From 8b3f7327d7d791926fcf3d1d5f08c8ed6ea15a3d Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:24:28 -0600 Subject: [PATCH 01/35] Added Eve Optimizer --- optax/_src/alias.py | 41 ++++++++++++++++++++++ optax/_src/transform.py | 76 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 113 insertions(+), 4 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index c6ae6b60..6dbcf7fd 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -339,6 +339,47 @@ def amsgrad( _scale_by_learning_rate(learning_rate), ) +def eve( + learning_rate: float = 1e-3, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """The Eve optimizer. + + Eve is an SGD variant with adaptive global and local learning rates. The `learning_rate` + used for each weight is computed from estimates of first- and second-order + moments of the gradients (using suitable exponential moving averages) as in ADAM. + The global learning rate is scaled by some notion of sub-optimality and is increased + when far from optimal and is decreased when approaching optimality + + References: + Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 + + Args: + learning_rate: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_eve( + b1=b1, b2=b2, b3=b3, c=c, eps=eps, f_star=f_star, mu_dtype=mu_dtype), + _scale_by_learning_rate(learning_rate), + ) def fromage( learning_rate: float, diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 2bbc75e9..58dd3c29 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -419,10 +419,10 @@ def scale_by_adamax( ) -> base.GradientTransformation: """Rescale updates according to the Adamax algorithm. - References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) - - Args: + References:nu = update_infinity_moment(updates, state.nu, b2, eps) + count_inc = utils.numerics.safe_int32_increment(state.count) + mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) + nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) b1: Decay rate for the exponentially weighted average of grads. b2: Decay rate for the exponentially weighted maximum of grads. eps: Term added to the denominator to improve numerical stability. @@ -449,6 +449,74 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) +class ScaleByEveState(NamedTuple): + """State for the Eve algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + d: float + f_prev: float + + +def scale_by_eve(b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Rescale updates according to the Eve algorithm. + + References: + [Hayashi et al, 2018](https://arxiv.org/abs/1611.01505) + + Args: + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + An (init_fn, update_fn) tuple. + """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=1.) + + def update_fn(updates: base.Updates, state: ScaleByEveState, f: float): + """ + Eve requires an additional parameter: the loss for the current iteration: f = f_t + ScaleByEveState holds the loss from the previous iteration: state.f_prev = f_{t-1} + """ + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = utils.numerics.safe_int32_increment(state.count) + mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) + nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) + if count_inc > 1: + d_new = jnp.abs(f - state.f_prev) / (jnp.min(jnp.array([f,state.f_prev])) - f_star) + d_tilde = jnp.clip(d_new,1/c,c) + d = b3*state.d + (1-b3)*d_tilde + else: + d = 1. + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=f) + + return base.GradientTransformation(init_fn, update_fn) + + ScaleState = base.EmptyState From 1374b6e880d01c90f25210b9a8e9ac810c5392f4 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:30:56 -0600 Subject: [PATCH 02/35] renamed for testing --- docs/conf.py | 4 +- docs/ext/coverage_check.py | 8 +- examples/differentially_private_sgd.py | 10 +- examples/flax_example.py | 10 +- examples/haiku_example.py | 10 +- examples/lookahead_mnist.py | 12 +- examples/mnist.py | 10 +- examples/mnist_test.py | 4 +- optax_add_eve/__init__.py | 347 +++++ optax_add_eve/_src/alias.py | 926 +++++++++++++ optax_add_eve/_src/alias_test.py | 186 +++ optax_add_eve/_src/base_test.py | 139 ++ optax_add_eve/_src/clipping.py | 222 +++ optax_add_eve/_src/clipping_test.py | 96 ++ optax_add_eve/_src/combine.py | 150 ++ optax_add_eve/_src/combine_test.py | 152 +++ optax_add_eve/_src/constrain.py | 97 ++ optax_add_eve/_src/constrain_test.py | 116 ++ optax_add_eve/_src/control_variates.py | 419 ++++++ optax_add_eve/_src/control_variates_test.py | 595 ++++++++ optax_add_eve/_src/equivalence_test.py | 176 +++ .../_src/experimental/complex_valued.py | 121 ++ .../_src/experimental/complex_valued_test.py | 79 ++ optax_add_eve/_src/experimental/extra_args.py | 167 +++ .../_src/experimental/extra_args_test.py | 65 + optax_add_eve/_src/factorized.py | 199 +++ optax_add_eve/_src/factorized_test.py | 45 + optax_add_eve/_src/float64_test.py | 94 ++ optax_add_eve/_src/linear_algebra.py | 201 +++ optax_add_eve/_src/linear_algebra_test.py | 62 + optax_add_eve/_src/lookahead.py | 192 +++ optax_add_eve/_src/lookahead_test.py | 140 ++ optax_add_eve/_src/loss.py | 521 +++++++ optax_add_eve/_src/loss_test.py | 500 +++++++ optax_add_eve/_src/numerics_test.py | 112 ++ optax_add_eve/_src/privacy.py | 74 + optax_add_eve/_src/privacy_test.py | 112 ++ optax_add_eve/_src/schedule.py | 620 +++++++++ optax_add_eve/_src/schedule_test.py | 649 +++++++++ optax_add_eve/_src/second_order_test.py | 93 ++ .../_src/stochastic_gradient_estimators.py | 317 +++++ .../stochastic_gradient_estimators_test.py | 371 +++++ optax_add_eve/_src/transform.py | 1206 +++++++++++++++++ optax_add_eve/_src/transform_test.py | 305 +++++ optax_add_eve/_src/update.py | 103 ++ optax_add_eve/_src/update_test.py | 83 ++ optax_add_eve/_src/utils.py | 152 +++ optax_add_eve/_src/utils_test.py | 65 + optax_add_eve/_src/wrappers.py | 547 ++++++++ optax_add_eve/_src/wrappers_test.py | 623 +++++++++ optax_add_eve/experimental/__init__.py | 23 + optax_add_eve/optax_test.py | 29 + 52 files changed, 11525 insertions(+), 34 deletions(-) create mode 100644 optax_add_eve/__init__.py create mode 100644 optax_add_eve/_src/alias.py create mode 100644 optax_add_eve/_src/alias_test.py create mode 100644 optax_add_eve/_src/base_test.py create mode 100644 optax_add_eve/_src/clipping.py create mode 100644 optax_add_eve/_src/clipping_test.py create mode 100644 optax_add_eve/_src/combine.py create mode 100644 optax_add_eve/_src/combine_test.py create mode 100644 optax_add_eve/_src/constrain.py create mode 100644 optax_add_eve/_src/constrain_test.py create mode 100644 optax_add_eve/_src/control_variates.py create mode 100644 optax_add_eve/_src/control_variates_test.py create mode 100644 optax_add_eve/_src/equivalence_test.py create mode 100644 optax_add_eve/_src/experimental/complex_valued.py create mode 100644 optax_add_eve/_src/experimental/complex_valued_test.py create mode 100644 optax_add_eve/_src/experimental/extra_args.py create mode 100644 optax_add_eve/_src/experimental/extra_args_test.py create mode 100644 optax_add_eve/_src/factorized.py create mode 100644 optax_add_eve/_src/factorized_test.py create mode 100644 optax_add_eve/_src/float64_test.py create mode 100644 optax_add_eve/_src/linear_algebra.py create mode 100644 optax_add_eve/_src/linear_algebra_test.py create mode 100644 optax_add_eve/_src/lookahead.py create mode 100644 optax_add_eve/_src/lookahead_test.py create mode 100644 optax_add_eve/_src/loss.py create mode 100644 optax_add_eve/_src/loss_test.py create mode 100644 optax_add_eve/_src/numerics_test.py create mode 100644 optax_add_eve/_src/privacy.py create mode 100644 optax_add_eve/_src/privacy_test.py create mode 100644 optax_add_eve/_src/schedule.py create mode 100644 optax_add_eve/_src/schedule_test.py create mode 100644 optax_add_eve/_src/second_order_test.py create mode 100644 optax_add_eve/_src/stochastic_gradient_estimators.py create mode 100644 optax_add_eve/_src/stochastic_gradient_estimators_test.py create mode 100644 optax_add_eve/_src/transform.py create mode 100644 optax_add_eve/_src/transform_test.py create mode 100644 optax_add_eve/_src/update.py create mode 100644 optax_add_eve/_src/update_test.py create mode 100644 optax_add_eve/_src/utils.py create mode 100644 optax_add_eve/_src/utils_test.py create mode 100644 optax_add_eve/_src/wrappers.py create mode 100644 optax_add_eve/_src/wrappers_test.py create mode 100644 optax_add_eve/experimental/__init__.py create mode 100644 optax_add_eve/optax_test.py diff --git a/docs/conf.py b/docs/conf.py index 936006a0..fc8fc231 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -121,7 +121,7 @@ def new_process_docstring(app, what, name, obj, options, lines): sys.path.insert(0, os.path.abspath('../')) sys.path.append(os.path.abspath('ext')) -import optax +import optax_add_eve from sphinxcontrib import katex # -- Project information ----------------------------------------------------- @@ -246,7 +246,7 @@ def linkcode_resolve(domain, info): # TODO(slebedev): support tags after we release an initial version. return 'https://github.com/deepmind/optax/tree/master/optax/%s#L%d#L%d' % ( os.path.relpath(filename, start=os.path.dirname( - optax.__file__)), lineno, lineno + len(source) - 1) + optax_add_eve.__file__)), lineno, lineno + len(source) - 1) # -- Intersphinx configuration ----------------------------------------------- diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index c31cb75f..9b42626e 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -16,8 +16,8 @@ from typing import Any, Mapping -import optax -from optax._src import test_utils +import optax_add_eve +from optax_add_eve._src import test_utils from sphinx import application from sphinx import builders from sphinx import errors @@ -25,7 +25,7 @@ def optax_public_symbols(): names = set() - for module_name, module in test_utils.find_internal_python_modules(optax): + for module_name, module in test_utils.find_internal_python_modules(optax_add_eve): for name in module.__all__: names.add(module_name + "." + name) return names @@ -55,4 +55,4 @@ def finish(self) -> None: def setup(app: application.Sphinx) -> Mapping[str, Any]: app.add_builder(OptaxCoverageCheck) - return dict(version=optax.__version__, parallel_read_safe=True) + return dict(version=optax_add_eve.__version__, parallel_read_safe=True) diff --git a/examples/differentially_private_sgd.py b/examples/differentially_private_sgd.py index 5cce0953..e011713a 100644 --- a/examples/differentially_private_sgd.py +++ b/examples/differentially_private_sgd.py @@ -70,7 +70,7 @@ import jax from jax.example_libraries import stax import jax.numpy as jnp -import optax +import optax_add_eve # pylint: disable=g-bad-import-order import datasets # Located in the examples folder. @@ -119,7 +119,7 @@ def compute_epsilon(steps, target_delta=1e-5): def loss_fn(params, batch): logits = predict(params, batch['image']) - return optax.softmax_cross_entropy(logits, batch['label']).mean(), logits + return optax_add_eve.softmax_cross_entropy(logits, batch['label']).mean(), logits @jax.jit @@ -136,12 +136,12 @@ def main(_): full_test_batch = next(test_dataset.as_numpy_iterator()) if FLAGS.dpsgd: - tx = optax.dpsgd(learning_rate=FLAGS.learning_rate, + tx = optax_add_eve.dpsgd(learning_rate=FLAGS.learning_rate, l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, seed=FLAGS.seed) else: - tx = optax.sgd(learning_rate=FLAGS.learning_rate) + tx = optax_add_eve.sgd(learning_rate=FLAGS.learning_rate) @jax.jit def train_step(params, opt_state, batch): @@ -154,7 +154,7 @@ def train_step(params, opt_state, batch): grads, _ = grad_fn(params, batch) updates, new_opt_state = tx.update(grads, opt_state, params) - new_params = optax.apply_updates(params, updates) + new_params = optax_add_eve.apply_updates(params, updates) return new_params, new_opt_state key = jax.random.PRNGKey(FLAGS.seed) diff --git a/examples/flax_example.py b/examples/flax_example.py index a507a02c..d47460a6 100644 --- a/examples/flax_example.py +++ b/examples/flax_example.py @@ -19,7 +19,7 @@ from flax import linen as nn import jax import jax.numpy as jnp -import optax +import optax_add_eve def main(argv): @@ -70,11 +70,11 @@ def squared_error(x, y): # Construct a simple Adam optimiser using the transforms in optax. # You could also just use the `optax.adam` alias, but we show here how # to do so manually so that you may construct your own `custom` optimiser. - tx = optax.chain( + tx = optax_add_eve.chain( # Set the parameters of Adam. Note the learning_rate is not here. - optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), + optax_add_eve.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), # Put a minus sign to *minimise* the loss. - optax.scale(-learning_rate) + optax_add_eve.scale(-learning_rate) ) # Create optimiser state. @@ -89,7 +89,7 @@ def squared_error(x, y): # Update the optimiser state, create an update to the params. updates, opt_state = tx.update(grads, opt_state) # Update the parameters. - params = optax.apply_updates(params, updates) + params = optax_add_eve.apply_updates(params, updates) print(f'Loss[{step}] = {loss_val}') diff --git a/examples/haiku_example.py b/examples/haiku_example.py index 3d8bbe2d..0854ab87 100644 --- a/examples/haiku_example.py +++ b/examples/haiku_example.py @@ -19,7 +19,7 @@ import haiku as hk import jax import jax.numpy as jnp -import optax +import optax_add_eve def main(argv): @@ -48,11 +48,11 @@ def mean_square_loss(params, x): # Construct a simple Adam optimiser using the transforms in optax. # You could also just use the `optax.adam` alias, but we show here how # to do so manually so that you may construct your own `custom` optimiser. - opt_init, opt_update = optax.chain( + opt_init, opt_update = optax_add_eve.chain( # Set the parameters of Adam. Note the learning_rate is not here. - optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), + optax_add_eve.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), # Put a minus sign to *minimise* the loss. - optax.scale(-learning_rate) + optax_add_eve.scale(-learning_rate) ) # Initialise the model's parameters and the optimiser's state. @@ -71,7 +71,7 @@ def mean_square_loss(params, x): # Transform the gradients using the optimiser. updates, opt_state = opt_update(grad, opt_state, params) # Update parameters. - params = optax.apply_updates(params, updates) + params = optax_add_eve.apply_updates(params, updates) if __name__ == '__main__': diff --git a/examples/lookahead_mnist.py b/examples/lookahead_mnist.py index fce49efa..df08264c 100644 --- a/examples/lookahead_mnist.py +++ b/examples/lookahead_mnist.py @@ -19,7 +19,7 @@ import jax from jax import random import jax.numpy as jnp -import optax +import optax_add_eve # pylint: disable=g-bad-import-order import datasets # Located in the examples folder. @@ -45,18 +45,18 @@ def main(unused_argv) -> None: (*HIDDEN_SIZES, num_classes)) # Set up the fast optimizer (adam) and wrap lookahead around it. - fast_optimizer = optax.adam(LEARNING_RATE) - optimizer = optax.lookahead(fast_optimizer, SYNC_PERIOD, SLOW_LEARNING_RATE) + fast_optimizer = optax_add_eve.adam(LEARNING_RATE) + optimizer = optax_add_eve.lookahead(fast_optimizer, SYNC_PERIOD, SLOW_LEARNING_RATE) def get_loss(fast_params, batch): logits = apply_params_fn(fast_params, batch['image']) - return jnp.mean(optax.softmax_cross_entropy(logits, batch['label'])) + return jnp.mean(optax_add_eve.softmax_cross_entropy(logits, batch['label'])) @jax.jit def train_step(params, optimizer_state, batch): grads = jax.grad(get_loss)(params.fast, batch) updates, opt_state = optimizer.update(grads, optimizer_state, params) - return optax.apply_updates(params, updates), opt_state + return optax_add_eve.apply_updates(params, updates), opt_state example_input = next(train_dataset.as_numpy_iterator())['image'] initial_params = init_params_fn(random.PRNGKey(SEED), example_input) @@ -66,7 +66,7 @@ def train_step(params, optimizer_state, batch): # initial model parameters. The first line below is only necessary for the # lookahead wrapper; without it the initial parameters could be used in the # initialization function of the optimizer directly. - params = optax.LookaheadParams.init_synced(initial_params) + params = optax_add_eve.LookaheadParams.init_synced(initial_params) opt_state = optimizer.init(params) # Training loop diff --git a/examples/mnist.py b/examples/mnist.py index d79f97af..ac1c395a 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -22,7 +22,7 @@ import jax from jax import random import jax.numpy as jnp -import optax +import optax_add_eve # pylint: disable=g-bad-import-order import datasets # Located in the examples folder. @@ -70,7 +70,7 @@ def mlp_model(inputs: chex.Array) -> chex.Array: return hk.without_apply_rng(mlp_model) -def train_on_mnist(optimizer: optax.GradientTransformation, +def train_on_mnist(optimizer: optax_add_eve.GradientTransformation, hidden_sizes: Sequence[int]) -> float: """Trains an MLP on MNIST using a given optimizer. @@ -90,13 +90,13 @@ def train_on_mnist(optimizer: optax.GradientTransformation, def get_loss(params, batch): logits = apply_params_fn(params, batch['image']) - return jnp.mean(optax.softmax_cross_entropy(logits, batch['label'])) + return jnp.mean(optax_add_eve.softmax_cross_entropy(logits, batch['label'])) @jax.jit def train_step(params, optimizer_state, batch): grads = jax.grad(get_loss)(params, batch) updates, opt_state = optimizer.update(grads, optimizer_state, params) - return optax.apply_updates(params, updates), opt_state + return optax_add_eve.apply_updates(params, updates), opt_state example_input = next(train_dataset.as_numpy_iterator())['image'] params = init_params_fn(random.PRNGKey(SEED), example_input) @@ -116,7 +116,7 @@ def train_step(params, optimizer_state, batch): def main(unused_argv): """Trains an MLP on MNIST using the adam optimizers.""" - return train_on_mnist(optax.adam(LEARNING_RATE), DEFAULT_HIDDEN_SIZES) + return train_on_mnist(optax_add_eve.adam(LEARNING_RATE), DEFAULT_HIDDEN_SIZES) if __name__ == '__main__': diff --git a/examples/mnist_test.py b/examples/mnist_test.py index afc8b636..9c0d8f48 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -21,7 +21,7 @@ import haiku as hk import jax import numpy as np -import optax +import optax_add_eve import tensorflow as tf # pylint: disable=g-bad-import-order @@ -71,7 +71,7 @@ def test_train_on_mnist_can_fit_linear_mock_data(self): dataset = tf.data.Dataset.from_tensor_slices(data).repeat(8).batch(10) with mock.patch.object( datasets, 'load_image_dataset', return_value=dataset): - final_accuracy = mnist.train_on_mnist(optax.adam(0.01), hidden_sizes=(1,)) + final_accuracy = mnist.train_on_mnist(optax_add_eve.adam(0.01), hidden_sizes=(1,)) self.assertEqual(final_accuracy, 1.) diff --git a/optax_add_eve/__init__.py b/optax_add_eve/__init__.py new file mode 100644 index 00000000..ac576393 --- /dev/null +++ b/optax_add_eve/__init__.py @@ -0,0 +1,347 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Optax: composable gradient processing and optimization, in JAX.""" + +from optax_add_eve import experimental +from optax_add_eve._src.alias import adabelief +from optax_add_eve._src.alias import adafactor +from optax_add_eve._src.alias import adagrad +from optax_add_eve._src.alias import adam +from optax_add_eve._src.alias import adamax +from optax_add_eve._src.alias import adamaxw +from optax_add_eve._src.alias import adamw +from optax_add_eve._src.alias import amsgrad +from optax_add_eve._src.alias import dpsgd +from optax_add_eve._src.alias import fromage +from optax_add_eve._src.alias import lamb +from optax_add_eve._src.alias import lars +from optax_add_eve._src.alias import MaskOrFn +from optax_add_eve._src.alias import noisy_sgd +from optax_add_eve._src.alias import novograd +from optax_add_eve._src.alias import optimistic_gradient_descent +from optax_add_eve._src.alias import radam +from optax_add_eve._src.alias import rmsprop +from optax_add_eve._src.alias import ScalarOrSchedule +from optax_add_eve._src.alias import sgd +from optax_add_eve._src.alias import sm3 +from optax_add_eve._src.alias import yogi +from optax_add_eve._src.base import EmptyState +from optax_add_eve._src.base import GradientTransformation +from optax_add_eve._src.base import identity +from optax_add_eve._src.base import OptState +from optax_add_eve._src.base import Params +from optax_add_eve._src.base import Schedule +from optax_add_eve._src.base import set_to_zero +from optax_add_eve._src.base import stateless +from optax_add_eve._src.base import stateless_with_tree_map +from optax_add_eve._src.base import TransformInitFn +from optax_add_eve._src.base import TransformUpdateFn +from optax_add_eve._src.base import Updates +from optax_add_eve._src.clipping import adaptive_grad_clip +from optax_add_eve._src.clipping import AdaptiveGradClipState +from optax_add_eve._src.clipping import clip +from optax_add_eve._src.clipping import clip_by_block_rms +from optax_add_eve._src.clipping import clip_by_global_norm +from optax_add_eve._src.clipping import ClipByGlobalNormState +from optax_add_eve._src.clipping import ClipState +from optax_add_eve._src.clipping import per_example_global_norm_clip +from optax_add_eve._src.combine import chain +from optax_add_eve._src.combine import multi_transform +from optax_add_eve._src.combine import MultiTransformState +from optax_add_eve._src.constrain import keep_params_nonnegative +from optax_add_eve._src.constrain import NonNegativeParamsState +from optax_add_eve._src.constrain import zero_nans +from optax_add_eve._src.constrain import ZeroNansState +from optax_add_eve._src.control_variates import control_delta_method +from optax_add_eve._src.control_variates import control_variates_jacobians +from optax_add_eve._src.control_variates import moving_avg_baseline +from optax_add_eve._src.factorized import FactoredState +from optax_add_eve._src.factorized import scale_by_factored_rms +from optax_add_eve._src.linear_algebra import global_norm +from optax_add_eve._src.linear_algebra import matrix_inverse_pth_root +from optax_add_eve._src.linear_algebra import power_iteration +from optax_add_eve._src.lookahead import lookahead +from optax_add_eve._src.lookahead import LookaheadParams +from optax_add_eve._src.lookahead import LookaheadState +from optax_add_eve._src.loss import cosine_distance +from optax_add_eve._src.loss import cosine_similarity +from optax_add_eve._src.loss import ctc_loss +from optax_add_eve._src.loss import ctc_loss_with_forward_probs +from optax_add_eve._src.loss import hinge_loss +from optax_add_eve._src.loss import huber_loss +from optax_add_eve._src.loss import l2_loss +from optax_add_eve._src.loss import log_cosh +from optax_add_eve._src.loss import sigmoid_binary_cross_entropy +from optax_add_eve._src.loss import smooth_labels +from optax_add_eve._src.loss import softmax_cross_entropy +from optax_add_eve._src.loss import softmax_cross_entropy_with_integer_labels +from optax_add_eve._src.numerics import safe_int32_increment +from optax_add_eve._src.numerics import safe_norm +from optax_add_eve._src.numerics import safe_root_mean_squares +from optax_add_eve._src.privacy import differentially_private_aggregate +from optax_add_eve._src.privacy import DifferentiallyPrivateAggregateState +from optax_add_eve._src.schedule import constant_schedule +from optax_add_eve._src.schedule import cosine_decay_schedule +from optax_add_eve._src.schedule import cosine_onecycle_schedule +from optax_add_eve._src.schedule import exponential_decay +from optax_add_eve._src.schedule import inject_hyperparams +from optax_add_eve._src.schedule import InjectHyperparamsState +from optax_add_eve._src.schedule import join_schedules +from optax_add_eve._src.schedule import linear_onecycle_schedule +from optax_add_eve._src.schedule import linear_schedule +from optax_add_eve._src.schedule import piecewise_constant_schedule +from optax_add_eve._src.schedule import piecewise_interpolate_schedule +from optax_add_eve._src.schedule import polynomial_schedule +from optax_add_eve._src.schedule import sgdr_schedule +from optax_add_eve._src.schedule import warmup_cosine_decay_schedule +from optax_add_eve._src.schedule import warmup_exponential_decay_schedule +from optax_add_eve._src.second_order import fisher_diag +from optax_add_eve._src.second_order import hessian_diag +from optax_add_eve._src.second_order import hvp +from optax_add_eve._src.stochastic_gradient_estimators import measure_valued_jacobians +from optax_add_eve._src.stochastic_gradient_estimators import pathwise_jacobians +from optax_add_eve._src.stochastic_gradient_estimators import score_function_jacobians +from optax_add_eve._src.transform import add_decayed_weights +from optax_add_eve._src.transform import add_noise +from optax_add_eve._src.transform import AddDecayedWeightsState +from optax_add_eve._src.transform import additive_weight_decay +from optax_add_eve._src.transform import AdditiveWeightDecayState +from optax_add_eve._src.transform import AddNoiseState +from optax_add_eve._src.transform import apply_every +from optax_add_eve._src.transform import ApplyEvery +from optax_add_eve._src.transform import bias_correction +from optax_add_eve._src.transform import centralize +from optax_add_eve._src.transform import ema +from optax_add_eve._src.transform import EmaState +from optax_add_eve._src.transform import scale +from optax_add_eve._src.transform import scale_by_adam +from optax_add_eve._src.transform import scale_by_adamax +from optax_add_eve._src.transform import scale_by_amsgrad +from optax_add_eve._src.transform import scale_by_belief +from optax_add_eve._src.transform import scale_by_novograd +from optax_add_eve._src.transform import scale_by_optimistic_gradient +from optax_add_eve._src.transform import scale_by_param_block_norm +from optax_add_eve._src.transform import scale_by_param_block_rms +from optax_add_eve._src.transform import scale_by_radam +from optax_add_eve._src.transform import scale_by_rms +from optax_add_eve._src.transform import scale_by_rss +from optax_add_eve._src.transform import scale_by_schedule +from optax_add_eve._src.transform import scale_by_sm3 +from optax_add_eve._src.transform import scale_by_stddev +from optax_add_eve._src.transform import scale_by_trust_ratio +from optax_add_eve._src.transform import scale_by_yogi +from optax_add_eve._src.transform import ScaleByAdamState +from optax_add_eve._src.transform import ScaleByAmsgradState +from optax_add_eve._src.transform import ScaleByBeliefState +from optax_add_eve._src.transform import ScaleByNovogradState +from optax_add_eve._src.transform import ScaleByRmsState +from optax_add_eve._src.transform import ScaleByRssState +from optax_add_eve._src.transform import ScaleByRStdDevState +from optax_add_eve._src.transform import ScaleByScheduleState +from optax_add_eve._src.transform import ScaleBySM3State +from optax_add_eve._src.transform import ScaleByTrustRatioState +from optax_add_eve._src.transform import ScaleState +from optax_add_eve._src.transform import trace +from optax_add_eve._src.transform import TraceState +from optax_add_eve._src.transform import update_infinity_moment +from optax_add_eve._src.transform import update_moment +from optax_add_eve._src.transform import update_moment_per_elem_norm +from optax_add_eve._src.update import apply_updates +from optax_add_eve._src.update import incremental_update +from optax_add_eve._src.update import periodic_update +from optax_add_eve._src.utils import multi_normal +from optax_add_eve._src.utils import scale_gradient +from optax_add_eve._src.wrappers import apply_if_finite +from optax_add_eve._src.wrappers import ApplyIfFiniteState +from optax_add_eve._src.wrappers import flatten +from optax_add_eve._src.wrappers import masked +from optax_add_eve._src.wrappers import MaskedNode +from optax_add_eve._src.wrappers import MaskedState +from optax_add_eve._src.wrappers import maybe_update +from optax_add_eve._src.wrappers import MaybeUpdateState +from optax_add_eve._src.wrappers import MultiSteps +from optax_add_eve._src.wrappers import MultiStepsState +from optax_add_eve._src.wrappers import ShouldSkipUpdateFunction +from optax_add_eve._src.wrappers import skip_large_updates +from optax_add_eve._src.wrappers import skip_not_finite + +__version__ = "0.1.5.dev" + +__all__ = ( + "adabelief", + "adafactor", + "adagrad", + "adam", + "adamax", + "adamaxw", + "adamw", + "adaptive_grad_clip", + "AdaptiveGradClipState", + "add_decayed_weights", + "add_noise", + "AddDecayedWeightsState", + "additive_weight_decay", + "AdditiveWeightDecayState", + "AddNoiseState", + "amsgrad", + "apply_every", + "apply_if_finite", + "apply_updates", + "ApplyEvery", + "ApplyIfFiniteState", + "centralize", + "chain", + "clip_by_block_rms", + "clip_by_global_norm", + "clip", + "ClipByGlobalNormState", + "ClipState", + "constant_schedule", + "ctc_loss", + "ctc_loss_with_forward_probs", + "control_delta_method", + "control_variates_jacobians", + "cosine_decay_schedule", + "cosine_distance", + "cosine_onecycle_schedule", + "cosine_similarity", + "differentially_private_aggregate", + "DifferentiallyPrivateAggregateState", + "dpsgd", + "ema", + "EmaState", + "EmptyState", + "exponential_decay", + "FactoredState", + "fisher_diag", + "flatten", + "fromage", + "global_norm", + "GradientTransformation", + "hinge_loss", + "hessian_diag", + "huber_loss", + "hvp", + "identity", + "incremental_update", + "inject_hyperparams", + "InjectHyperparamsState", + "join_schedules", + "keep_params_nonnegative", + "l2_loss", + "lamb", + "lars", + "linear_onecycle_schedule", + "linear_schedule", + "log_cosh", + "lookahead", + "LookaheadParams", + "LookaheadState", + "masked", + "MaskOrFn", + "MaskedState", + "matrix_inverse_pth_root", + "maybe_update", + "MaybeUpdateState", + "measure_valued_jacobians", + "moving_avg_baseline", + "multi_normal", + "multi_transform", + "MultiSteps", + "MultiStepsState", + "MultiTransformState", + "noisy_sgd", + "novograd", + "NonNegativeParamsState", + "OptState", + "Params", + "pathwise_jacobians", + "periodic_update", + "per_example_global_norm_clip", + "piecewise_constant_schedule", + "piecewise_interpolate_schedule", + "polynomial_schedule", + "power_iteration", + "radam", + "rmsprop", + "safe_int32_increment", + "safe_norm", + "safe_root_mean_squares", + "ScalarOrSchedule", + "scale_by_adam", + "scale_by_adamax", + "scale_by_amsgrad", + "scale_by_belief", + "scale_by_factored_rms", + "scale_by_novograd", + "scale_by_param_block_norm", + "scale_by_param_block_rms", + "scale_by_radam", + "scale_by_rms", + "scale_by_rss", + "scale_by_schedule", + "scale_by_sm3", + "scale_by_stddev", + "scale_by_trust_ratio", + "scale_by_yogi", + "scale_gradient", + "scale", + "ScaleByAdamState", + "ScaleByAmsgradState", + "ScaleByBeliefState", + "ScaleByNovogradState", + "ScaleByRmsState", + "ScaleByRssState", + "ScaleByRStdDevState", + "ScaleByScheduleState", + "ScaleBySM3State", + "ScaleByTrustRatioState", + "ScaleState", + "Schedule", + "score_function_jacobians", + "set_to_zero", + "sgd", + "sgdr_schedule", + "ShouldSkipUpdateFunction", + "sigmoid_binary_cross_entropy", + "skip_large_updates", + "skip_not_finite", + "sm3", + "smooth_labels", + "softmax_cross_entropy", + "stateless", + "stateless_with_tree_map", + "trace", + "TraceState", + "TransformInitFn", + "TransformUpdateFn", + "Updates", + "warmup_cosine_decay_schedule", + "warmup_exponential_decay_schedule", + "yogi", + "zero_nans", + "ZeroNansState", +) + +# _________________________________________ +# / Please don't use symbols in `_src` they \ +# \ are not part of the Optax public API. / +# ----------------------------------------- +# \ ^__^ +# \ (oo)\_______ +# (__)\ )\/\ +# ||----w | +# || || +# diff --git a/optax_add_eve/_src/alias.py b/optax_add_eve/_src/alias.py new file mode 100644 index 00000000..b5935ae9 --- /dev/null +++ b/optax_add_eve/_src/alias.py @@ -0,0 +1,926 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Aliases for popular optimizers.""" + +from typing import Any, Callable, Optional, Union + +import jax.numpy as jnp + +from optax_add_eve._src import base +from optax_add_eve._src import clipping +from optax_add_eve._src import combine +from optax_add_eve._src import factorized +from optax_add_eve._src import privacy +from optax_add_eve._src import transform +from optax_add_eve._src import wrappers + + +ScalarOrSchedule = Union[float, base.Schedule] +MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] + + +def _scale_by_learning_rate(learning_rate: ScalarOrSchedule, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return transform.scale_by_schedule(lambda count: m * learning_rate(count)) + return transform.scale(m * learning_rate) + + +def adabelief( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-16, + eps_root: float = 1e-16) -> base.GradientTransformation: + """The AdaBelief optimizer. + + AdaBelief is an adaptive learning rate optimizer that focuses on fast + convergence, generalization, and stability. It adapts the step size depending + on its "belief" in the gradient direction — the optimizer adaptively scales + the step size by the difference between the predicted and observed gradients. + AdaBelief is a modified version of Adam and contains the same number of + parameters. + + References: + Zhuang et al, 2020: https://arxiv.org/abs/2010.07468 + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the second moment of past gradients. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the second moment of the prediction error to + improve numerical stability. If backpropagating gradients through the + gradient transformation (e.g. for meta-learning), this must be non-zero. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_belief(b1=b1, b2=b2, eps=eps, eps_root=eps_root), + _scale_by_learning_rate(learning_rate), + ) + + +def adafactor( + learning_rate: Optional[ScalarOrSchedule] = None, + min_dim_size_to_factor: int = 128, + decay_rate: float = 0.8, + decay_offset: int = 0, + multiply_by_parameter_scale: float = True, + clipping_threshold: Optional[float] = 1.0, + momentum: Optional[float] = None, + dtype_momentum: Any = jnp.float32, + weight_decay_rate: Optional[float] = None, + eps: float = 1e-30, + factored: bool = True, + weight_decay_mask: MaskOrFn = None, + ) -> base.GradientTransformation: + """The Adafactor optimizer. + + Adafactor is an adaptive learning rate optimizer that focuses on fast + training of large scale neural networks. It saves memory by using a factored + estimate of the second order moments used to scale gradients. + + References: + Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235 + + Args: + learning_rate: A fixed global scaling factor. Note: the natural scale for + Adafactor's LR is markedly different from Adam, one doesn't use the + 1/sqrt(hidden) correction for this optim with attention-based models. + min_dim_size_to_factor: Only factor the statistics if two array dimensions + have at least this size. + decay_rate: Controls second-moment exponential decay schedule. + decay_offset: For fine-tuning, one may set this to the starting step + number of the fine-tuning phase. + multiply_by_parameter_scale: If True, then scale learning_rate by + parameter norm. If False, provided learning_rate is absolute step size. + clipping_threshold: Optional clipping threshold. Must be >= 1. If None, + clipping is disabled. + momentum: Optional value between 0 and 1, enables momentum and uses extra + memory if non-None! None by default. + dtype_momentum: Data type of momentum buffers. + weight_decay_rate: Optional rate at which to decay weights. + eps: Regularization constant for root mean squared gradient. + factored: Whether to use factored second-moment estimates. + weight_decay_mask: A tree with same structure as (or a prefix of) + the params PyTree, or a Callable that returns such a pytree given + the params/updates. The leaves should be booleans, `True` + for leaves/subtrees you want to apply the transformation to, + and `False` for those you want to skip. + + Returns: + The corresponding `GradientTransformation`. + """ + # The core of the algorithm is a procedure for rescaling gradients + # by a factored estimate of the root mean squared gradients. + # This reduces memory compared to algorithms such as Adam or RmsProp, + # by not having to hold a separate estimate for each weight. + tx = [ + factorized.scale_by_factored_rms( + factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)] + # This basic rescaling is typically combined with one or more of the following + # transformation (all can be disabled via adafactor's constructor args). + if clipping_threshold is not None: + tx.append(clipping.clip_by_block_rms(clipping_threshold)) + if learning_rate is not None: + tx.append(_scale_by_learning_rate(learning_rate, flip_sign=False)) + if multiply_by_parameter_scale: + tx.append(transform.scale_by_param_block_rms()) + if momentum is not None: + tx.append( + transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum)) + if weight_decay_rate is not None: + tx.append(transform.add_decayed_weights( + weight_decay_rate, mask=weight_decay_mask)) + # In gradient "descent" we follow the negative gradient. + tx.append(transform.scale(-1)) + return combine.chain(*tx) + + +def adagrad( + learning_rate: ScalarOrSchedule, + initial_accumulator_value: float = 0.1, + eps: float = 1e-7 +) -> base.GradientTransformation: + """The Adagrad optimizer. + + Adagrad is an algorithm for gradient based optimization that anneals the + learning rate for each parameter during the course of training. + + WARNING: Adagrad's main limit is the monotonic accumulation of squared + gradients in the denominator: since all terms are >0, the sum keeps growing + during training and the learning rate eventually becomes vanishingly small. + + References: + Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html + + Args: + learning_rate: A fixed global scaling factor. + initial_accumulator_value: Initial value for the accumulator. + eps: A small constant applied to denominator inside of the square root + (as in RMSProp) to avoid dividing by zero when rescaling. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_rss( + initial_accumulator_value=initial_accumulator_value, eps=eps), + _scale_by_learning_rate(learning_rate), + ) + + +def adam( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + r"""The classic Adam optimizer. + + Adam is an SGD variant with gradient scaling adaptation. The scaling + used for each parameter is computed from estimates of first and second-order + moments of the gradients (using suitable exponential moving averages). + + Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`, + :math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments + ``b1``, ``b2``, ``eps`` and ``eps_root`` respectievly. The learning rate is + indexed by :math:`t` since the learning rate may also be provided by a + schedule function. + + The ``init`` function of this optimizer initializes an internal state + :math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the + first and second moments. In practice these values are stored as pytrees + containing all zeros, with the same shape as the model updates. + At step :math:`t`, the ``update`` function of this optimizer takes as + arguments the incoming gradients :math:`g_t` and optimizer state :math:`S_t` + and computes updates :math:`u_t` and new state :math:`S_{t+1}`. Thus, for + :math:`t > 0`, we have, + + .. math:: + \begin{align*} + m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ + v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ + \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ + \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ + u_t &\leftarrow \alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + + \bar{\varepsilon}} + \varepsilon} \right)\\ + S_t &\leftarrow (m_t, v_t). + \end{align*} + + References: + Kingma et al, 2014: https://arxiv.org/abs/1412.6980 + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the second moment of past gradients. + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_adam( + b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), + _scale_by_learning_rate(learning_rate), + ) + + +def adamw( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: Optional[Any] = None, + weight_decay: float = 1e-4, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, +) -> base.GradientTransformation: + """Adam with weight decay regularization. + + AdamW uses weight decay to regularize learning towards small weights, as + this leads to better generalization. In SGD you can also use L2 regularization + to implement this as an additive loss term, however L2 regularization + does not behave as intended for adaptive gradient algorithms such as Adam. + + References: + Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the second moment of past gradients. + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + instance when computing (meta-)gradients through Adam. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent + with other frameworks such as PyTorch, but different from + (Loshchilov et al, 2019) where the weight decay is only multiplied with + the "schedule multiplier", but not the base learning rate. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Adam gradient transformations are applied to all parameters. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_adam( + b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), + transform.add_decayed_weights(weight_decay, mask), + _scale_by_learning_rate(learning_rate), + ) + + +def amsgrad( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """The AMSGrad optimiser. + + The original Adam can fail to converge to the optimal solution in some cases. + AMSGrad guarantees convergence by using a long-term memory of past gradients. + + References: + Reddi et al, 2018: https://openreview.net/forum?id=ryQu7f-RZ + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the second moment of past gradients. + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + instance when computing (meta-)gradients through Adam. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_amsgrad( + b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), + _scale_by_learning_rate(learning_rate), + ) + +def eve( + learning_rate: float = 1e-3, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """The Eve optimizer. + + Eve is an SGD variant with adaptive global and local learning rates. The `learning_rate` + used for each weight is computed from estimates of first- and second-order + moments of the gradients (using suitable exponential moving averages) as in ADAM. + The global learning rate is scaled by some notion of sub-optimality and is increased + when far from optimal and is decreased when approaching optimality + + References: + Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 + + Args: + learning_rate: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_eve( + b1=b1, b2=b2, b3=b3, c=c, eps=eps, f_star=f_star, mu_dtype=mu_dtype), + _scale_by_learning_rate(learning_rate), + ) + +def fromage( + learning_rate: float, + min_norm: float = 1e-6 +) -> base.GradientTransformation: + """The Frobenius matched gradient descent (Fromage) optimizer. + + Fromage is a learning algorithm that does not require learning rate tuning. + The optimizer is based on modeling neural network gradients via deep relative + trust (a distance function on deep neural networks). Fromage is similar to the + LARS optimizer and can work on a range of standard neural network benchmarks, + such as natural language Transformers and generative adversarial networks. + + References: + Bernstein et al, 2020: https://arxiv.org/abs/2002.03432 + + Args: + learning_rate: A fixed global scaling factor. + min_norm: A minimum value that the norm of the gradient updates and the norm + of the layer parameters can be clipped to to avoid dividing by zero when + computing the trust ratio (as in the LARS paper). + + Returns: + The corresponding `GradientTransformation`. + """ + mult = 1 / jnp.sqrt(1 + learning_rate ** 2) + return combine.chain( + transform.scale_by_trust_ratio(min_norm), + _scale_by_learning_rate(learning_rate * mult), + transform.add_decayed_weights((mult - 1)), + ) + + +def lars( + learning_rate: ScalarOrSchedule, + weight_decay: float = 0., + weight_decay_mask: MaskOrFn = True, + trust_coefficient: float = 0.001, + eps: float = 0., + trust_ratio_mask: MaskOrFn = True, + momentum: float = 0.9, + nesterov: bool = False, +) -> base.GradientTransformation: + """The LARS optimizer. + + LARS is a layer-wise adaptive optimizer introduced to help scale SGD to + larger batch sizes. LARS later inspired the LAMB optimizer. + + References: + You et al, 2017: https://arxiv.org/abs/1708.03888 + + Args: + learning_rate: A fixed global scaling factor. + weight_decay: Strength of the weight decay regularization. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + trust_coefficient: A multiplier for the trust ratio. + eps: Optional additive constant in the trust ratio denominator. + trust_ratio_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + momentum: Decay rate for momentum. + nesterov: Whether to use Nesterov momentum. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), + wrappers.masked( + inner=transform.scale_by_trust_ratio( + trust_coefficient=trust_coefficient, eps=eps), + mask=trust_ratio_mask), + _scale_by_learning_rate(learning_rate), + transform.trace(decay=momentum, nesterov=nesterov), + ) + + +def lamb( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-6, + eps_root: float = 0.0, + weight_decay: float = 0., + mask: MaskOrFn = None, +) -> base.GradientTransformation: + """The LAMB optimizer. + + LAMB is a general purpose layer-wise adaptive large batch optimizer designed + to provide consistent training performance across a wide range of tasks, + including those that use attention-based models (such as Transformers) and + ResNet-50. The optimizer is able to work with small and large batch sizes. + LAMB was inspired by the LARS learning algorithm. + + References: + You et al, 2019: https://arxiv.org/abs/1904.00962 + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the second moment of past gradients. + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + instance when computing (meta-)gradients through Adam. + weight_decay: Strength of the weight decay regularization. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), + transform.add_decayed_weights(weight_decay=weight_decay, mask=mask), + transform.scale_by_trust_ratio(), + _scale_by_learning_rate(learning_rate), + ) + + +def noisy_sgd( + learning_rate: ScalarOrSchedule, + eta: float = 0.01, + gamma: float = 0.55, + seed: int = 0 +) -> base.GradientTransformation: + r"""A variant of SGD with added noise. + + It has been found that adding noise to the gradients can improve + both the training error and the generalization error in very deep networks. + + References: + Neelakantan et al, 2014: https://arxiv.org/abs/1511.06807 + + Args: + learning_rate: A fixed global scaling factor. + eta: Initial variance for the Gaussian noise added to gradients. + gamma: A parameter controlling the annealing of noise over time, the + variance decays according to `(1+t)^-\gamma`. + seed: Seed for the pseudo-random generation process. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.add_noise(eta, gamma, seed), + _scale_by_learning_rate(learning_rate), + ) + + +def novograd( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.25, + eps: float = 1e-6, + eps_root: float = 0.0, + weight_decay: float = 0., +) -> base.GradientTransformation: + """NovoGrad optimizer. + + NovoGrad is more robust to the initial learning rate and + weight initialization than other methods. For example, + NovoGrad works well without LR warm-up, while other methods require it. + NovoGrad performs exceptionally well for large batch training, e.g. it + outperforms other methods for ResNet-50 for all batches up to 32K. + In addition, NovoGrad requires half the memory compared to Adam. + It was introduced together with Jasper ASR model. + + References: + Ginsburg et al, 2019: https://arxiv.org/abs/1905.11286 + Li et al, 2019: https://arxiv.org/abs/1904.03288 + + Args: + learning_rate: A fixed global scaling factor. + b1: An exponential decay rate to track the first moment of past gradients. + b2: An exponential decay rate to track the second moment of past gradients. + eps: A small constant applied to denominator outside of the square root (as + in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside + the square root (as in RMSProp), to avoid dividing by zero when rescaling. + This is needed for instance when computing (meta-)gradients through Adam. + weight_decay: Strength of the weight decay regularization. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_novograd( + b1=b1, b2=b2, eps=eps, eps_root=eps_root, weight_decay=weight_decay), + _scale_by_learning_rate(learning_rate), + ) + + +def optimistic_gradient_descent( + learning_rate: ScalarOrSchedule, + alpha: ScalarOrSchedule = 1.0, + beta: ScalarOrSchedule = 1.0 +) -> base.GradientTransformation: + """An Optimistic Gradient Descent optimizer. + + Optimistic gradient descent is an approximation of extra-gradient methods + which require multiple gradient calls to compute the next update. It has + strong formal guarantees for last-iterate convergence in min-max games, for + which standard gradient descent can oscillate or even diverge. + + References: + Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2 + + Args: + learning_rate: A fixed global scaling factor. + alpha: Coefficient for generalized OGD. + beta: Coefficient for generalized OGD negative momentum. + + Returns: + A `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_optimistic_gradient(alpha=alpha, beta=beta), + _scale_by_learning_rate(learning_rate) + ) + + +def radam( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + threshold: float = 5.0 +) -> base.GradientTransformation: + """The Rectified Adam optimizer. + + The adaptive learning rate in Adam has undesirably large variance in early + stages of training, due to the limited number of training samples used to + estimate the optimizer's statistics. Rectified Adam addresses this issue + by analytically reducing the large variance. + + References: + Kingma et al, 2014: https://arxiv.org/abs/1412.6980 + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the second moment of past gradients. + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + instance when computing (meta-)gradients through Adam. + threshold: Threshold for variance tractability. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_radam( + b1=b1, b2=b2, eps=eps, eps_root=eps_root, threshold=threshold), + _scale_by_learning_rate(learning_rate), + ) + + +def rmsprop( + learning_rate: ScalarOrSchedule, + decay: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0., + centered: bool = False, + momentum: Optional[float] = None, + nesterov: bool = False +) -> base.GradientTransformation: + # pylint: disable=line-too-long + """A flexible RMSProp optimizer. + + RMSProp is an SGD variant with learning rate adaptation. The `learning_rate` + used for each weight is scaled by a suitable estimate of the magnitude of the + gradients on previous steps. Several variants of RMSProp can be found + in the literature. This alias provides an easy to configure RMSProp + optimizer that can be used to switch between several of these variants. + + References: + Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf + Graves, 2013: https://arxiv.org/abs/1308.0850 + + Args: + learning_rate: A fixed global scaling factor. + decay: Decay used to track the magnitude of previous gradients. + eps: A small numerical constant to avoid dividing by zero when rescaling. + initial_scale: Initial value of accumulators tracking the magnitude of + previous updates. PyTorch uses `0`, TF1 uses `1`. When reproducing results + from a paper, verify the value used by the authors. + centered: Whether the second moment or the variance of the past gradients is + used to rescale the latest gradients. + momentum: Decay rate used by the momentum term, when it is set to `None`, + then momentum is not used at all. + nesterov: Whether Nesterov momentum is used. + + Returns: + The corresponding `GradientTransformation`. + """ + # pylint: enable=line-too-long + if centered: + return combine.chain( + transform.scale_by_stddev( + decay=decay, eps=eps, initial_scale=initial_scale), + _scale_by_learning_rate(learning_rate), + (transform.trace(decay=momentum, nesterov=nesterov) + if momentum is not None else base.identity()) + ) + return combine.chain( + transform.scale_by_rms( + decay=decay, eps=eps, initial_scale=initial_scale), + _scale_by_learning_rate(learning_rate), + (transform.trace(decay=momentum, nesterov=nesterov) + if momentum is not None else base.identity()) + ) + + +def sgd( + learning_rate: ScalarOrSchedule, + momentum: Optional[float] = None, + nesterov: bool = False, + accumulator_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """A canonical Stochastic Gradient Descent optimizer. + + This implements stochastic gradient descent. It also includes support for + momentum, and nesterov acceleration, as these are standard practice when + using stochastic gradient descent to train deep neural networks. + + References: + Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf + + Args: + learning_rate: A fixed global scaling factor. + momentum: Decay rate used by the momentum term, when it is set to `None`, + then momentum is not used at all. + nesterov: Whether Nesterov momentum is used. + accumulator_dtype: Optional `dtype` to be used for the accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation`. + """ + return combine.chain( + (transform.trace(decay=momentum, nesterov=nesterov, + accumulator_dtype=accumulator_dtype) + if momentum is not None else base.identity()), + _scale_by_learning_rate(learning_rate) + ) + + +def sm3( + learning_rate: float, + momentum: float = 0.9 +) -> base.GradientTransformation: + """The SM3 optimizer. + + SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a + memory-efficient adaptive optimizer designed to decrease memory overhead when + training very large models, such as the Transformer for machine translation, + BERT for language modeling, and AmoebaNet-D for image classification. SM3: 1) + applies to tensors of arbitrary dimensions and any predefined cover of the + parameters; 2) adapts the learning rates in an adaptive and data-driven manner + (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence + guarantees in stochastic convex optimization settings. + + References: + Anil et al, 2019: https://arxiv.org/abs/1901.11150 + + Args: + learning_rate: A fixed global scaling factor. + momentum: Decay rate used by the momentum term (when it is not set to + `None`, then momentum is not used at all). + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_sm3(momentum), + transform.scale(-learning_rate), + ) + + +def yogi( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-3, +) -> base.GradientTransformation: + # pylint: disable=line-too-long + """The Yogi optimizer. + + Yogi is an adaptive optimizer, which provides control in tuning the effective + learning rate to prevent it from increasing. By doing so, it focuses on + addressing the issues of convergence and generalization in exponential moving + average-based adaptive methods (such as Adam and RMSprop). Yogi is a + modification of Adam and uses the same parameters. + + References: + Zaheer et al, 2018: https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the second moment of past gradients. + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + + Returns: + The corresponding `GradientTransformation`. + """ + # pylint: enable=line-too-long + return combine.chain( + transform.scale_by_yogi(b1=b1, b2=b2, eps=eps), + _scale_by_learning_rate(learning_rate), + ) + + +def dpsgd( + learning_rate: ScalarOrSchedule, + l2_norm_clip: float, + noise_multiplier: float, + seed: int, + momentum: Optional[float] = None, + nesterov: bool = False +) -> base.GradientTransformation: + """The DPSGD optimizer. + + Differential privacy is a standard for privacy guarantees of algorithms + learning from aggregate databases including potentially sensitive information. + DPSGD offers protection against a strong adversary with full knowledge of the + training mechanism and access to the model’s parameters. + + WARNING: This `GradientTransformation` expects input updates to have a batch + dimension on the 0th axis. That is, this function expects per-example + gradients as input (which are easy to obtain in JAX using `jax.vmap`). + + References: + Abadi et al, 2016: https://arxiv.org/abs/1607.00133 + + Args: + learning_rate: A fixed global scaling factor. + l2_norm_clip: Maximum L2 norm of the per-example gradients. + noise_multiplier: Ratio of standard deviation to the clipping norm. + seed: Initial seed used for the jax.random.PRNGKey + momentum: Decay rate used by the momentum term, when it is set to `None`, + then momentum is not used at all. + nesterov: Whether Nesterov momentum is used. + + Returns: + A `GradientTransformation`. + """ + return combine.chain( + privacy.differentially_private_aggregate( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + seed=seed), + (transform.trace(decay=momentum, nesterov=nesterov) + if momentum is not None else base.identity()), + _scale_by_learning_rate(learning_rate) + ) + + +def adamax( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, +) -> base.GradientTransformation: + """A variant of the Adam optimizer that uses the infinity norm. + + References: + Kingma et al, 2014: https://arxiv.org/abs/1412.6980 + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the maximum of past gradients. + eps: A small constant applied to denominator to avoid dividing by zero when + rescaling. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_adamax(b1=b1, b2=b2, eps=eps,), + _scale_by_learning_rate(learning_rate), + ) + + +def adamaxw( + learning_rate: ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + weight_decay: float = 1e-4, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, +) -> base.GradientTransformation: + """Adamax with weight decay regularization. + + AdamaxW uses weight decay to regularize learning towards small weights, as + this leads to better generalization. In SGD you can also use L2 regularization + to implement this as an additive loss term, however L2 regularization + does not behave as intended for adaptive gradient algorithms such as Adam. + + WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or + for the bias parameters. You can use `optax.masked` to make your own AdamaxW + variant where `additive_weight_decay` is applied only to a subset of `params`. + + References: + Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 + + Args: + learning_rate: A fixed global scaling factor. + b1: Exponential decay rate to track the first moment of past gradients. + b2: Exponential decay rate to track the maximum of past gradients. + eps: A small constant applied to denominator to avoid dividing by zero when + rescaling. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent + with other frameworks such as PyTorch, but different from + (Loshchilov et al, 2019) where the weight decay is only multiplied with + the "schedule multiplier", but not the base learning rate. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Adamax gradient transformations are applied to all parameters. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + transform.scale_by_adamax(b1=b1, b2=b2, eps=eps), + transform.add_decayed_weights(weight_decay, mask), + _scale_by_learning_rate(learning_rate), + ) diff --git a/optax_add_eve/_src/alias_test.py b/optax_add_eve/_src/alias_test.py new file mode 100644 index 00000000..46f0643d --- /dev/null +++ b/optax_add_eve/_src/alias_test.py @@ -0,0 +1,186 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `alias.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import alias +from optax_add_eve._src import numerics +from optax_add_eve._src import schedule +from optax_add_eve._src import update + +_OPTIMIZERS_UNDER_TEST = ( + dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3, momentum=0.9)), + dict(opt_name='adafactor', opt_kwargs=dict(learning_rate=5e-3)), + dict(opt_name='adagrad', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adam', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)), + dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), + dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)), + dict( + opt_name='optimistic_gradient_descent', + opt_kwargs=dict(learning_rate=2e-3, alpha=0.7, beta=0.1)), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3)), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3, momentum=0.9)), + dict(opt_name='fromage', opt_kwargs=dict(learning_rate=5e-3)), + dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1e-2)), + dict(opt_name='radam', opt_kwargs=dict(learning_rate=5e-3)), + dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1e-1)), + dict( + opt_name='dpsgd', + opt_kwargs=dict( + learning_rate=1e-3, + l2_norm_clip=10., + noise_multiplier=1e-3, + seed=0, + momentum=0.2)), +) + + +def _setup_parabola(dtype): + """Quadratic function as an optimization target.""" + initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) + final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) + + if jnp.iscomplexobj(dtype): + final_params *= 1 + 1j + + @jax.grad + def get_updates(params): + return jnp.sum(numerics.abs_sq(params - final_params)) + + return initial_params, final_params, get_updates + + +def _setup_rosenbrock(dtype): + """Rosenbrock function as an optimization target.""" + a = 1.0 + b = 100.0 + + if jnp.iscomplexobj(dtype): + a *= 1 + 1j + + initial_params = jnp.array([0.0, 0.0], dtype=dtype) + final_params = jnp.array([a, a**2], dtype=dtype) + + @jax.grad + def get_updates(params): + return (numerics.abs_sq(a - params[0]) + + b * numerics.abs_sq(params[1] - params[0]**2)) + + return initial_params, final_params, get_updates + + +class AliasTest(chex.TestCase): + + @parameterized.product( + _OPTIMIZERS_UNDER_TEST, + target=(_setup_parabola, _setup_rosenbrock), + dtype=(jnp.float32, jnp.complex64), + ) + def test_optimization(self, opt_name, opt_kwargs, target, dtype): + if (opt_name + in ('fromage', 'noisy_sgd', 'sm3', 'optimistic_gradient_descent') and + jnp.iscomplexobj(dtype)): + raise absltest.SkipTest( + f'{opt_name} does not support complex parameters.') + + opt = getattr(alias, opt_name)(**opt_kwargs) + initial_params, final_params, get_updates = target(dtype) + + @jax.jit + def step(params, state): + updates = get_updates(params) + if opt_name == 'dpsgd': + updates = updates[None] + # Complex gradients need to be conjugated before being added to parameters + # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 + updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) + updates, state = opt.update(updates, state, params) + params = update.apply_updates(params, updates) + return params, state + + params = initial_params + state = opt.init(params) + for _ in range(10000): + params, state = step(params, state) + + chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) + + @chex.all_variants + @parameterized.product(_OPTIMIZERS_UNDER_TEST) + def test_optimizers_can_be_wrapped_in_inject_hyperparams( + self, opt_name, opt_kwargs): + """Checks that optimizers can be wrapped in inject_hyperparams.""" + # See also https://github.com/deepmind/optax/issues/412. + opt_factory = getattr(alias, opt_name) + opt = opt_factory(**opt_kwargs) + if opt_name == 'adafactor': + # Adafactor wrapped in inject_hyperparams currently needs a static + # argument to be specified in order to be jittable. See issue + # https://github.com/deepmind/optax/issues/412. + opt_inject = schedule.inject_hyperparams( + opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs) + else: + opt_inject = schedule.inject_hyperparams(opt_factory)(**opt_kwargs) + + params = [-jnp.ones((2, 3)), jnp.ones((2, 5, 2))] + grads = [jnp.ones((2, 3)), -jnp.ones((2, 5, 2))] + + state = self.variant(opt.init)(params) + updates, new_state = self.variant(opt.update)(grads, state, params) + + state_inject = self.variant(opt_inject.init)(params) + updates_inject, new_state_inject = self.variant(opt_inject.update)( + grads, state_inject, params) + + with self.subTest('Equality of updates.'): + chex.assert_trees_all_close(updates_inject, updates, rtol=1e-4) + with self.subTest('Equality of new optimizer states.'): + chex.assert_trees_all_close( + new_state_inject.inner_state, new_state, rtol=1e-4) + + @parameterized.named_parameters([ + ('float32', 'float32'), + ('bfloat16', 'bfloat16'), + ('complex64', 'complex64'), + ('None', None), + ]) + def test_explicit_dtype(self, dtype): + expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32 + tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype) + trace_state, _ = tx.init(jnp.array([0.0, 0.0])) + self.assertEqual(expected_dtype, trace_state.trace.dtype) + tx = alias.adam(0.1, mu_dtype=dtype) + adam_state, _ = tx.init(jnp.array([0.0, 0.0])) + self.assertEqual(expected_dtype, adam_state.mu.dtype) + tx = alias.adamw(0.1, mu_dtype=dtype) + adam_state, _, _ = tx.init(jnp.array([0.0, 0.0])) + self.assertEqual(expected_dtype, adam_state.mu.dtype) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/base_test.py b/optax_add_eve/_src/base_test.py new file mode 100644 index 00000000..65c898b4 --- /dev/null +++ b/optax_add_eve/_src/base_test.py @@ -0,0 +1,139 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for base.py.""" + +from absl.testing import absltest + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import base + +# pylint:disable=no-value-for-parameter + + +class BaseTest(chex.TestCase): + + def test_typing(self): + """Ensure that the type annotations work for the update function.""" + + def f(updates, opt_state, params=None): + del params + return updates, opt_state + + def g(f: base.TransformUpdateFn): + updates = np.zeros([]) + params = np.zeros([]) + opt_state = np.zeros([]) + + f(updates, opt_state) + f(updates, opt_state, params) + f(updates, opt_state, params=params) + + g(f) + + @chex.all_variants + def test_set_to_zero_returns_tree_of_correct_zero_arrays(self): + """Tests that zero transform returns a tree of zeros of correct shape.""" + grads = ({'a': np.ones((3, 4)), 'b': 1.}, np.ones((1, 2, 3))) + updates, _ = self.variant(base.set_to_zero().update)(grads, + base.EmptyState()) + correct_zeros = ({'a': np.zeros((3, 4)), 'b': 0.}, np.zeros((1, 2, 3))) + chex.assert_trees_all_close(updates, correct_zeros, rtol=0) + + @chex.all_variants(with_pmap=False) + def test_set_to_zero_is_stateless(self): + """Tests that the zero transform returns an empty state.""" + self.assertEqual( + self.variant(base.set_to_zero().init)(params=None), base.EmptyState()) + + +class StatelessTest(chex.TestCase): + """Tests for the stateless transformation.""" + + @chex.all_variants + def test_stateless(self): + params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1,))} + updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1,), 2.0)} + + @base.stateless + def opt(g, p): + return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p) + + state = opt.init(params) + update_fn = self.variant(opt.update) + new_updates, _ = update_fn(updates, state, params) + expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])} + chex.assert_trees_all_close(new_updates, expected_updates) + + @chex.all_variants + def test_stateless_no_params(self): + updates = {'linear': jnp.full((5, 3), 3.0)} + + @base.stateless + def opt(g, _): + return jax.tree_util.tree_map(lambda g_: g_ * 2, g) + + state = opt.init(None) + update_fn = self.variant(opt.update) + new_updates, _ = update_fn(updates, state) + expected_updates = {'linear': jnp.full((5, 3), 6.0)} + chex.assert_trees_all_close(new_updates, expected_updates) + + def test_init_returns_emptystate(self): + def weight_decay(g, p): + return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p) + + opt = base.stateless(weight_decay) + state = opt.init(None) + self.assertIsInstance(state, base.EmptyState) + + +class StatelessWithTreeMapTest(chex.TestCase): + """Tests for the stateless_with_tree_map transformation.""" + + @chex.all_variants + def test_stateless_with_tree_map(self): + params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1,))} + updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1,), 2.0)} + + opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p) + state = opt.init(params) + update_fn = self.variant(opt.update) + new_updates, _ = update_fn(updates, state, params) + expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])} + chex.assert_trees_all_close(new_updates, expected_updates) + + @chex.all_variants + def test_stateless_with_tree_map_no_params(self): + updates = {'linear': jnp.full((5, 3), 3.0)} + + opt = base.stateless_with_tree_map(lambda g, _: g * 2.0) + state = opt.init(None) + update_fn = self.variant(opt.update) + new_updates, _ = update_fn(updates, state) + expected_updates = {'linear': jnp.full((5, 3), 6.0)} + chex.assert_trees_all_close(new_updates, expected_updates) + + def test_init_returns_emptystate(self): + opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p) + state = opt.init(None) + self.assertIsInstance(state, base.EmptyState) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/clipping.py b/optax_add_eve/_src/clipping.py new file mode 100644 index 00000000..5eb1dc9d --- /dev/null +++ b/optax_add_eve/_src/clipping.py @@ -0,0 +1,222 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient clipping transformations. + +Note that complex numbers are also supported, see +https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 +""" +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base +from optax_add_eve._src import linear_algebra +from optax_add_eve._src import numerics + +ClipState = base.EmptyState + + +def clip(max_delta: chex.Numeric) -> base.GradientTransformation: + """Clips updates element-wise, to be in ``[-max_delta, +max_delta]``. + + Args: + max_delta: The maximum absolute value for each element in the update. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ClipState() + + def update_fn(updates, state, params=None): + del params + updates = jax.tree_util.tree_map( + lambda g: jnp.clip(g, -max_delta, max_delta), updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def clip_by_block_rms(threshold: float) -> base.GradientTransformation: + """Clips updates to a max rms for the gradient of each param vector or matrix. + + A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix + (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. + + Args: + threshold: The maximum rms for the gradient of each param vector or matrix. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return base.EmptyState() + + def update_fn(updates, state, params=None): + del params + + def _clip_fn(u): + clip_denom = jnp.maximum( + 1.0, + jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold) + return u / clip_denom + + updates = jax.tree_util.tree_map(_clip_fn, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +ClipByGlobalNormState = base.EmptyState + + +def clip_by_global_norm(max_norm: float) -> base.GradientTransformation: + """Clips updates using their global norm. + + References: + [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) + + Args: + max_norm: The maximum global norm for an update. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ClipByGlobalNormState() + + def update_fn(updates, state, params=None): + del params + g_norm = linear_algebra.global_norm(updates) + # TODO(b/163995078): revert back to the following (faster) implementation + # once analysed how it affects backprop through update (e.g. meta-gradients) + # g_norm = jnp.maximum(max_norm, g_norm) + # updates = jax.tree_util.tree_map( + # lambda t: (t / g_norm) * max_norm, updates) + trigger = jnp.squeeze(g_norm < max_norm) + chex.assert_shape(trigger, ()) # A scalar. + + def clip_fn(t): + return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm) + + updates = jax.tree_util.tree_map(clip_fn, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def per_example_global_norm_clip(grads: chex.Array, + l2_norm_clip: float) -> Tuple[chex.Array, int]: + """Applies gradient clipping per-example using their global norm. + + References: + [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) + + Args: + grads: flattened update; the function expects these to have a batch + dimension on the 0th axis. + l2_norm_clip: maximum L2 norm of the per-example gradients. + + Returns: + A tuple containing sum of the clipped per-example grads, and the number of + per-example grads that were clipped. + """ + bsize = grads[0].shape[0] + + if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): + raise ValueError( + 'Unlike other transforms, `per_example_global_norm_clip` expects' + ' `grads` to have a batch dimension in the 0th axis.') + + global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads) + divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) + num_clipped = jnp.greater(divisors, 1.0).sum() + clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads] + return clipped_sum, num_clipped + + +def unitwise_norm(x: chex.Array) -> chex.Array: + """Computes norms of each output unit separately.""" + if jnp.squeeze(x).ndim <= 1: # Scalars and vectors + squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True) + # Note that this assumes parameters with a shape of length 3 are multihead + # linear parameters--if you wish to apply AGC to 1D convs, you may need + # to modify this line. + elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear + squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True) + elif x.ndim == 4: # Conv kernels of shape HWIO + squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True) + else: + raise ValueError( + f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.') + chex.assert_is_broadcastable(squared_norm.shape, x.shape) + return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape) + + +def unitwise_clip(g_norm: chex.Array, + max_norm: chex.Array, + grad: chex.Array, + div_eps: float = 1e-6) -> chex.Array: + """Applies gradient clipping unit-wise.""" + # This little max(., div_eps) is distinct from the normal eps and just + # prevents division by zero. It technically should be impossible to engage. + clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps)) + chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad)) + return jnp.where(g_norm < max_norm, grad, clipped_grad) + + +AdaptiveGradClipState = base.EmptyState + + +def adaptive_grad_clip(clipping: float, + eps: float = 1e-3) -> base.GradientTransformation: + """Clips updates to be at most ``clipping * parameter_norm``, unit-wise. + + References: + [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image + Recognition Without Normalization. (https://arxiv.org/abs/2102.06171) + + Args: + clipping: The maximum allowed ratio of update norm to parameter norm. + eps: An epsilon term to prevent clipping of zero-initialized params. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return AdaptiveGradClipState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + g_norm, p_norm = jax.tree_util.tree_map(unitwise_norm, (updates, params)) + # Maximum allowable norm. + max_norm = jax.tree_util.tree_map( + lambda x: clipping * jnp.maximum(x, eps), p_norm) + # If grad norm > clipping * param_norm, rescale. + updates = jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/clipping_test.py b/optax_add_eve/_src/clipping_test.py new file mode 100644 index 00000000..e2676284 --- /dev/null +++ b/optax_add_eve/_src/clipping_test.py @@ -0,0 +1,96 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `clipping.py`.""" + +from absl.testing import absltest + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import clipping +from optax_add_eve._src import linear_algebra + +STEPS = 50 +LR = 1e-2 + + +class ClippingTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + + def test_clip(self): + updates = self.per_step_updates + # For a sufficiently high delta the update should not be changed. + clipper = clipping.clip(1e6) + clipped_updates, _ = clipper.update(updates, None) + chex.assert_trees_all_close(clipped_updates, clipped_updates) + # Clipping at delta=1 should make all updates exactly 1. + clipper = clipping.clip(1.) + clipped_updates, _ = clipper.update(updates, None) + chex.assert_trees_all_close( + clipped_updates, jax.tree_util.tree_map(jnp.ones_like, updates)) + + def test_clip_by_block_rms(self): + rmf_fn = lambda t: jnp.sqrt(jnp.mean(t**2)) + updates = self.per_step_updates + for i in range(1, STEPS + 1): + clipper = clipping.clip_by_block_rms(1. / i) + # Check that the clipper actually works and block rms is <= threshold + updates, _ = clipper.update(updates, None) + self.assertAlmostEqual(rmf_fn(updates[0]), 1. / i) + self.assertAlmostEqual(rmf_fn(updates[1]), 1. / i) + # Check that continuously clipping won't cause numerical issues. + updates_step, _ = clipper.update(self.per_step_updates, None) + chex.assert_trees_all_close(updates, updates_step) + + def test_clip_by_global_norm(self): + updates = self.per_step_updates + for i in range(1, STEPS + 1): + clipper = clipping.clip_by_global_norm(1. / i) + # Check that the clipper actually works and global norm is <= max_norm + updates, _ = clipper.update(updates, None) + self.assertAlmostEqual( + linear_algebra.global_norm(updates), 1. / i, places=6) + # Check that continuously clipping won't cause numerical issues. + updates_step, _ = clipper.update(self.per_step_updates, None) + chex.assert_trees_all_close(updates, updates_step) + + def test_adaptive_grad_clip(self): + updates = self.per_step_updates + params = self.init_params + for i in range(1, STEPS + 1): + clip_r = 1. / i + clipper = clipping.adaptive_grad_clip(clip_r) + + # Check that the clipper actually works and upd_norm is < c * param_norm. + updates, _ = clipper.update(updates, None, params) + u_norm, p_norm = jax.tree_util.tree_map( + clipping.unitwise_norm, (updates, params)) + cmp = jax.tree_util.tree_map( + lambda u, p, c=clip_r: u - c * p < 1e-6, u_norm, p_norm) + for leaf in jax.tree_util.tree_leaves(cmp): + self.assertTrue(leaf.all()) + + # Check that continuously clipping won't cause numerical issues. + updates_step, _ = clipper.update(self.per_step_updates, None, params) + chex.assert_trees_all_close(updates, updates_step) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/combine.py b/optax_add_eve/_src/combine.py new file mode 100644 index 00000000..a3a4542a --- /dev/null +++ b/optax_add_eve/_src/combine.py @@ -0,0 +1,150 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flexibly compose gradient transformations.""" + +from typing import Callable, NamedTuple, Union, Mapping, Hashable + +import jax + +from optax_add_eve._src import base +from optax_add_eve._src import wrappers + + +def chain( + *args: base.GradientTransformation +) -> base.GradientTransformation: + """Applies a list of chainable update transformations. + + Given a sequence of chainable transforms, `chain` returns an `init_fn` + that constructs a `state` by concatenating the states of the individual + transforms, and returns an `update_fn` which chains the update transformations + feeding the appropriate state to each. + + Args: + *args: a sequence of chainable (init_fn, update_fn) tuples. + + Returns: + A single (init_fn, update_fn) tuple. + """ + + init_fns, update_fns = zip(*args) + + def init_fn(params): + return tuple(fn(params) for fn in init_fns) + + def update_fn(updates, state, params=None): + if len(update_fns) != len(state): + raise ValueError('The number of updates and states has to be the same in ' + 'chain! Make sure you have called init first!') + + new_state = [] + for s, fn in zip(state, update_fns): + updates, new_s = fn(updates, s, params) + new_state.append(new_s) + return updates, tuple(new_state) + + return base.GradientTransformation(init_fn, update_fn) + + +class MultiTransformState(NamedTuple): + inner_states: Mapping[Hashable, NamedTuple] + + +def multi_transform( + transforms: Mapping[Hashable, base.GradientTransformation], + param_labels: Union[base.PyTree, Callable[[base.PyTree], base.PyTree]] +) -> base.GradientTransformation: + """Partitions params and applies a different transformation to each subset. + + Below is an example where we apply Adam to the weights and SGD to the biases + of a 2-layer neural network:: + + import optax + import jax + import jax.numpy as jnp + + def map_nested_fn(fn): + '''Recursively apply `fn` to the key-value pairs of a nested dict''' + def map_fn(nested_dict): + return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) + for k, v in nested_dict.items()} + return map_fn + + params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, + 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} + gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients + + label_fn = map_nested_fn(lambda k, _: k) + tx = optax.multi_transform({'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, + label_fn) + state = tx.init(params) + updates, new_state = tx.update(gradients, state, params) + new_params = optax.apply_updates(params, updates) + + Instead of providing a ``label_fn``, you may provide a PyTree of labels + directly. Also, this PyTree may be a prefix of the parameters PyTree. This + is demonstrated in the GAN pseudocode below:: + + generator_params = ... + discriminator_params = ... + all_params = (generator_params, discriminator_params) + param_labels = ('generator', 'discriminator') + + tx = optax.multi_transform( + {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, + param_labels) + + If you would like to not optimize some parameters, you may wrap + ``optax.multi_transform`` with :func:`optax.masked`. + + Args: + transforms: A mapping from labels to transformations. Each transformation + will be only be applied to parameters with the same label. + param_labels: A PyTree that is the same shape or a prefix of the + parameters/updates (or a function that returns one given the parameters as + input). The leaves of this PyTree correspond to the keys of the transforms + (therefore the values at the leaves must be a subset of the keys). + + Returns: + An ``optax.GradientTransformation``. + """ + def make_mask(labels, group): + return jax.tree_util.tree_map(lambda label: label == group, labels) + + def init_fn(params): + labels = param_labels(params) if callable(param_labels) else param_labels + + label_set = set(jax.tree_util.tree_leaves(labels)) + if not label_set.issubset(transforms.keys()): + raise ValueError('Some parameters have no corresponding transformation.\n' + f'Parameter labels: {list(sorted(label_set))} \n' + f'Transforms keys: {list(sorted(transforms.keys()))} \n') + + inner_states = { + group: wrappers.masked(tx, make_mask(labels, group)).init(params) + for group, tx in transforms.items() + } + return MultiTransformState(inner_states) + + def update_fn(updates, state, params=None): + labels = param_labels(updates) if callable(param_labels) else param_labels + new_inner_state = {} + for group, tx in transforms.items(): + masked_tx = wrappers.masked(tx, make_mask(labels, group)) + updates, new_inner_state[group] = masked_tx.update( + updates, state.inner_states[group], params) + return updates, MultiTransformState(new_inner_state) + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/combine_test.py b/optax_add_eve/_src/combine_test.py new file mode 100644 index 00000000..122858e7 --- /dev/null +++ b/optax_add_eve/_src/combine_test.py @@ -0,0 +1,152 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `combine.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import alias +from optax_add_eve._src import combine +from optax_add_eve._src import transform +from optax_add_eve._src import update + + +STEPS = 50 +LR = 1e-2 + + +class ComposeTest(chex.TestCase): + + def setUp(self): + super().setUp() + self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + + @chex.all_variants + def test_chain(self): + transformations = [ + transform.scale_by_adam(), + transform.trace(decay=0, nesterov=False), + transform.scale(-LR)] + + # Apply updates with chain. + chain_params = self.init_params + chained_transforms = combine.chain(*transformations) + state = chained_transforms.init(chain_params) + self.assertIsInstance(state, tuple) + + @self.variant + def update_fn(updates, state): + return chained_transforms.update(updates, state) + + for _ in range(STEPS): + updates, state = update_fn(self.per_step_updates, state) + self.assertIsInstance(state, tuple) + chain_params = update.apply_updates(chain_params, updates) + + # Manually apply sequence of transformations. + manual_params = self.init_params + states = [t.init(manual_params) for t in transformations] + for _ in range(STEPS): + updates = self.per_step_updates + new_states = [] + for t, s in zip(transformations, states): + updates, state = t.update(updates, s) + new_states.append(state) + manual_params = update.apply_updates(manual_params, updates) + states = new_states + + # Check equivalence. + chex.assert_trees_all_close(manual_params, chain_params, rtol=1e-4) + + +def _map_keys_fn(fn): + def map_fn(nested_dict): + return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) + for k, v in nested_dict.items()} + return map_fn + + +class MultiTransformTest(chex.TestCase): + """Tests for the multi_transform wrapper.""" + + @chex.all_variants + @parameterized.parameters(True, False) + def test_multi_transform(self, use_fn): + params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}} + params = jax.tree_util.tree_map(jnp.asarray, params) + input_updates = jax.tree_util.tree_map(lambda x: x / 10.0, params) + tx_dict = {'a': transform.scale(-1.0), + 'b': transform.ema(0.0), # stateful + 'c': transform.scale(2.0)} + param_labels = _map_keys_fn(lambda k, _: k[0]) + if not use_fn: + param_labels = param_labels(params) + tx = combine.multi_transform(tx_dict, param_labels) + update_fn = self.variant(tx.update) + state = self.variant(tx.init)(params) + + correct_update_fn = _map_keys_fn( + lambda k, v: {'a': -v, 'b': v, 'c': 2.0*v}[k[0]]) + + updates, state = update_fn(input_updates, state, params) + correct_updates = correct_update_fn(input_updates) + chex.assert_trees_all_close(updates, correct_updates) + + # Check repeated application, this time with no params. + correct_updates = correct_update_fn(correct_updates) + updates, state = update_fn(updates, state) + chex.assert_trees_all_close(updates, correct_updates) + + @parameterized.parameters(list, tuple, dict) + def test_empty(self, container): + init_fn, update_fn = combine.multi_transform( + {0: alias.sgd(1.)}, lambda _: 0) + updates, _ = update_fn(container(), init_fn(container())) + self.assertEqual(updates, container()) + + @chex.all_variants + @parameterized.parameters( + (False, False), (False, True), (True, False), (True, True)) + def test_labels_mismatch(self, use_extra_label, use_fn): + # The labels from label_fn must be a subet of the keys for the tx. + params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} + params = jax.tree_util.tree_map(jnp.asarray, params) + label_tree = {'a': 0, 'b': [1, 0], 'c': 1} # prefix of params + + if use_extra_label: + label_tree['a'] = 3 + + transforms = {0: alias.sgd(1.), + 1: alias.adam(1., b1=0., b2=0.), + 2: transform.trace(1.0)} + init_fn, update_fn = combine.multi_transform( + transforms, (lambda _: label_tree) if use_fn else label_tree) + + if use_extra_label: + with self.assertRaises(ValueError): + self.variant(init_fn)(params) + else: + state = self.variant(init_fn)(params) + updates = jax.tree_util.tree_map(lambda x: x / 10.0, params) + self.variant(update_fn)(updates, state) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/constrain.py b/optax_add_eve/_src/constrain.py new file mode 100644 index 00000000..f1bf38e1 --- /dev/null +++ b/optax_add_eve/_src/constrain.py @@ -0,0 +1,97 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient transformations used to enforce specific constraints.""" + +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base + +# pylint:disable=no-value-for-parameter + + +NonNegativeParamsState = base.EmptyState + + +def keep_params_nonnegative() -> base.GradientTransformation: + """Modifies the updates to keep parameters non-negative, i.e. >= 0. + + This transformation ensures that parameters after the update will be + larger than or equal to zero. + In a chain of transformations, this should be the last one. + + WARNING: the transformation expects input params to be non-negative. + When params is negative the transformed update will move them to 0. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return NonNegativeParamsState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + + updates = jax.tree_util.tree_map( + lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class ZeroNansState(NamedTuple): + """Contains a tree. + + The entry `found_nan` has the same tree structure as that of the parameters. + Each leaf is a single boolean which contains True iff a NaN was detected in + the corresponding parameter array at the last call to `update`. + """ + found_nan: Any + + +def zero_nans() -> base.GradientTransformation: + """A transformation which replaces NaNs with 0. + + Zeroing values in gradients is guaranteed to produce a direction of + non-increasing loss. + + The state of the transformation has the same tree structure as that of the + parameters. Each leaf is a single boolean which contains True iff a NaN was + detected in the corresponding parameter array at the last call to `update`. + This state is not used by the transformation internally, but lets users be + aware when NaNs have been zeroed out. + + Returns: + A `GradientTransformation`. + """ + + def init_fn(params): + return ZeroNansState(jax.tree_util.tree_map( + lambda p: jnp.array(False, dtype=jnp.bool_), params)) + + def update_fn(updates, opt_state, params=None): + del params + opt_state = ZeroNansState( + jax.tree_util.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates)) + updates = jax.tree_util.tree_map( + lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates) + return updates, opt_state + + return base.GradientTransformation(init=init_fn, update=update_fn) diff --git a/optax_add_eve/_src/constrain_test.py b/optax_add_eve/_src/constrain_test.py new file mode 100644 index 00000000..ca52232b --- /dev/null +++ b/optax_add_eve/_src/constrain_test.py @@ -0,0 +1,116 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax._src.constrain.""" + +from absl.testing import absltest + +import chex +import jax.numpy as jnp + +from optax_add_eve._src import combine +from optax_add_eve._src import constrain +from optax_add_eve._src import transform +from optax_add_eve._src import update + +STEPS = 50 +LR = 1e-2 + + +class ConstraintsTest(chex.TestCase): + + def test_keep_params_nonnegative(self): + grads = (jnp.array([500., -500., 0.]), + jnp.array([500., -500., 0.]), + jnp.array([500., -500., 0.])) + + params = (jnp.array([-1., -1., -1.]), + jnp.array([1., 1., 1.]), + jnp.array([0., 0., 0.])) + + # vanilla sgd + opt = combine.chain( + transform.trace(decay=0, nesterov=False), transform.scale(-LR)) + opt_state = opt.init(params) + + updates, _ = opt.update(grads, opt_state, params) + new_params = update.apply_updates(params, updates) + + chex.assert_trees_all_close(new_params, (jnp.array([-6., 4., -1.]), + jnp.array([-4., 6., 1.]), + jnp.array([-5., 5., 0.]))) + + # sgd with keeping parameters non-negative + opt = combine.chain( + transform.trace(decay=0, nesterov=False), transform.scale(-LR), + constrain.keep_params_nonnegative()) + opt_state = opt.init(params) + + updates, _ = opt.update(grads, opt_state, params) + new_params = update.apply_updates(params, updates) + + chex.assert_trees_all_close(new_params, (jnp.array([0., 4., 0.]), + jnp.array([0., 6., 1.]), + jnp.array([0., 5., 0.]))) + + @chex.all_variants + def test_zero_nans(self): + params = (jnp.zeros([3]), jnp.zeros([3]), jnp.zeros([3])) + + opt = constrain.zero_nans() + opt_state = self.variant(opt.init)(params) + update_fn = self.variant(opt.update) + + chex.assert_trees_all_close( + opt_state, + constrain.ZeroNansState((jnp.array(False),) * 3)) + + # Check an upate with nans + grads_with_nans = (jnp.ones([3]), + jnp.array([1., float('nan'), float('nan')]), + jnp.array([float('nan'), 1., 1.])) + updates, opt_state = update_fn(grads_with_nans, opt_state) + chex.assert_trees_all_close( + opt_state, + constrain.ZeroNansState( + (jnp.array(False), jnp.array(True), jnp.array(True)))) + chex.assert_trees_all_close( + updates, + (jnp.ones([3]), jnp.array([1., 0., 0.]), jnp.array([0., 1., 1.]))) + + # Check an upate with nans and infs + grads_with_nans_infs = (jnp.ones([3]), + jnp.array([1., float('nan'), + float('nan')]), + jnp.array([float('inf'), 1., 1.])) + updates, opt_state = update_fn(grads_with_nans_infs, opt_state) + chex.assert_trees_all_close( + opt_state, + constrain.ZeroNansState( + (jnp.array(False), jnp.array(True), jnp.array(False)))) + chex.assert_trees_all_close(updates, (jnp.ones([3]), jnp.array( + [1., 0., 0.]), jnp.array([float('inf'), 1., 1.]))) + + # Check an upate with only good values + grads = (jnp.ones([3]), jnp.ones([3]), jnp.ones([3])) + updates, opt_state = update_fn(grads, opt_state) + chex.assert_trees_all_close( + opt_state, + constrain.ZeroNansState( + (jnp.array(False), jnp.array(False), jnp.array(False)))) + chex.assert_trees_all_close(updates, grads) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/control_variates.py b/optax_add_eve/_src/control_variates.py new file mode 100644 index 00000000..33316a76 --- /dev/null +++ b/optax_add_eve/_src/control_variates.py @@ -0,0 +1,419 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Implementation of control variates. + + We are interested in computing the gradient using control variates: + \nabla_{\theta} E_{p(x; \theta)} f(x) + = \nabla_{\theta} [E_{p(x; \theta)} f(x) - h(x; \theta) + E_{p(x; \theta)}] + = \nabla_{\theta} [E_{p(x; \theta)} f(x) - h(x; \theta)] + + \nabla_{\theta} E_{p(x; \theta)}] + = \nabla_{\theta} [E_{p(x; \theta)} f(x) - h(x; \theta)] + + \nabla_{\theta} E_{p(x; \theta)}] + = \nabla_{\theta} \int {p(x; \theta)} (f(x) - h(x; \theta)) dx + + \nabla_{\theta} E_{p(x; \theta)}] + = \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx + + [E_{p(x; \theta)} \nabla_{\theta} (f(x) - h(x; \theta)) + + \nabla_{\theta} E_{p(x; \theta)}] + = \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx + - [E_{p(x; \theta)} \nabla_{\theta} h(x; \theta) + + \nabla_{\theta} E_{p(x; \theta)}] + + The above computation is performed in `control_variates_jacobians`. + + When adding a new control variate, one does not need to implement the jacobian + computation, but instead has to implement the forward computation. + + Each control variate implemented has to satisfy the following API: + * control_variate(function) + This returns a tuple of three functions: + * The first element of the tuple is a function which returns the + control variate value for a set of samples. It takes in as + arguments the parameters used to construct the distribution, + the distributional samples, and the state of the control variate + (if any). The return value of this function will have shape + `num_samples`, where `num_samples` is the number of samples + provided as input. + * The second is a function returns the expected value of the control + variate. The input arguments of this function are the parameters + of the distribution and the state of the control variate. + * The third is a function which updates the state of the control + variate, and returns the updated states. + + For examples, see `control_delta_method` and `moving_avg_baseline`. +""" +from typing import Any, Callable, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base + + +CvState = Any +ComputeCv = Callable[[base.Params, chex.Array, CvState], float] +CvExpectedValue = Callable[[base.Params, CvState], CvState] +UpdateCvState = Callable[[base.Params, chex.Array, CvState], CvState] +ControlVariate = Tuple[ComputeCv, CvExpectedValue, UpdateCvState] + + +def control_delta_method( + function: Callable[[chex.Array], float]) -> ControlVariate: + """The control delta covariate method. + + Control variate obtained by performing a second order Taylor expansion + on the cost function f at the mean of the input distribution. + + Only implemented for Gaussian random variables. + + For details, see: https://icml.cc/2012/papers/687.pdf + + Args: + function: The function for which to compute the control variate. + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + + Returns: + A tuple of three functions, to compute the control variate, the + expected value of the control variate, and to update the control variate + state. + """ + + def delta( + params: base.Params, + sample: chex.Array, + state: CvState = None) -> chex.Array: + """"Second order expansion of `function` at the mean of the input dist.""" + del state + mean_dist = params[0] + centered_sample = sample - mean_dist + # Function is a function of samples. Here, we use the mean as the input + # since we do a Taylor expansion of function around the mean. + grads = jax.grad(function)(mean_dist) + hessians = jax.hessian(function)(mean_dist) + assert hessians.ndim == 2 + control_variate = function(mean_dist) + control_variate += jnp.dot(centered_sample, grads) + control_variate += jnp.dot( + jnp.dot(centered_sample, hessians), centered_sample) / 2. + return control_variate + + def expected_value_delta( + params: base.Params, state: CvState) -> float: + """"Expected value of second order expansion of `function` at dist mean.""" + del state + mean_dist = params[0] + var_dist = jnp.square(jnp.exp(params[1])) + hessians = jax.hessian(function)(mean_dist) + + assert hessians.ndim == 2 + hess_diags = jnp.diag(hessians) + assert hess_diags.ndim == 1 + + # Trace (Hessian * Sigma) and we use that Sigma is diagonal. + expected_second_order_term = jnp.sum(var_dist * hess_diags) / 2. + + expected_control_variate = function(mean_dist) + expected_control_variate += expected_second_order_term + return expected_control_variate + + def update_state( + params: base.Params, + samples: chex.Array, + state: CvState = None) -> CvState: + """"No state kept, so no operation is done.""" + del params, samples + return state + + return delta, expected_value_delta, update_state + + +def moving_avg_baseline( + function: Callable[[chex.Array], float], + decay: float = 0.99, + zero_debias: bool = True, + use_decay_early_training_heuristic=True) -> ControlVariate: + """A moving average baseline. + + It has no effect on the pathwise or measure valued estimator. + + Args: + function: The function for which to compute the control variate. + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + decay: The decay rate for the moving average. + zero_debias: Whether or not to use zero debiasing for the moving average. + use_decay_early_training_heuristic: Whether or not to use a heuristic which + overrides the decay value early in training based on + min(decay, (1.0 + i) / (10.0 + i)). This stabilises training and was + adapted from the Tensorflow codebase. + + Returns: + A tuple of three functions, to compute the control variate, the + expected value of the control variate, and to update the control variate + state. + """ + def moving_avg( + params: base.Params, + samples: chex.Array, + state: CvState = None) -> CvState: + """"Return the moving average.""" + del params, samples + return state[0] + + def expected_value_moving_avg( + params: base.Params, state: CvState) -> chex.Array: + """"Return the moving average.""" + del params + return state[0] + + def update_state( + params: base.Params, + samples: chex.Array, + state: CvState = None) -> CvState: + """"Update the moving average.""" + del params + value, i = state + + if use_decay_early_training_heuristic: + iteration_decay = jnp.minimum(decay, (1.0 + i) / (10.0 + i)) + else: + iteration_decay = decay + + updated_value = iteration_decay * value + updated_value += (1 - iteration_decay) * jnp.mean( + jax.vmap(function)(samples)) + + if zero_debias: + updated_value /= (jnp.ones([]) - jnp.power(iteration_decay, i + 1)) + + return (jax.lax.stop_gradient(updated_value), i + 1) + + return moving_avg, expected_value_moving_avg, update_state + + +def _map(cv, params, samples, state): + return jax.vmap(lambda x: cv(params, x, state))(samples) + + +def control_variates_jacobians( + function: Callable[[chex.Array], float], + control_variate_from_function: Callable[[Callable[[chex.Array], float]], + ControlVariate], + grad_estimator: Callable[..., jnp.array], + params: base.Params, + dist_builder: Callable[..., Any], + rng: chex.PRNGKey, + num_samples: int, + control_variate_state: CvState = None, + estimate_cv_coeffs: bool = False, + estimate_cv_coeffs_num_samples: int = 20) -> Tuple[ + Sequence[chex.Array], CvState]: + r"""Obtain jacobians using control variates. + + We will compute each term individually. The first term will use stochastic + gradient estimation. The second term will be computes using Monte + Carlo estimation and automatic differentiation to compute + \nabla_{\theta} h(x; \theta). The the third term will be computed using + automatic differentiation, as we restrict ourselves to control variates + which compute this expectation in closed form. + + This function updates the state of the control variate (once), before + computing the control variate coefficients. + + Args: + function: Function f(x) for which to estimate grads_{params} E_dist f(x). + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + control_variate_from_function: The control variate to use to reduce + variance. See `control_delta_method` and `moving_avg_baseline` examples. + grad_estimator: The gradient estimator to be used to compute the gradients. + Note that not all control variates will reduce variance for all + estimators. For example, the `moving_avg_baseline` will make no difference + to the measure valued or pathwise estimators. + params: A tuple of jnp arrays. + The parameters for which to construct the distribution and for which we + want to compute the jacobians. + dist_builder: a constructor which builds a distribution given the input + parameters specified by params. `dist_builder(params)` should return a + valid distribution. + rng: a PRNGKey key. + num_samples: Int, the number of samples used to compute the grads. + control_variate_state: The control variate state. This is used for control + variates which keep states (such as the moving average baselines). + estimate_cv_coeffs: Boolean. Whether or not to estimate the optimal control + variate coefficient via `estimate_control_variate_coefficients`. + estimate_cv_coeffs_num_samples: The number of samples to use to estimate + the optimal coefficient. These need to be new samples to ensure that the + objective is unbiased. + + Returns: + A tuple of size two: + * A tuple of size `params`, each element is `num_samples x param.shape` + jacobian vector containing the estimates of the gradients obtained + for each sample. + The mean of this vector is the gradient wrt to parameters that can be + used for learning. The entire jacobian vector can be used to assess + estimator variance. + * The updated CV state. + """ + control_variate = control_variate_from_function(function) + stochastic_cv, expected_value_cv, update_state_cv = control_variate + data_dim = params[0].shape[0] + if estimate_cv_coeffs: + cv_coeffs = estimate_control_variate_coefficients( + function, control_variate_from_function, grad_estimator, params, + dist_builder, rng, estimate_cv_coeffs_num_samples, + control_variate_state) + else: + cv_coeffs = [1.0] * len(params) + + # \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx + function_jacobians = grad_estimator( + function, params, dist_builder, rng, num_samples) + + # Chain rule since CVs can also depend on parameters - for example, for the + # pathwise gradient estimator they have in order to have an effect on + # gradient. + # The rng has to be the same as passed to the grad_estimator above so that we + # obtain the same samples. + samples = dist_builder(*params).sample((num_samples,), seed=rng) + # If the CV has state, update it. + control_variate_state = update_state_cv( + params, samples, control_variate_state) + + def samples_fn(x): + return stochastic_cv( + jax.lax.stop_gradient(params), x, control_variate_state) + + cv_jacobians = grad_estimator( + samples_fn, params, dist_builder, rng, num_samples) + + # The gradients of the stochastic covariate with respect to the parameters. + def param_fn(x): + return jnp.mean(_map( + stochastic_cv, x, + jax.lax.stop_gradient(samples), control_variate_state)) + + # [E_{p(x; \theta)} \nabla_{\theta} h(x; \theta) + cv_param_grads = jax.grad(param_fn)(params) + # The gradients of the closed form expectation of the control variate + # with respect to the parameters: # \nabla_{\theta} E_{p(x; \theta)}]. + expected_value_grads = jax.grad( + lambda x: expected_value_cv(x, control_variate_state))(params) + + jacobians = [] + for param_index, param in enumerate(params): + chex.assert_shape(function_jacobians[param_index], (num_samples, data_dim)) + chex.assert_shape(cv_jacobians[param_index], (num_samples, data_dim)) + chex.assert_shape(cv_param_grads[param_index], (data_dim,)) + chex.assert_shape(expected_value_grads[param_index], (data_dim,)) + + cv_coeff = cv_coeffs[param_index] + # \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx + param_jacobians = function_jacobians[param_index] + param_jacobians -= cv_coeff * cv_jacobians[param_index] + # - [E_{p(x; \theta)} \nabla_{\theta} h(x; \theta) + param_jacobians -= cv_coeff * cv_param_grads[param_index] + # \nabla_{\theta} E_{p(x; \theta)}] + param_jacobians += cv_coeff * expected_value_grads[param_index] + + chex.assert_shape(param_jacobians, (num_samples,) + param.shape) + jacobians.append(param_jacobians) + + return jacobians, control_variate_state + + +def estimate_control_variate_coefficients( + function: Callable[[chex.Array], float], + control_variate_from_function: Callable[[Callable[[chex.Array], float]], + ControlVariate], + grad_estimator: Callable[..., jnp.array], + params: base.Params, + dist_builder: Callable[..., Any], + rng: chex.PRNGKey, + num_samples: int, + control_variate_state: CvState = None, + eps: float = 1e-3) -> Sequence[float]: + r"""Estimates the control variate coefficients for the given parameters. + + For each variable `var_k`, the coefficient is given by: + \sum_k cov(df/d var_k, d cv/d var_k) / (\sum var(d cv/d var_k) + eps) + + Where var_k is the k'th element of the parameters in `params`. + The covariance and variance calculations are done from samples obtained + from the distribution obtained by calling `dist_builder` on the input + `params`. + + This function does not update the state of the control variate. + + Args: + function: Function f(x) for which to estimate grads_{params} E_dist f(x). + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + control_variate_from_function: The control variate to use to reduce + variance. See `control_delta_method` and `moving_avg_baseline` examples. + grad_estimator: The gradient estimator to be used to compute the gradients. + Note that not all control variates will reduce variance for all + estimators. For example, the `moving_avg_baseline` will make no difference + to the measure valued or pathwise estimators. + params: A tuple of jnp arrays. + The parameters for which to construct the distribution and for which we + want to compute the jacobians. + dist_builder: a constructor which builds a distribution given the input + parameters specified by params. `dist_builder(params)` should return a + valid distribution. + rng: a PRNGKey key. + num_samples: Int, the number of samples used to compute the grads. + control_variate_state: The control variate state. This is used for control + variates which keep states (such as the moving average baselines). + eps: A small constant used to avoid numerical issues. Float. + + Returns: + A list of control variate coefficients (each a scalar), for each parameter + in `params`. + """ + # Resample to avoid biased gradients. + cv_rng, _ = jax.random.split(rng) + del rng # Avoid using rng in this function. + stochastic_cv, _, _ = control_variate_from_function(function) + + # Samples have to be the same so we use the same rng. + cv_jacobians = grad_estimator( + lambda x: stochastic_cv(params, x, control_variate_state), + params, dist_builder, cv_rng, num_samples) + function_jacobians = grad_estimator( + function, params, dist_builder, cv_rng, num_samples) + + def compute_coeff(param_cv_jacs, param_f_jacs): + assert param_f_jacs.ndim == 2 + assert param_cv_jacs.ndim == 2 + + mean_f = jnp.mean(param_f_jacs, axis=0) + mean_cv = jnp.mean(param_cv_jacs, axis=0) + + cov = jnp.mean((param_f_jacs - mean_f) * (param_cv_jacs - mean_cv), axis=0) + + assert cov.ndim == 1 + + # Compute the coefficients which minimize variance. + # Since we want to minimize the variances across parameter dimensions, + # the optimal coefficients are given by the sum of covariances per + # dimensions over the sum of variances per dimension. + cv_coeff = jnp.sum(cov) / (jnp.sum(jnp.var(param_cv_jacs, axis=0)) + eps) + return jax.lax.stop_gradient(cv_coeff) + + return [compute_coeff(cv_jacobians[i], function_jacobians[i]) + for i in range(len(params))] diff --git a/optax_add_eve/_src/control_variates_test.py b/optax_add_eve/_src/control_variates_test.py new file mode 100644 index 00000000..3dc2edd7 --- /dev/null +++ b/optax_add_eve/_src/control_variates_test.py @@ -0,0 +1,595 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `control_variates.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import control_variates +from optax_add_eve._src import stochastic_gradient_estimators as sge +from optax_add_eve._src import utils + + +# Set seed for deterministic sampling. +np.random.seed(42) + + +def _assert_equal(actual, expected, rtol=1e-2, atol=1e-2): + """Asserts that arrays are equal.""" + # Note: assert_allclose does not check shapes + chex.assert_equal_shape((actual, expected)) + + # Scalar. + if not actual.shape: + np.testing.assert_allclose( + np.asarray(actual), np.asarray(expected), rtol, atol) + return + + # We get around the bug https://github.com/numpy/numpy/issues/13801 + zero_indices = np.argwhere(expected == 0) + if not np.all(np.abs(actual[zero_indices]) <= atol): + raise AssertionError(f'Larger than {atol} diff in {actual[zero_indices]}') + + non_zero_indices = np.argwhere(expected != 0) + np.testing.assert_allclose( + np.asarray(actual)[non_zero_indices], + expected[non_zero_indices], rtol, atol) + + +def _map(cv, params, samples, state=None): + return jax.vmap(lambda x: cv(params, x, state))(samples) + + +def _map_variant(variant): + return variant(_map, static_argnums=0) + + +def _cv_jac_variant(variant): + return variant( + control_variates.control_variates_jacobians, + static_argnums=(0, 1, 2, 4, 6, 7, 8)) + + +class DeltaControlVariateTest(chex.TestCase): + + @chex.all_variants + @parameterized.parameters([(1.0, 0.5)]) + def testQuadraticFunction(self, effective_mean, effective_log_scale): + data_dims = 20 + num_samples = 10**6 + rng = jax.random.PRNGKey(1) + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + params = [mean, log_scale] + + dist = utils.multi_normal(*params) + dist_samples = dist.sample((num_samples,), rng) + function = lambda x: jnp.sum(x**2) + + cv, expected_cv, _ = control_variates.control_delta_method(function) + avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) + expected_cv_value = jnp.sum(dist_samples**2) / num_samples + + # This should be an analytical computation, the result needs to be + # accurate. + _assert_equal(avg_cv, expected_cv_value, rtol=1e-1, atol=1e-3) + _assert_equal(expected_cv(params, None), expected_cv_value, rtol=0.02) + + @chex.all_variants + @parameterized.parameters([(1.0, 1.0)]) + def testPolinomialFunction(self, effective_mean, effective_log_scale): + data_dims = 10 + num_samples = 10**3 + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + params = [mean, log_scale] + + dist = utils.multi_normal(*params) + rng = jax.random.PRNGKey(1) + dist_samples = dist.sample((num_samples,), rng) + function = lambda x: jnp.sum(x**5) + + cv, expected_cv, _ = control_variates.control_delta_method(function) + avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) + + # Check that the average value of the control variate is close to the + # expected value. + _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3) + + @chex.all_variants + def testNonPolynomialFunction(self): + data_dims = 10 + num_samples = 10**3 + + mean = jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = jnp.ones(shape=(data_dims), dtype=jnp.float32) + params = [mean, log_scale] + + rng = jax.random.PRNGKey(1) + dist = utils.multi_normal(*params) + dist_samples = dist.sample((num_samples,), rng) + function = lambda x: jnp.sum(jnp.log(x**2)) + + cv, expected_cv, _ = control_variates.control_delta_method(function) + avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) + + # Check that the average value of the control variate is close to the + # expected value. + _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3) + + # Second order expansion is log(\mu**2) + 1/2 * \sigma**2 (-2 / \mu**2) + expected_cv_val = - np.exp(1.) ** 2 * data_dims + _assert_equal( + expected_cv(params, None), expected_cv_val, rtol=1e-1, atol=1e-3) + + +class MovingAverageBaselineTest(chex.TestCase): + + @chex.all_variants + @parameterized.parameters( + [(1.0, 0.5, 0.9), + (1.0, 0.5, 0.99)]) + def testLinearFunction( + self, effective_mean, effective_log_scale, decay): + weights = jnp.array([1., 2., 3.], dtype=jnp.float32) + num_samples = 10**4 + data_dims = len(weights) + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.sum(weights * x) + + rng = jax.random.PRNGKey(1) + dist = utils.multi_normal(*params) + dist_samples = dist.sample((num_samples,), rng) + + cv, expected_cv, update_state = control_variates.moving_avg_baseline( + function, decay=decay, zero_debias=False, + use_decay_early_training_heuristic=False) + + state_1 = jnp.array(1.) + avg_cv = jnp.mean(_map_variant(self.variant)( + cv, params, dist_samples, (state_1, 0))) + _assert_equal(avg_cv, state_1) + _assert_equal(expected_cv(params, (state_1, 0)), state_1) + + state_2 = jnp.array(2.) + avg_cv = jnp.mean( + _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0))) + _assert_equal(avg_cv, state_2) + _assert_equal(expected_cv(params, (state_2, 0)), state_2) + + update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] + _assert_equal( + update_state_1, + decay * state_1 + (1 - decay) * function(mean)) + + update_state_2 = update_state(params, dist_samples, (state_2, 0))[0] + _assert_equal( + update_state_2, + decay * state_2 + (1 - decay) * function(mean)) + + @chex.all_variants + @parameterized.parameters( + [(1.0, 0.5, 0.9), + (1.0, 0.5, 0.99)]) + def testLinearFunctionWithHeuristic( + self, effective_mean, effective_log_scale, decay): + weights = jnp.array([1., 2., 3.], dtype=jnp.float32) + num_samples = 10**5 + data_dims = len(weights) + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.sum(weights * x) + + rng = jax.random.PRNGKey(1) + dist = utils.multi_normal(*params) + dist_samples = dist.sample((num_samples,), rng) + + cv, expected_cv, update_state = control_variates.moving_avg_baseline( + function, decay=decay, zero_debias=False, + use_decay_early_training_heuristic=True) + + state_1 = jnp.array(1.) + avg_cv = jnp.mean(_map_variant(self.variant)( + cv, params, dist_samples, (state_1, 0))) + _assert_equal(avg_cv, state_1) + _assert_equal(expected_cv(params, (state_1, 0)), state_1) + + state_2 = jnp.array(2.) + avg_cv = jnp.mean( + _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0))) + _assert_equal(avg_cv, state_2) + _assert_equal(expected_cv(params, (state_2, 0)), state_2) + + first_step_decay = 0.1 + update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] + _assert_equal( + update_state_1, + first_step_decay * state_1 + (1 - first_step_decay) * function(mean)) + + second_step_decay = 2. / 11 + update_state_2 = update_state(params, dist_samples, (state_2, 1))[0] + _assert_equal( + update_state_2, + second_step_decay * state_2 + (1 - second_step_decay) * function(mean)) + + @parameterized.parameters( + [(1.0, 0.5, 0.9), + (1.0, 0.5, 0.99)]) + def testLinearFunctionZeroDebias( + self, effective_mean, effective_log_scale, decay): + weights = jnp.array([1., 2., 3.], dtype=jnp.float32) + num_samples = 10**5 + data_dims = len(weights) + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.sum(weights * x) + + rng = jax.random.PRNGKey(1) + dist = utils.multi_normal(*params) + dist_samples = dist.sample((num_samples,), rng) + + update_state = control_variates.moving_avg_baseline( + function, decay=decay, zero_debias=False, + use_decay_early_training_heuristic=False)[-1] + + update_state_zero_debias = control_variates.moving_avg_baseline( + function, decay=decay, zero_debias=True, + use_decay_early_training_heuristic=False)[-1] + + updated_state = update_state(params, dist_samples, (jnp.array(0.), 0))[0] + _assert_equal(updated_state, (1 - decay) * function(mean)) + + updated_state_zero_debias = update_state_zero_debias( + params, dist_samples, (jnp.array(0.), 0))[0] + _assert_equal( + updated_state_zero_debias, function(mean)) + + +class DeltaMethodAnalyticalExpectedGrads(chex.TestCase): + """Tests for grads approximations.""" + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), + ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), + ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), + ], [ + ('estimate_cv_coeffs', True), + ('no_estimate_cv_coeffs', False), + ], + named=True)) + def testQuadraticFunction(self, effective_mean, effective_log_scale, + grad_estimator, estimate_cv_coeffs): + data_dims = 3 + num_samples = 10**3 + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.sum(x**2) + rng = jax.random.PRNGKey(1) + + jacobians = _cv_jac_variant(self.variant)( + function, + control_variates.control_delta_method, + grad_estimator, + params, + utils.multi_normal, # dist_builder + rng, + num_samples, + None, # No cv state. + estimate_cv_coeffs)[0] + + expected_mean_grads = 2 * effective_mean * np.ones( + data_dims, dtype=np.float32) + expected_log_scale_grads = 2 * np.exp(2 * effective_log_scale) * np.ones( + data_dims, dtype=np.float32) + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) + + _assert_equal(mean_grads_from_jacobian, expected_mean_grads, + rtol=1e-1, atol=1e-3) + _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, + rtol=1e-1, atol=1e-3) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), + ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), + ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), + ], [ + ('estimate_cv_coeffs', True), + ('no_estimate_cv_coeffs', False), + ], + named=True)) + def testCubicFunction( + self, effective_mean, effective_log_scale, grad_estimator, + estimate_cv_coeffs): + data_dims = 1 + num_samples = 10**5 + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.sum(x**3) + rng = jax.random.PRNGKey(1) + + jacobians = _cv_jac_variant(self.variant)( + function, + control_variates.control_delta_method, + grad_estimator, + params, + utils.multi_normal, + rng, + num_samples, + None, # No cv state. + estimate_cv_coeffs)[0] + + # The third order uncentered moment of the Gaussian distribution is + # mu**3 + 2 mu * sigma **2. We use that to compute the expected value + # of the gradients. Note: for the log scale we need use the chain rule. + expected_mean_grads = ( + 3 * effective_mean**2 + 3 * np.exp(effective_log_scale)**2) + expected_mean_grads *= np.ones(data_dims, dtype=np.float32) + expected_log_scale_grads = ( + 6 * effective_mean * np.exp(effective_log_scale) ** 2) + expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32) + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) + + _assert_equal(mean_grads_from_jacobian, expected_mean_grads, + rtol=1e-1, atol=1e-3) + + _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, + rtol=1e-1, atol=1e-3) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), + ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), + ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), + ], [ + ('estimate_cv_coeffs', True), + ('no_estimate_cv_coeffs', False), + ], + named=True)) + def testForthPowerFunction( + self, effective_mean, effective_log_scale, grad_estimator, + estimate_cv_coeffs): + data_dims = 1 + num_samples = 10**5 + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.sum(x**4) + rng = jax.random.PRNGKey(1) + + jacobians = _cv_jac_variant(self.variant)( + function, + control_variates.control_delta_method, + grad_estimator, + params, + utils.multi_normal, + rng, + num_samples, + None, # No cv state + estimate_cv_coeffs)[0] + # The third order uncentered moment of the Gaussian distribution is + # mu**4 + 6 mu **2 sigma **2 + 3 sigma**4. We use that to compute the + # expected value of the gradients. + # Note: for the log scale we need use the chain rule. + expected_mean_grads = ( + 3 * effective_mean**3 + + 12 * effective_mean * np.exp(effective_log_scale)**2) + expected_mean_grads *= np.ones(data_dims, dtype=np.float32) + expected_log_scale_grads = 12 * ( + effective_mean**2 * np.exp(effective_log_scale) + + np.exp(effective_log_scale) ** 3) * np.exp(effective_log_scale) + expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32) + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) + + _assert_equal(mean_grads_from_jacobian, expected_mean_grads, + rtol=1e-1, atol=1e-3) + + _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, + rtol=1e-1, atol=1e-3) + + +class ConsistencyWithStandardEstimators(chex.TestCase): + """Tests for consistency between estimators.""" + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', 1, 1, sge.score_function_jacobians, + 10**6), + ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), + ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians, + 10**5), + ], [ + ('control_delta_method', control_variates.control_delta_method), + ('moving_avg_baseline', control_variates.moving_avg_baseline), + ], + named=True)) + def testWeightedLinearFunction(self, effective_mean, effective_log_scale, + grad_estimator, num_samples, + control_variate_from_function): + """Check that the gradients are consistent between estimators.""" + weights = jnp.array([1., 2., 3.], dtype=jnp.float32) + data_dims = len(weights) + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.sum(weights * x) + rng = jax.random.PRNGKey(1) + cv_rng, ge_rng = jax.random.split(rng) + + jacobians = _cv_jac_variant(self.variant)( + function, + control_variate_from_function, + grad_estimator, + params, + utils.multi_normal, # dist_builder + cv_rng, # rng + num_samples, + (0., 0), # control_variate_state + False)[0] + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads = jnp.mean(mean_jacobians, axis=0) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads = jnp.mean(log_scale_jacobians, axis=0) + + # We use a different random number generator for the gradient estimator + # without the control variate. + no_cv_jacobians = grad_estimator( + function, [mean, log_scale], + utils.multi_normal, ge_rng, num_samples=num_samples) + + no_cv_mean_jacobians = no_cv_jacobians[0] + chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims)) + no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0) + + no_cv_log_scale_jacobians = no_cv_jacobians[1] + chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims)) + no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0) + + _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2) + _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1, atol=5e-2) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', 1, 1, sge.score_function_jacobians, + 10**5), + ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), + ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians, + 10**5), + ], [ + ('control_delta_method', control_variates.control_delta_method), + ('moving_avg_baseline', control_variates.moving_avg_baseline), + ], + named=True)) + def testNonPolynomialFunction( + self, effective_mean, effective_log_scale, + grad_estimator, num_samples, control_variate_from_function): + """Check that the gradients are consistent between estimators.""" + data_dims = 3 + + mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) + log_scale = effective_log_scale * jnp.ones( + shape=(data_dims), dtype=jnp.float32) + + params = [mean, log_scale] + function = lambda x: jnp.log(jnp.sum(x**2)) + rng = jax.random.PRNGKey(1) + cv_rng, ge_rng = jax.random.split(rng) + + jacobians = _cv_jac_variant(self.variant)( + function, + control_variate_from_function, + grad_estimator, + params, + utils.multi_normal, + cv_rng, + num_samples, + (0., 0), # control_variate_state + False)[0] + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads = jnp.mean(mean_jacobians, axis=0) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads = jnp.mean(log_scale_jacobians, axis=0) + + # We use a different random number generator for the gradient estimator + # without the control variate. + no_cv_jacobians = grad_estimator( + function, [mean, log_scale], + utils.multi_normal, ge_rng, num_samples=num_samples) + + no_cv_mean_jacobians = no_cv_jacobians[0] + chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims)) + no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0) + + no_cv_log_scale_jacobians = no_cv_jacobians[1] + chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims)) + no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0) + + _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2) + _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1e-1, atol=5e-2) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/equivalence_test.py b/optax_add_eve/_src/equivalence_test.py new file mode 100644 index 00000000..9130e0c7 --- /dev/null +++ b/optax_add_eve/_src/equivalence_test.py @@ -0,0 +1,176 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests of equivalence between optax and other optimiser libraries.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +from flax import optim +from jax.example_libraries import optimizers +import jax.numpy as jnp + +from optax_add_eve._src import alias +from optax_add_eve._src import update + + +STEPS = 50 +LR = 1e-2 +LR_SCHED = lambda _: LR # Trivial constant "schedule". + + +class OptimizersEquivalenceTest(chex.TestCase): + + def setUp(self): + super().setUp() + self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4., 5.])) + self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3., 1.])) + + @chex.all_variants + @parameterized.named_parameters( + ('sgd', alias.sgd(LR, 0.0), + optimizers.sgd(LR), 1e-5), + ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), + optimizers.adam(LR, 0.9, 0.999), 1e-4), + ('rmsprop', alias.rmsprop(LR, decay=.9, eps=0.1), + optimizers.rmsprop(LR, .9, 0.1), 1e-5), + ('rmsprop_momentum', alias.rmsprop( + LR, decay=.9, eps=0.1, momentum=0.9), + optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), + ('adagrad', alias.adagrad(LR, 0., 0.,), + optimizers.adagrad(LR, 0.), 1e-5), + ('sgd', alias.sgd(LR_SCHED, 0.0), + optimizers.sgd(LR), 1e-5), + ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), + optimizers.adam(LR, 0.9, 0.999), 1e-4), + ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), + optimizers.rmsprop(LR, .9, 0.1), 1e-5), + ('rmsprop_momentum', alias.rmsprop( + LR_SCHED, decay=.9, eps=0.1, momentum=0.9), + optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), + ('adagrad', alias.adagrad(LR_SCHED, 0., 0.,), + optimizers.adagrad(LR, 0.), 1e-5), + ('sm3', alias.sm3(LR), optimizers.sm3(LR), 1e-2), + ) + def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): + + # example_libraries/optimizers.py + jax_params = self.init_params + opt_init, opt_update, get_params = jax_optimizer + state = opt_init(jax_params) + for i in range(STEPS): + state = opt_update(i, self.per_step_updates, state) + jax_params = get_params(state) + + # optax + optax_params = self.init_params + state = optax_optimizer.init(optax_params) + + @self.variant + def step(updates, state): + return optax_optimizer.update(updates, state) + + for _ in range(STEPS): + updates, state = step(self.per_step_updates, state) + optax_params = update.apply_updates(optax_params, updates) + + # Check equivalence. + chex.assert_trees_all_close(jax_params, optax_params, rtol=rtol) + + +class FlaxOptimizersEquivalenceTest(chex.TestCase): + + def setUp(self): + super().setUp() + self.init_params = ( + jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = ( + jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) + + @parameterized.named_parameters( + ('sgd', + alias.sgd(LR), + optim.GradientDescent(LR)), + ('momentum', + alias.sgd(LR, momentum=0.9), + optim.Momentum(LR, beta=0.9)), # Different names. + ('nesterov_momentum', + alias.sgd(LR, momentum=0.9, nesterov=True), + optim.Momentum(LR, beta=0.9, nesterov=True)), + ('rmsprop', + alias.rmsprop(LR), + optim.RMSProp(LR)), + ('centered_rmsprop', + alias.rmsprop(LR, centered=True), + optim.RMSProp(LR, centered=True)), + ('adam', + alias.adam(LR), + optim.Adam(LR)), + ('adam_w', + alias.adamw(LR, weight_decay=1e-4), + optim.Adam(LR, weight_decay=1e-4)), # Different name. + ('adagrad', + alias.adagrad(LR, initial_accumulator_value=0.), # Different default! + optim.Adagrad(LR)), + ('lamb', + alias.lamb(LR), + optim.LAMB(LR)), + ('lars', + alias.lars( + LR, weight_decay=.5, trust_coefficient=0.003, + momentum=0.9, eps=1e-3), + optim.LARS( + LR, weight_decay=.5, trust_coefficient=0.003, + beta=0.9, eps=1e-3)), + ('adafactor', + alias.adafactor( + learning_rate=LR / 10., + factored=True, + multiply_by_parameter_scale=True, + clipping_threshold=1.0, + decay_rate=0.8, + min_dim_size_to_factor=2), + optim.Adafactor( + learning_rate=LR / 10., + factored=True, + multiply_by_parameter_scale=True, + clipping_threshold=1.0, + decay_rate=0.8, + min_dim_size_to_factor=2)), + ) + def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): + + # flax/optim + flax_params = self.init_params + flax_optimizer = flax_optimizer.create(flax_params) + for _ in range(STEPS): + flax_optimizer = flax_optimizer.apply_gradient( + self.per_step_updates) + flax_params = flax_optimizer.target + + # optax + optax_params = self.init_params + state = optax_optimizer.init(optax_params) + for _ in range(STEPS): + updates, state = optax_optimizer.update( + self.per_step_updates, state, optax_params) + optax_params = update.apply_updates(optax_params, updates) + + # Check equivalence. + chex.assert_trees_all_close(flax_params, optax_params, rtol=2e-4) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/experimental/complex_valued.py b/optax_add_eve/_src/experimental/complex_valued.py new file mode 100644 index 00000000..5c1c7b54 --- /dev/null +++ b/optax_add_eve/_src/experimental/complex_valued.py @@ -0,0 +1,121 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Complex-valued optimization. + +When using `split_real_and_imaginary` to wrap an optimizer, we split the complex +parameters and updates into pairs of real ones before sending them to the +`update` of the wrapped optimizer, and merge the pairs of transformed real +updates into complex ones afterward. In this way, optimizers on complex +parameters behave the same way as if they were running on two real parameters. + +Note that the convention of conjugate for complex gradients in JAX is different +from that in PyTorch and other frameworks, and we need to manually conjugate the +gradients between `jax.grad` and `optimizer.update`. + +See details at https://github.com/deepmind/optax/issues/196 +""" + +from typing import NamedTuple, Union + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base + + +class SplitRealAndImaginaryArrays(NamedTuple): + """A pair of real arrays split from a complex array.""" + real: chex.Array + imaginary: chex.Array + + +def _complex_to_real_pair( + x: chex.Array +) -> Union[chex.Array, SplitRealAndImaginaryArrays]: + """Splits a complex array into a `SplitRealAndImaginaryArrays`. + + Args: + x: The input array, can be complex or real. + + Returns: + `SplitRealAndImaginaryArrays` if the input is a complex array. If the + input is a real array, it is passed through unmodified. + """ + if jnp.iscomplexobj(x): + return SplitRealAndImaginaryArrays(x.real, x.imag) + else: + return x + + +def _real_pair_to_complex( + x: Union[chex.Array, SplitRealAndImaginaryArrays] +) -> chex.Array: + """Merges a `SplitRealAndImaginaryArrays` into a complex array. + + Args: + x: The input `SplitRealAndImaginaryArrays` or array. + + Returns: + A complex array obtained from the real and imaginary parts of the + `SplitRealAndImaginaryArrays`. If the input is not a + `SplitRealAndImaginaryArrays`, it is passed through unmodified. + """ + if isinstance(x, SplitRealAndImaginaryArrays): + return x.real + x.imaginary * 1j + else: + return x + + +class SplitRealAndImaginaryState(NamedTuple): + """Maintains the inner transformation state for `split_real_and_imaginary`.""" + inner_state: base.OptState + + +def split_real_and_imaginary( + inner: base.GradientTransformation +) -> base.GradientTransformation: + """Splits the real and imaginary components of complex updates into two. + + The inner transformation processes real parameters and updates, and the + pairs of transformed real updates are merged into complex updates. + + Parameters and updates that are real before splitting are passed through + unmodified. + + Args: + inner: The inner transformation. + + Returns: + An `optax.GradientTransformation`. + """ + + def init_fn(params): + params = jax.tree_util.tree_map(_complex_to_real_pair, params) + inner_state = inner.init(params) + return SplitRealAndImaginaryState(inner_state) + + def update_fn(updates, state, params=None): + inner_state = state.inner_state + updates = jax.tree_util.tree_map(_complex_to_real_pair, updates) + params = jax.tree_util.tree_map(_complex_to_real_pair, params) + updates, inner_state = inner.update(updates, inner_state, params) + updates = jax.tree_util.tree_map( + _real_pair_to_complex, + updates, + is_leaf=lambda x: isinstance(x, SplitRealAndImaginaryArrays)) + return updates, SplitRealAndImaginaryState(inner_state) + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/experimental/complex_valued_test.py b/optax_add_eve/_src/experimental/complex_valued_test.py new file mode 100644 index 00000000..57ad98e1 --- /dev/null +++ b/optax_add_eve/_src/experimental/complex_valued_test.py @@ -0,0 +1,79 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `complex_valued.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import transform +from optax_add_eve._src import update +from optax_add_eve._src.experimental import complex_valued + + +def _loss_fun_complex_to_real(z): + return (z.conj() * z).real.sum() + + +def _loss_fun_real_to_real(params): + x, y = params + return _loss_fun_complex_to_real(x + y * 1j) + + +class ComplexValuedTest(parameterized.TestCase): + + @chex.all_variants + @parameterized.named_parameters([ + ('adam', transform.scale_by_adam), + ('param_block_norm', transform.scale_by_param_block_norm), + ]) + def test_split_real_and_imaginary(self, scaler_constr): + + def do_update(loss_fun, optimizer, params, opt_state): + loss, grads = jax.value_and_grad(loss_fun)(params) + # Complex gradients need to be conjugated before being added to parameters + grads = jax.tree_util.tree_map(lambda x: x.conj(), grads) + updates, opt_state = self.variant(optimizer.update)( + grads, opt_state, params) + params = update.apply_updates(params, updates) + return loss, grads, params, opt_state + + x = jnp.array([[0.1, 0.2, 0.3], [-0.1, -0.2, -0.3]]) + y = jnp.array([[0.5, -0.5, 0], [0.1, 0.3, -0.2]]) + z = x + y * 1j + + optimizer = scaler_constr() + optimizer_complex = complex_valued.split_real_and_imaginary(optimizer) + opt_state = self.variant(optimizer.init)((x, y)) + opt_state_complex = self.variant(optimizer_complex.init)(z) + + # Check that the loss, the gradients, and the parameters are the same for + # real-to-real and complex-to-real loss functions in each step + for _ in range(3): + loss, (gx, gy), (x, y), opt_state = do_update( + _loss_fun_real_to_real, optimizer, (x, y), opt_state) + loss_complex, gz, z, opt_state_complex = do_update( + _loss_fun_complex_to_real, optimizer_complex, z, opt_state_complex) + np.testing.assert_allclose(loss, loss_complex) + np.testing.assert_allclose(gx + gy * 1j, gz) + np.testing.assert_allclose(x + y * 1j, z) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/experimental/extra_args.py b/optax_add_eve/_src/experimental/extra_args.py new file mode 100644 index 00000000..7264fbc0 --- /dev/null +++ b/optax_add_eve/_src/experimental/extra_args.py @@ -0,0 +1,167 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for extra kwargs in a gradient transformation's `init` and `update`. + +Some users have the need to condition the behavior of a gradient +transformations on dynamical quantities such as the loss. With this experimental +feature we support passing additional kwargs to `init` and `update`. + +We introduce `GradientTransformationWithExtraArgs` as an experimental feature. +You can use the new `named_chain` to combine both old-style and new-style +transformations. We will then monitor users to understand how they use it and +gather feedback from optax users before merging this into the stable API. +""" + +from typing import Any, Mapping, Optional, Tuple, Union, NamedTuple + +from optax_add_eve._src import base +import typing_extensions + + +class InitFnWithExtraArgs(typing_extensions.Protocol): + """Like `TransformInitFn` but with optional `extra_args`.""" + + def __call__( + self, + params: base.Params, + *, + extra_args: Optional[Mapping[str, Any]] = None, + ) -> base.OptState: + """The `init` function.""" + + +class UpdateFnWithExtraArgs(typing_extensions.Protocol): + """Like `TransformUpdateFn` but with optional `extra_args`.""" + + def __call__( + self, + updates: base.Updates, + state: base.OptState, + params: Optional[base.Params] = None, + *, + extra_args: Optional[Mapping[str, Any]] = None, + ) -> Tuple[base.Updates, base.OptState]: + """The `update` function.""" + + +class GradientTransformationWithExtraArgs(NamedTuple): + """A pair of pure functions implementing a gradient transformation. + + GradientTransformationWithExtraArgs is just like GradientTransformation but + both the `init` and `update` functions accept an additional `extra_args` dict. + This can be used to provide additional dynamic information that is not + computed by the GradientTransformation itself (e.g. loss) but that may be + needed by specific algorithms. + """ + init: InitFnWithExtraArgs + update: UpdateFnWithExtraArgs + + +AnyGradientTransformation = Union[ + base.GradientTransformation, GradientTransformationWithExtraArgs] +NamedTransform = Tuple[str, AnyGradientTransformation] + + +def named_chain( + *transforms: NamedTransform) -> GradientTransformationWithExtraArgs: + """Chains optax gradient transformations. + + The `transforms` are `(name, transformation)` pairs, constituted of a string + `name` and an associated gradient transformation `transformation`. The + gradient transformation must be an instance of either + `GradientTransformation` or `GradientTransformationWithExtraArgs`. + + Each `name` is used as key for the state of the corresponding transformation + within the `named_chain` state. Thus the state of the gradient transformation + with a given `name` can be retrieved as `opt_state[name]`. + + The `named_chain` accepts an `extra_args` meta-dictionary whose fields are + the transformations' names and its values are the corresponding extra_args: + + Example: + tx = named_chain(('one', tx1), ('two', tx2)) + + extra_args={ + 'one': {'loss': 0.1}, + 'two': {'loss': 0.3, 'temperature': 0.01}} + tx.init(params, extra_args=extra_args} + tx.update(grads, state, params, extra_args=extra_args) + + # tx1 receives {'loss': 0.1} as extra_args + # tx2 receives {'loss': 0.3, 'temperature': 0.01} as extra_args + + If one of the transformations does not need extra_args the corresponding + name can just be skipped in the `named_chain` extra_args: + + Example: + tx = named_chain(('one', tx1), ('two', tx2)) + + extra_args={'one': {'loss': 0.1}} + tx.init(params, extra_args=extra_args} + tx.update(grads, state, params, extra_args=extra_args) + + # tx1 receives {'loss': 0.1} as extra_args. + # tx2 is called without passing the extra_args. + + Args: + *transforms: an arbitrary number of `(name, tx)` pairs, constituted of a + string `name` and an associated gradient transformation `tx`. The latter + is a `GradientTransformation` or `GradientTransformationWithExtraArgs`. + + Returns: + A single (init_fn, update_fn) tuple. + """ + + names = [name for name, _ in transforms] + if len(names) != len(set(names)): + raise ValueError( + f'Named transformations must have unique names, but got {names}') + + def init_fn(params, *, extra_args=None): + states = {} + for (name, tx) in transforms: + _assert_is_gradient_transformation(tx) + if (extra_args is not None and + isinstance(tx, GradientTransformationWithExtraArgs)): + states[name] = tx.init( + params, extra_args=extra_args.get(name)) + else: + states[name] = tx.init(params) + return states + + def update_fn(updates, state, params=None, *, extra_args=None): + new_state = {} + for (name, tx) in transforms: + _assert_is_gradient_transformation(tx) + if (extra_args is not None and + isinstance(tx, GradientTransformationWithExtraArgs)): + updates, new_state[name] = tx.update( + updates, state[name], params, extra_args=extra_args.get(name)) + else: + updates, new_state[name] = tx.update(updates, state[name], params) + return updates, new_state + + return GradientTransformationWithExtraArgs(init_fn, update_fn) + + +def _assert_is_gradient_transformation(tx): + valid_types = ( + base.GradientTransformation, + GradientTransformationWithExtraArgs) + if not isinstance(tx, valid_types): + raise ValueError( + 'The transformation `tx` must be a valid gradient transformation, ' + 'that is an instance of either `GradientTransformation` or ' + 'an instance of `GradientTransformationWithExtraArgs`') diff --git a/optax_add_eve/_src/experimental/extra_args_test.py b/optax_add_eve/_src/experimental/extra_args_test.py new file mode 100644 index 00000000..b24fca48 --- /dev/null +++ b/optax_add_eve/_src/experimental/extra_args_test.py @@ -0,0 +1,65 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for extra_kwargs.""" + +from absl.testing import absltest +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base +from optax_add_eve._src import transform +from optax_add_eve._src.experimental import extra_args as extra + + +def scale_by_loss(): + """Scale the gradient by the absolute value of the loss.""" + + def init_fn(params, *, extra_args): + del params, extra_args + return base.EmptyState() + + def update_fn(updates, state, params, *, extra_args): + del params + updates = jax.tree_util.tree_map( + lambda u: u / extra_args['loss'], updates) + return updates, state + + return extra.GradientTransformationWithExtraArgs(init_fn, update_fn) + + +class ExtraArgsTest(absltest.TestCase): + + def test_named_chain(self): + tx = extra.named_chain( + ('scale', transform.scale(0.1)), + ('scale_by_policy_loss', scale_by_loss()), + ('scale_by_value_loss', scale_by_loss()), + ) + + params = {'a': jnp.ones((4,))} + grads = params + extra_args = { + 'scale_by_policy_loss': {'loss': 0.01}, + 'scale_by_value_loss': {'loss': 10.0}} + + opt_state = tx.init(params, extra_args=extra_args) + updates, opt_state = tx.update( + grads, opt_state, params, extra_args=extra_args) + chex.assert_trees_all_close(updates, {'a': jnp.ones((4,))}) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/factorized.py b/optax_add_eve/_src/factorized.py new file mode 100644 index 00000000..b3bbec45 --- /dev/null +++ b/optax_add_eve/_src/factorized.py @@ -0,0 +1,199 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Factorized optimizers.""" + +import dataclasses +from typing import NamedTuple, Optional, Tuple, Callable + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import base +from optax_add_eve._src import numerics +from optax_add_eve._src import utils + +# pylint:disable=no-value-for-parameter + + +def _decay_rate_pow(i: int, exponent: float = 0.8) -> float: + """Second-order moment decay schedule.""" + t = jnp.array(i, jnp.float32) + 1.0 + return 1.0 - t**(-exponent) + + +def _factored_dims( + shape: base.Shape, + factored: bool, + min_dim_size_to_factor: int +) -> Optional[Tuple[int, int]]: + """Whether to use a factored second moment estimator. + + This function returns a tuple with the two largest axes to reduce over. + If no two dimensions have size >= min_dim_size_to_factor, return None. + + Args: + shape: an input shape + factored: whether to use factored second-moment estimator for 2d vars. + min_dim_size_to_factor: only factor accumulator if two array dimensions + have at least this size. + + Returns: + None or a tuple of ints + """ + if not factored or len(shape) < 2: + return None + sorted_dims = np.argsort(shape) + if shape[sorted_dims[-2]] < min_dim_size_to_factor: + return None + return int(sorted_dims[-2]), int(sorted_dims[-1]) + + +@dataclasses.dataclass +class _UpdateResult: + """Opaque containter that is not traversed by jax.tree_util.tree_map.""" + update: chex.Array # the update to apply to params + v_row: chex.Array # used for factored params. + v_col: chex.Array # used for factored params. + v: chex.Array # used for params where factoring is skipped. + + +class FactoredState(NamedTuple): + """Overall state of the gradient transformation.""" + count: chex.Array # number of update steps. + v_row: chex.ArrayTree # Tree of factored params. + v_col: chex.ArrayTree # Tree of factored params. + v: chex.ArrayTree # Tree for params where factoring is skipped. + + +def scale_by_factored_rms( + factored: bool = True, + decay_rate: float = 0.8, + step_offset: int = 0, + min_dim_size_to_factor: int = 128, + epsilon: float = 1e-30, + decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow): + """Scaling by a factored estimate of the gradient rms (as in Adafactor). + + This is a so-called "1+epsilon" scaling algorithms, that is extremely memory + efficient compared to RMSProp/Adam, and has had wide success when applied to + large-scale training of attention-based models. + + References: + [Shazeer et al, 2018](https://arxiv.org/abs/1804.04235) + + Args: + factored: boolean: whether to use factored second-moment estimates.. + decay_rate: float: controls second-moment exponential decay schedule. + step_offset: for finetuning, one may set this to the starting step-number + of the fine tuning phase. + min_dim_size_to_factor: only factor accumulator if two array dimensions + are at least this size. + epsilon: Regularization constant for squared gradient. + decay_rate_fn: A function that accepts the current step, the decay rate + parameter and controls the schedule for the second momentum. Defaults to + the original adafactor's power decay schedule. One potential shortcoming + of the orignal schedule is the fact that second momentum converges to 1, + which effectively freezes the second momentum. To prevent this the user + can opt for a custom schedule that sets an upper bound for the second + momentum, like in [Zhai et al., 2021](https://arxiv.org/abs/2106.04560). + + Returns: + the corresponding `GradientTransformation`. + """ + + def _to_state(count: chex.Array, result_tree): + """Maps from a tree of (factored) values to separate trees of values.""" + return FactoredState( + count=count, + v_row=jax.tree_util.tree_map(lambda o: o.v_row, result_tree), + v_col=jax.tree_util.tree_map(lambda o: o.v_col, result_tree), + v=jax.tree_util.tree_map(lambda o: o.v, result_tree)) + + def init_fn(params): + """Initialise the optimiser's state.""" + + def _init(param): + shape = param.shape + factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) + if factored_dims is not None: + d1, d0 = factored_dims + vr_shape = np.delete(shape, d0) + vc_shape = np.delete(shape, d1) + return _UpdateResult( + update=jnp.zeros((1,)), + v_row=jnp.zeros(vr_shape), + v_col=jnp.zeros(vc_shape), + v=jnp.zeros((1,))) + else: + return _UpdateResult( + update=jnp.zeros((1,)), + v_row=jnp.zeros((1,)), + v_col=jnp.zeros((1,)), + v=jnp.zeros(param.shape)) + + return _to_state( + jnp.zeros([], jnp.int32), jax.tree_util.tree_map(_init, params)) + + def update_fn(grads, state, params): + """Apply gradient transformation.""" + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + + def _update(grad, v_row, v_col, v, param, step): + shape = param.shape + decay_rate_t = decay_rate_fn(step - step_offset, decay_rate) + + # Scaled by factorized second moment statistics. + new_v_row = jnp.zeros((1,)) + new_v_col = jnp.zeros((1,)) + new_v = jnp.zeros((1,)) + + factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) + if factored_dims is not None: + d1, d0 = factored_dims + grad_sqr = numerics.abs_sq(grad) + epsilon + new_v_row = ( + decay_rate_t * v_row + + (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d0)) + new_v_col = ( + decay_rate_t * v_col + + (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1)) + reduced_d1 = d1-1 if d1 > d0 else d1 + row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) + row_factor = (new_v_row / row_col_mean) ** -0.5 + col_factor = (new_v_col) ** -0.5 + update = ( + grad * + jnp.expand_dims(row_factor, axis=d0) * + jnp.expand_dims(col_factor, axis=d1)) + else: + grad_sqr = numerics.abs_sq(grad) + epsilon + new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr + update = grad * (new_v)**-0.5 + + return _UpdateResult(update, new_v_row, new_v_col, new_v) + + # Transform grad and compute new per-parameter stats. + output = jax.tree_util.tree_map( + lambda *args: _update(*args, state.count), + grads, state.v_row, state.v_col, state.v, params) + + # Unpack updates / stats and return. + updates = jax.tree_util.tree_map(lambda o: o.update, output) + return updates, _to_state(utils.safe_int32_increment(state.count), output) + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/factorized_test.py b/optax_add_eve/_src/factorized_test.py new file mode 100644 index 00000000..d0e2f90a --- /dev/null +++ b/optax_add_eve/_src/factorized_test.py @@ -0,0 +1,45 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `factorized.py`.""" + +from absl.testing import parameterized + +import chex +import jax.numpy as jnp + +from optax_add_eve._src import factorized + + +class FactorizedTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + + @chex.all_variants + def test_scale_by_factored_rms(self): + params = self.init_params + + scaler = factorized.scale_by_factored_rms() + init_fn = self.variant(scaler.init) + transform_fn = self.variant(scaler.update) + + state = init_fn(params) + chex.assert_tree_all_finite(state) + + updates, state = transform_fn(self.per_step_updates, state, params) + chex.assert_tree_all_finite((params, updates, state)) + chex.assert_tree_all_equal_shapes(params, updates) diff --git a/optax_add_eve/_src/float64_test.py b/optax_add_eve/_src/float64_test.py new file mode 100644 index 00000000..9f22516d --- /dev/null +++ b/optax_add_eve/_src/float64_test.py @@ -0,0 +1,94 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that types are preserved by the `update` calls when jax_enbable_x64.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +from jax.config import config +import jax.numpy as jnp + +from optax_add_eve._src import alias +from optax_add_eve._src import base +from optax_add_eve._src import clipping +from optax_add_eve._src import transform +from optax_add_eve._src import update + + +ALL_MODULES = [ + ('identity', base.identity, {}), + ('clip', clipping.clip, dict(max_delta=1.0)), + ('clip_by_global_norm', clipping.clip_by_global_norm, dict(max_norm=1.0)), + ('trace', transform.trace, dict(decay=0.5, nesterov=False)), + ('trace_with_nesterov', transform.trace, dict(decay=0.5, nesterov=True)), + ('scale_by_rss', transform.scale_by_rss, {}), + ('scale_by_rms', transform.scale_by_rms, {}), + ('scale_by_stddev', transform.scale_by_stddev, {}), + ('adam', transform.scale_by_adam, {}), + ('scale', transform.scale, dict(step_size=3.0)), + ('additive_weight_decay', transform.additive_weight_decay, + dict(weight_decay=0.1)), + ('scale_by_schedule', transform.scale_by_schedule, + dict(step_size_fn=lambda x: x * 0.1)), + ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), + ('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)), + ('apply_every_k', transform.apply_every, {}), + ('adagrad', alias.adagrad, dict(learning_rate=0.1)), + ('adam', alias.adam, dict(learning_rate=0.1)), + ('adamw', alias.adamw, dict(learning_rate=0.1)), + ('fromage', alias.fromage, dict(learning_rate=0.1)), + ('lamb', alias.lamb, dict(learning_rate=0.1)), + ('noisy_sgd', alias.noisy_sgd, dict(learning_rate=0.1)), + ('rmsprop', alias.rmsprop, dict(learning_rate=0.1)), + ('sgd', alias.sgd, dict(learning_rate=0.1)), + ('dpsgd', alias.dpsgd, + dict(learning_rate=0.1, l2_norm_clip=0.9, noise_multiplier=1.1, seed=42)), +] + + +class Float64Test(parameterized.TestCase): + + def _assert_dtype_equals(self, tree1, tree2): + tree1_types = jax.tree_util.tree_map(lambda t: t.dtype, tree1) + tree2_types = jax.tree_util.tree_map(lambda t: t.dtype, tree2) + self.assertEqual(tree1_types, tree2_types) + + @chex.all_variants + @parameterized.named_parameters(ALL_MODULES) + def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs): + initial_params = ( + jnp.array([1., 2.], dtype=jnp.float32), + jnp.array([3., 4.], dtype=jnp.float64)) + updates = ( + jnp.array([10., 21.], dtype=jnp.float32), + jnp.array([33., 42.], dtype=jnp.float64)) + scaler = transform_constr(**transform_kwargs) + init_fn = self.variant(scaler.init) + update_fn = self.variant(scaler.update) + + initial_state = init_fn(initial_params) + updates, new_state = update_fn( + updates, initial_state, params=initial_params) + new_params = update.apply_updates(initial_params, updates) + + self._assert_dtype_equals(initial_state, new_state) + self._assert_dtype_equals(initial_params, new_params) + + +if __name__ == '__main__': + config.update('jax_enable_x64', True) + absltest.main() diff --git a/optax_add_eve/_src/linear_algebra.py b/optax_add_eve/_src/linear_algebra.py new file mode 100644 index 00000000..0caedd69 --- /dev/null +++ b/optax_add_eve/_src/linear_algebra.py @@ -0,0 +1,201 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra utilities used in optimisation.""" + +import chex +import jax +from jax import lax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import base +from optax_add_eve._src import numerics + + +def global_norm(updates: base.Updates) -> base.Updates: + """Compute the global norm across a nested structure of tensors.""" + return jnp.sqrt(sum( + jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates))) + + +def power_iteration(matrix: chex.Array, + num_iters: int = 100, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST): + r"""Power iteration algorithm. + + The power iteration algorithm takes a symmetric PSD matrix `A`, and produces + a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue + of `A`, and a vector v, which is the corresponding eigenvector of `A`. + + References: + [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) + + Args: + matrix: the symmetric PSD matrix. + num_iters: Number of iterations. + error_tolerance: Iterative exit condition. + precision: precision XLA related flag, the available options are: + a) lax.Precision.DEFAULT (better step time, but not precise); + b) lax.Precision.HIGH (increased precision, slower); + c) lax.Precision.HIGHEST (best possible precision, slowest). + + Returns: + eigen vector, eigen value + """ + matrix_size = matrix.shape[-1] + def _iter_condition(state): + i, unused_v, unused_s, unused_s_v, run_step = state + return jnp.logical_and(i < num_iters, run_step) + + def _iter_body(state): + """One step of power iteration.""" + i, new_v, s, s_v, unused_run_step = state + new_v = new_v / jnp.linalg.norm(new_v) + + s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision) + s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision) + return (i + 1, s_v, s_new, s_v, + jnp.greater(jnp.abs(s_new - s), error_tolerance)) + + # Figure out how to use step as seed for random. + v_0 = np.random.uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype) + + init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) + _, v_out, s_out, _, _ = lax.while_loop( + _iter_condition, _iter_body, init_state) + v_out = v_out / jnp.linalg.norm(v_out) + return v_out, s_out + + +def matrix_inverse_pth_root(matrix: chex.Array, + p: int, + num_iters: int = 100, + ridge_epsilon: float = 1e-6, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST): + """Computes `matrix^(-1/p)`, where `p` is a positive integer. + + This function uses the Coupled newton iterations algorithm for + the computation of a matrix's inverse pth root. + + + References: + [Functions of Matrices, Theory and Computation, + Nicholas J Higham, Pg 184, Eq 7.18]( + https://epubs.siam.org/doi/book/10.1137/1.9780898717778) + + Args: + matrix: the symmetric PSD matrix whose power it to be computed + p: exponent, for p a positive integer. + num_iters: Maximum number of iterations. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + error_tolerance: Error indicator, useful for early termination. + precision: precision XLA related flag, the available options are: + a) lax.Precision.DEFAULT (better step time, but not precise); + b) lax.Precision.HIGH (increased precision, slower); + c) lax.Precision.HIGHEST (best possible precision, slowest). + + Returns: + matrix^(-1/p) + """ + + # We use float32 for the matrix inverse pth root. + # Switch to f64 if you have hardware that supports it. + matrix_size = matrix.shape[0] + alpha = jnp.asarray(-1.0 / p, jnp.float32) + identity = jnp.eye(matrix_size, dtype=jnp.float32) + _, max_ev = power_iteration( + matrix=matrix, num_iters=100, + error_tolerance=1e-6, precision=precision) + ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) + + def _unrolled_mat_pow_1(mat_m): + """Computes mat_m^1.""" + return mat_m + + def _unrolled_mat_pow_2(mat_m): + """Computes mat_m^2.""" + return jnp.matmul(mat_m, mat_m, precision=precision) + + def _unrolled_mat_pow_4(mat_m): + """Computes mat_m^4.""" + mat_pow_2 = _unrolled_mat_pow_2(mat_m) + return jnp.matmul( + mat_pow_2, mat_pow_2, precision=precision) + + def _unrolled_mat_pow_8(mat_m): + """Computes mat_m^4.""" + mat_pow_4 = _unrolled_mat_pow_4(mat_m) + return jnp.matmul( + mat_pow_4, mat_pow_4, precision=precision) + + def mat_power(mat_m, p): + """Computes mat_m^p, for p == 1, 2, 4 or 8. + + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + # We unrolled the loop for performance reasons. + exponent = jnp.round(jnp.log2(p)) + return lax.switch( + jnp.asarray(exponent, jnp.int32), [ + _unrolled_mat_pow_1, + _unrolled_mat_pow_2, + _unrolled_mat_pow_4, + _unrolled_mat_pow_8, + ], (mat_m)) + + def _iter_condition(state): + (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, + run_step) = state + error_above_threshold = jnp.logical_and( + error > error_tolerance, run_step) + return jnp.logical_and(i < num_iters, error_above_threshold) + + def _iter_body(state): + (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state + mat_m_i = (1 - alpha) * identity + alpha * mat_m + new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) + new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) + new_error = jnp.max(jnp.abs(new_mat_m - identity)) + # sometimes error increases after an iteration before decreasing and + # converging. 1.2 factor is used to bound the maximal allowed increase. + return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, + new_error < error * 1.2) + + if matrix_size == 1: + resultant_mat_h = (matrix + ridge_epsilon)**alpha + error = 0 + else: + damped_matrix = matrix + ridge_epsilon * identity + + z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) + new_mat_m_0 = damped_matrix * z + new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) + new_mat_h_0 = identity * jnp.power(z, 1.0 / p) + init_state = tuple( + [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) + _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( + _iter_condition, _iter_body, init_state) + error = jnp.max(jnp.abs(mat_m - identity)) + is_converged = jnp.asarray(convergence, old_mat_h.dtype) + resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h + resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) + return resultant_mat_h, error diff --git a/optax_add_eve/_src/linear_algebra_test.py b/optax_add_eve/_src/linear_algebra_test.py new file mode 100644 index 00000000..5ad8172b --- /dev/null +++ b/optax_add_eve/_src/linear_algebra_test.py @@ -0,0 +1,62 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax._src.linear_algebra.""" + +from absl.testing import absltest + +import jax.numpy as jnp +import numpy as np +from optax_add_eve._src import linear_algebra +import scipy.stats + + +class LinearAlgebraTest(absltest.TestCase): + + def test_global_norm(self): + flat_updates = jnp.array([2., 4., 3., 5.], dtype=jnp.float32) + nested_updates = dict( + a=jnp.array([2., 4.], dtype=jnp.float32), + b=jnp.array([3., 5.], dtype=jnp.float32)) + np.testing.assert_array_equal( + jnp.sqrt(jnp.sum(flat_updates**2)), + linear_algebra.global_norm(nested_updates)) + + def test_matrix_inverse_pth_root(self): + """Test for matrix inverse pth root.""" + + def _gen_symmetrix_matrix(dim, condition_number): + u = scipy.stats.ortho_group.rvs(dim=dim).astype(np.float64) + v = u.T + diag = np.diag([condition_number ** (-i/(dim-1)) for i in range(dim)]) + return u @ diag @ v + + # Fails after it reaches a particular condition number. + for e in range(2, 12): + condition_number = 10 ** e + ms = _gen_symmetrix_matrix(16, condition_number) + self.assertLess( + np.abs(np.linalg.cond(ms) - condition_number), + condition_number * 0.01) + error = linear_algebra.matrix_inverse_pth_root( + ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1] + if e < 7: + self.assertLess(error, 0.1) + else: + # No guarantee of success after e >= 7 + pass + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/lookahead.py b/optax_add_eve/_src/lookahead.py new file mode 100644 index 00000000..97b3a6e9 --- /dev/null +++ b/optax_add_eve/_src/lookahead.py @@ -0,0 +1,192 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A lookahead optimization wrapper.""" + +from typing import NamedTuple, Tuple + +from absl import logging +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base + +# pylint:disable=no-value-for-parameter + + +class LookaheadState(NamedTuple): + """State of the `GradientTransformation` returned by `lookahead`. + + Attributes: + fast_state: Optimizer state of the fast optimizer. + steps_since_sync: Number of fast optimizer steps taken since slow and fast + parameters were synchronized. + """ + fast_state: base.OptState + steps_since_sync: jnp.ndarray + + +class LookaheadParams(NamedTuple): + """Holds a pair of slow and fast parameters for the lookahead optimizer. + + Gradients should always be calculated with the fast parameters. The slow + parameters should be used for testing and inference as they generalize better. + See the reference for a detailed discussion. + + References: + [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) + + Attributes: + fast: Fast parameters. + slow: Slow parameters. + """ + fast: base.Params + slow: base.Params + + @classmethod + def init_synced(cls, params: base.Params) -> 'LookaheadParams': + """Initialize a pair of synchronized lookahead parameters.""" + return cls(slow=params, fast=params) + + +def lookahead( + fast_optimizer: base.GradientTransformation, + sync_period: int, + slow_step_size: float, + reset_state: bool = False +) -> base.GradientTransformation: + """Lookahead optimizer. + + Performs steps with a fast optimizer and periodically updates a set of slow + parameters. Optionally resets the fast optimizer state after synchronization + by calling the init function of the fast optimizer. + + Updates returned by the lookahead optimizer should not be modified before they + are applied, otherwise fast and slow parameters are not synchronized + correctly. + + References: + [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) + + Args: + fast_optimizer: The optimizer to use in the inner loop of lookahead. + sync_period: Number of fast optimizer steps to take before synchronizing + parameters. Must be >= 1. + slow_step_size: Step size of the slow parameter updates. + reset_state: Whether to reset the optimizer state of the fast opimizer after + each synchronization. + + Returns: + A `GradientTransformation` with init and update functions. The updates + passed to the update function should be calculated using the fast lookahead + parameters only. + """ + if sync_period < 1: + raise ValueError('Synchronization period must be >= 1.') + + def init_fn(params: base.Params) -> LookaheadState: + try: + fast_params = params.fast + except AttributeError: + # Allowing init_fn to be called with fast parameters reduces the + # modifications necessary to adapt code to use lookahead in some cases. + logging.warning( + '`params` has no attribute `fast`. Continuing by assuming that ' + 'only fast parameters were passed to lookahead init.') + fast_params = params + + return LookaheadState( + fast_state=fast_optimizer.init(fast_params), + steps_since_sync=jnp.zeros(shape=(), dtype=jnp.int32)) + + def update_fn( + updates: base.Updates, state: LookaheadState, + params: LookaheadParams) -> Tuple[LookaheadParams, LookaheadState]: + updates, fast_state = fast_optimizer.update(updates, state.fast_state, + params.fast) + + sync_next = (state.steps_since_sync == sync_period - 1) + updates = _lookahead_update(updates, sync_next, params, slow_step_size) + if reset_state: + # Jittable way of resetting the fast optimizer state if parameters will be + # synchronized after this update step. + initial_state = fast_optimizer.init(params.fast) + fast_state = jax.tree_util.tree_map( + lambda current, init: (1 - sync_next) * current + sync_next * init, + fast_state, initial_state) + + steps_since_sync = (state.steps_since_sync + 1) % sync_period + return updates, LookaheadState(fast_state, steps_since_sync) + + return base.GradientTransformation(init_fn, update_fn) + + +def _lookahead_update( + updates: base.Updates, sync_next: bool, params: LookaheadParams, + slow_step_size: float) -> LookaheadParams: + """Returns the updates corresponding to one lookahead step. + + References: + [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) + + Args: + updates: Updates returned by the fast optimizer. + sync_next: Wether fast and slow parameters should be synchronized after the + fast optimizer step. + params: Current fast and slow parameters as `LookaheadParams` object. + slow_step_size: Step size of the slow optimizer. + + Returns: + The updates for the lookahead parameters. + """ + # In the paper, lookahead is presented as two nested loops. To write lookahead + # as optax wrapper, these loops have to be broken into successive updates. + # This leads to two types of update steps: + # + # Non-synchronization steps (sync_next == False): + # The updates returned by the fast optimizer are used for the fast parameters + # without change and the slow parameter updates are zero (i.e. fast_updates = + # updates, slow_updates = 0). + # + # Synchronisation step (sync_next == True): + # This consists of two substeps: a last fast optimizer step and the + # synchronization. + # Substep 1 (last fast optimizer step): + # last_fast_params = fast_params + updates + # Substep 2 (synchronization): + # new_slow_params = slow_params + slow_step_size * ( + # last_fast_params - slow_params) + # new_fast_params = new_slow_params + # + # Merging into a single update step we get the update rules: + # slow_updates = slow_step_size * (fast_params + updates - slow_params) + # fast_updates = new_slow_params - fast_params = updates - (1 - + # slow_step_size) * (fast_params + updates - slow_params) + # + # To make the equations jittable, the two types of steps are merged. Defining + # last_difference = fast_params + updates - slow_params, this yields the + # following equtions which are implemented below: + # slow_updates = slow_step_size * sync_next * last_difference + # fast_updates = updates - ( + # 1 - slow_step_size) * sync_next * last_difference + last_difference = jax.tree_util.tree_map( + lambda f, u, s: f + u - s, params.fast, updates, params.slow) + slow_updates = jax.tree_util.tree_map( + lambda diff: slow_step_size * sync_next * diff, last_difference) + fast_updates = jax.tree_util.tree_map( + lambda up, diff: up - sync_next * (1 - slow_step_size) * diff, updates, + last_difference) + + return LookaheadParams(fast=fast_updates, slow=slow_updates) + diff --git a/optax_add_eve/_src/lookahead_test.py b/optax_add_eve/_src/lookahead_test.py new file mode 100644 index 00000000..99964a1d --- /dev/null +++ b/optax_add_eve/_src/lookahead_test.py @@ -0,0 +1,140 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `lookahead.py`.""" + +from typing import NamedTuple + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax +import jax.numpy as jnp +import numpy as np +from optax_add_eve._src import alias +from optax_add_eve._src import base +from optax_add_eve._src import lookahead +from optax_add_eve._src import update + + +def _build_sgd(): + return alias.sgd(1.) + + +class TestOptimizerState(NamedTuple): + """Fast optimizer state for the lookahead tests.""" + aggregate_grads: base.Params + # Include a variable with non-zero initial value to check that it is reset + # correctly by the lookahead optimizer. + is_reset: bool = True + + +def _test_optimizer(step_size: float) -> base.GradientTransformation: + """Fast optimizer for the lookahead tests.""" + + # Use SGD for simplicity but add non-trivial optimizer state so that the + # resetting behaviour of lookahead can be tested. + def init_fn(params): + aggregate_grads = jax.tree_util.tree_map(jnp.zeros_like, params) + return TestOptimizerState(aggregate_grads, is_reset=True) + + def update_fn(updates, state, params): + # The test optimizer does not use the parameters, but we check that they + # have been passed correctly. + chex.assert_trees_all_equal_shapes(updates, params) + aggregate_grads = update.apply_updates(state.aggregate_grads, updates) + updates = jax.tree_util.tree_map(lambda u: step_size * u, updates) + return updates, TestOptimizerState(aggregate_grads, is_reset=False) + + return base.GradientTransformation(init_fn, update_fn) + + +class LookaheadTest(chex.TestCase): + """Tests for the lookahead optimizer.""" + + def setUp(self): + super().setUp() + self.grads = {'x': np.array(2.), 'y': np.array(-2.)} + self.initial_params = {'x': np.array(3.), 'y': np.array(-3.)} + self.synced_initial_params = lookahead.LookaheadParams.init_synced( + self.initial_params) + + def loop(self, optimizer, num_steps, params): + """Performs a given number of optimizer steps.""" + init_fn, update_fn = optimizer + # Use the chex variant to check various function versions (jit, pmap, etc). + step = self.variant(update_fn) + opt_state = self.variant(init_fn)(params) + for _ in range(num_steps): + updates, opt_state = step(self.grads, opt_state, params) + params = update.apply_updates(params, updates) + + return params, opt_state + + @chex.all_variants + def test_lookahead(self): + """Tests the lookahead optimizer in an analytically tractable setting.""" + sync_period = 3 + optimizer = lookahead.lookahead( + _test_optimizer(-0.5), sync_period=sync_period, slow_step_size=1 / 3) + + final_params, _ = self.loop(optimizer, 2 * sync_period, + self.synced_initial_params) + # x steps must be: 3 -> 2 -> 1 -> 2 (sync) -> 1 -> 0 -> 1 (sync). + # Similarly for y (with sign flipped). + correct_final_params = {'x': 1, 'y': -1} + chex.assert_trees_all_close(final_params.slow, correct_final_params) + + @chex.all_variants + @parameterized.parameters([False], [True]) + def test_lookahead_state_reset(self, reset_state): + """Checks that lookahead resets the fast optimizer state correctly.""" + num_steps = sync_period = 3 + fast_optimizer = _test_optimizer(-0.5) + optimizer = lookahead.lookahead( + fast_optimizer, + sync_period=sync_period, + slow_step_size=0.5, + reset_state=reset_state) + + _, opt_state = self.loop(optimizer, num_steps, self.synced_initial_params) + fast_state = opt_state.fast_state + if reset_state: + correct_state = fast_optimizer.init(self.initial_params) + else: + _, correct_state = self.loop(fast_optimizer, num_steps, + self.initial_params) + + chex.assert_trees_all_close(fast_state, correct_state) + + @chex.all_variants + @parameterized.parameters( + [1, 0.5, {'x': np.array(1.), 'y': np.array(-1.)}], + [1, 0, {'x': np.array(3.), 'y': np.array(-3.)}], + [1, 1, {'x': np.array(-1.), 'y': np.array(1.)}], + [2, 1, {'x': np.array(-1.), 'y': np.array(1.)}]) # pyformat: disable + def test_lookahead_edge_cases(self, sync_period, slow_step_size, + correct_result): + """Checks special cases of the lookahed optimizer parameters.""" + # These edge cases are important to check since users might use them as + # simple ways of disabling lookahead in experiments. + optimizer = lookahead.lookahead( + _test_optimizer(-1), sync_period, slow_step_size) + final_params, _ = self.loop( + optimizer, num_steps=2, params=self.synced_initial_params) + chex.assert_trees_all_close(final_params.slow, correct_result) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/loss.py b/optax_add_eve/_src/loss.py new file mode 100644 index 00000000..578741f1 --- /dev/null +++ b/optax_add_eve/_src/loss.py @@ -0,0 +1,521 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Standard losses used in optimisation. + +We provide implementations of the most canonical losses used in deep +learning. These operate transparently on batches, and do not perform any +reduction over the batch dimensions, leaving it to the user to, for instance, +mean or sum losses across batch dimensions. +""" + +from typing import Optional, Tuple + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import utils + + +def l2_loss( + predictions: chex.Array, + targets: Optional[chex.Array] = None, +) -> chex.Array: + """Calculates the L2 loss for a set of predictions. + + Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning" + by Bishop, but not "The Elements of Statistical Learning" by Tibshirani. + + References: + [Chris Bishop, 2006](https://bit.ly/3eeP0ga) + + Args: + predictions: a vector of arbitrary shape `[...]`. + targets: a vector with shape broadcastable to that of `predictions`; + if not provided then it is assumed to be a vector of zeros. + + Returns: + elementwise squared differences, with same shape as `predictions`. + """ + chex.assert_type([predictions], float) + if targets is not None: + # Avoid broadcasting logic for "-" operator. + chex.assert_equal_shape((predictions, targets)) + errors = (predictions - targets) if (targets is not None) else predictions + return 0.5 * (errors)**2 + + +def huber_loss( + predictions: chex.Array, + targets: Optional[chex.Array] = None, + delta: float = 1.) -> chex.Array: + """Huber loss, similar to L2 loss close to zero, L1 loss away from zero. + + If gradient descent is applied to the `huber loss`, it is equivalent to + clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. + + References: + [Huber, 1964](www.projecteuclid.org/download/pdf_1/euclid.aoms/1177703732) + + Args: + predictions: a vector of arbitrary shape `[...]`. + targets: a vector with shape broadcastable to that of `predictions`; + if not provided then it is assumed to be a vector of zeros. + delta: the bounds for the huber loss transformation, defaults at 1. + + Returns: + elementwise huber losses, with the same shape of `predictions`. + """ + chex.assert_type([predictions], float) + errors = (predictions - targets) if (targets is not None) else predictions + # 0.5 * err^2 if |err| <= d + # 0.5 * d^2 + d * (|err| - d) if |err| > d + abs_errors = jnp.abs(errors) + quadratic = jnp.minimum(abs_errors, delta) + # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient. + linear = abs_errors - quadratic + return 0.5 * quadratic ** 2 + delta * linear + + +def smooth_labels( + labels: chex.Array, + alpha: float, +) -> jnp.ndarray: + """Apply label smoothing. + + Label smoothing is often used in combination with a cross-entropy loss. + Smoothed labels favour small logit gaps, and it has been shown that this can + provide better model calibration by preventing overconfident predictions. + + References: + [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) + + Args: + labels: one hot labels to be smoothed. + alpha: the smoothing factor, the greedy category with be assigned + probability `(1-alpha) + alpha / num_categories` + + Returns: + a smoothed version of the one hot input labels. + + """ + chex.assert_type([labels], float) + num_categories = labels.shape[-1] + return (1.0 - alpha) * labels + alpha / num_categories + + +def sigmoid_binary_cross_entropy(logits, labels): + """Computes element-wise sigmoid cross entropy given logits and labels. + + This can be used to measure the error in discrete classification tasks in + which each class is an independent binary prediction and different classes + are not mutually exclusive. This may be used for multilabel image + classification for instance a model may predict that an image contains both a + cat and a dog. + + References: + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + + Args: + logits: Each element is the unnormalized log probability of a binary + prediction. + labels: The target probabilities, must have a shape broadcastable to that of + `logits`. + + Returns: + cross entropy for each binary prediction, same shape as `logits`. + """ + chex.assert_type([logits], float) + log_p = jax.nn.log_sigmoid(logits) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable + log_not_p = jax.nn.log_sigmoid(-logits) + return -labels * log_p - (1. - labels) * log_not_p + + +def softmax_cross_entropy( + logits: chex.Array, + labels: chex.Array, +) -> chex.Array: + """Computes the softmax cross entropy between sets of logits and labels. + + Measures the probability error in discrete classification tasks in which + the classes are mutually exclusive (each entry is in exactly one class). + For example, each CIFAR-10 image is labeled with one and only one label: + an image can be a dog or a truck, but not both. + + References: + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + + Args: + logits: Unnormalized log probabilities, with shape `[..., num_classes]`. + labels: Valid probability distributions (non-negative, sum to 1), e.g a + one hot encoding specifying the correct class for each input; + must have a shape broadcastable to `[..., num_classes]`` + + Returns: + cross entropy between each prediction and the corresponding target + distributions, with shape `[...]`. + """ + chex.assert_type([logits], float) + return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) + + +def softmax_cross_entropy_with_integer_labels( + logits: chex.Array, + labels: chex.Array, +) -> chex.Array: + """Computes softmax cross entropy between sets of logits and integer labels. + + Measures the probability error in discrete classification tasks in which + the classes are mutually exclusive (each entry is in exactly one class). + For example, each CIFAR-10 image is labeled with one and only one label: + an image can be a dog or a truck, but not both. + + References: + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + + Args: + logits: Unnormalized log probabilities, with shape `[..., num_classes]`. + labels: Integers specifying the correct class for each input, with shape + `[...]`. + + Returns: + Cross entropy between each prediction and the corresponding target + distributions, with shape `[...]`. + """ + chex.assert_type([logits], float) + chex.assert_type([labels], int) + # This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that + # we avoid subtracting the normalizer from all values, just from the values + # for the correct labels. + logits_max = jnp.max(logits, axis=-1, keepdims=True) + logits -= jax.lax.stop_gradient(logits_max) + label_logits = jnp.take_along_axis(logits, labels[..., None], axis=-1)[..., 0] + log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=-1)) + return log_normalizers - label_logits + + +def cosine_similarity( + predictions: chex.Array, + targets: chex.Array, + epsilon: float = 0., +) -> chex.Array: + r"""Computes the cosine similarity between targets and predictions. + + The cosine **similarity** is a measure of similarity between vectors defined + as the cosine of the angle between them, which is also the inner product of + those vectors normalized to have unit norm. + + References: + [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity) + + Args: + predictions: The predicted vectors, with shape `[..., dim]`. + targets: Ground truth target vectors, with shape `[..., dim]`. + epsilon: minimum norm for terms in the denominator of the cosine similarity. + + Returns: + cosine similarity measures, with shape `[...]`. + """ + chex.assert_type([predictions, targets], float) + # vectorize norm fn, to treat all dimensions except the last as batch dims. + batched_norm_fn = jnp.vectorize( + utils.safe_norm, signature='(k)->()', excluded={1}) + # normalise the last dimension of targets and predictions. + unit_targets = targets / jnp.expand_dims( + batched_norm_fn(targets, epsilon), axis=-1) + unit_predictions = predictions / jnp.expand_dims( + batched_norm_fn(predictions, epsilon), axis=-1) + # return cosine similarity. + return jnp.sum(unit_targets * unit_predictions, axis=-1) + + +def cosine_distance( + predictions: chex.Array, + targets: chex.Array, + epsilon: float = 0., +) -> chex.Array: + r"""Computes the cosine distance between targets and predictions. + + The cosine **distance**, implemented here, measures the **dissimilarity** + of two vectors as the opposite of cosine **similarity**: `1 - cos(\theta)`. + + References: + [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity) + + Args: + predictions: The predicted vectors, with shape `[..., dim]`. + targets: Ground truth target vectors, with shape `[..., dim]`. + epsilon: minimum norm for terms in the denominator of the cosine similarity. + + Returns: + cosine distances, with shape `[...]`. + """ + chex.assert_type([predictions, targets], float) + # cosine distance = 1 - cosine similarity. + return 1. - cosine_similarity(predictions, targets, epsilon) + + +def log_cosh( + predictions: chex.Array, + targets: Optional[chex.Array] = None, +) -> chex.Array: + """Calculates the log-cosh loss for a set of predictions. + + log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` + for large x. It is a twice differentiable alternative to the Huber loss. + + References: + [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) + + Args: + predictions: a vector of arbitrary shape `[...]`. + targets: a vector with shape broadcastable to that of `predictions`; + if not provided then it is assumed to be a vector of zeros. + + Returns: + the log-cosh loss, with same shape as `predictions`. + """ + chex.assert_type([predictions], float) + errors = (predictions - targets) if (targets is not None) else predictions + # log(cosh(x)) = log((exp(x) + exp(-x))/2) = log(exp(x) + exp(-x)) - log(2) + return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) + + +def ctc_loss_with_forward_probs( + logits: chex.Array, + logit_paddings: chex.Array, + labels: chex.Array, + label_paddings: chex.Array, + blank_id: int = 0, + log_epsilon: float = -1e5) -> Tuple[chex.Array, chex.Array, chex.Array]: + r"""Computes CTC loss and CTC forward-probabilities. + + The CTC loss is a loss function based on log-likelihoods of the model that + introduces a special blank symbol :math:`\phi` to represent variable-length + output sequences. + + Forward probabilities returned by this function, as auxiliary results, are + grouped into two part: blank alpha-probability and non-blank alpha + probability. Those are defined as follows: + + .. math:: + \alpha_{\mathrm{BLANK}}(t, n) = + \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ + \alpha_{\mathrm{LABEL}}(t, n) = + \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). + + Here, :math:`\pi` denotes the alignment sequence in the reference + [Graves et al, 2006] that is blank-inserted representations of ``labels``. + The return values are the logarithms of the above probabilities. + + References: + [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) + + Args: + logits: (B, T, K)-array containing logits of each class where B denotes + the batch size, T denotes the max time frames in ``logits``, and K + denotes the number of classes including a class for blanks. + logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` + denotes that ``logits[b, t, :]`` are padded values. + labels: (B, N)-array containing reference integer labels where N denotes + the max time frames in the label sequence. + label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` + denotes that ``labels[b, n]`` is a padded label. In the current + implementation, ``labels`` must be right-padded, i.e. each row + ``labelpaddings[b, :]`` must be repetition of zeroes, followed by + repetition of ones. + blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + probabilities of blank symbols. + log_epsilon: Numerically-stable approximation of log(+0). + + Returns: + A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, + ``loss_value`` is a (B,)-array containing the loss values for each sequence + in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are + (T, B, N+1)-arrays where the (t, b, n)-th element denotes + \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th + sequence in the batch. + """ + + chex.assert_rank(logits, 3) + chex.assert_rank(labels, 2) + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_of_labels, maxlabellen = labels.shape + chex.assert_equal(batchsize, batchsize_of_labels) + chex.assert_equal(labels.shape, label_paddings.shape) + chex.assert_equal(logits.shape[:2], logit_paddings.shape) + + logprobs = jax.nn.log_softmax(logits) + labellens = maxlabellen - jnp.sum(label_paddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id:blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum('btk,bnk->btn', logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones( + (batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return jnp.concatenate( + [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, + prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logit_paddings.transpose((1, 0))) + _, (logalpha_phi, + logalpha_emit) = jax.lax.scan(loop_body, + (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] + per_seq_loss = -jnp.einsum('bn,bn->b', logalpha_phi_last, one_hot) + + return per_seq_loss, logalpha_phi, logalpha_emit + + +def ctc_loss(logits: chex.Array, + logit_paddings: chex.Array, + labels: chex.Array, + label_paddings: chex.Array, + blank_id: int = 0, + log_epsilon: float = -1e5) -> chex.Array: + """Computes CTC loss. + + See docstring for ``ctc_loss_with_forward_probs`` for details. + + Args: + logits: (B, T, K)-array containing logits of each class where B denotes + the batch size, T denotes the max time frames in ``logits``, and K + denotes the number of classes including a class for blanks. + logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` + denotes that ``logits[b, t, :]`` are padded values. + labels: (B, N)-array containing reference integer labels where N denotes + the max time frames in the label sequence. + label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` + denotes that ``labels[b, n]`` is a padded label. In the current + implementation, ``labels`` must be right-padded, i.e. each row + ``labelpaddings[b, :]`` must be repetition of zeroes, followed by + repetition of ones. + blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + probabilities of blank symbols. + log_epsilon: Numerically-stable approximation of log(+0). + + Returns: + (B,)-array containing loss values for each sequence in the batch. + """ + per_seq_loss, _, _ = ctc_loss_with_forward_probs( + logits, logit_paddings, labels, label_paddings, + blank_id=blank_id, log_epsilon=log_epsilon) + return per_seq_loss + + +def kl_divergence(log_predictions: chex.Array, + targets: chex.Array) -> chex.Array: + """Computes the Kullback-Leibler divergence (relative entropy) loss. + + Measures the information gain achieved if target probability distribution + would be used instead of predicted probability distribution. + + References: + [Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703) + + Args: + log_predictions: Probabilities of predicted distribution with shape + [..., dim]. Expected to be in the log-space to avoid underflow. + targets: Probabilities of target distribution with shape [..., dim]. + Expected to be strictly positive. + + Returns: + Kullback-Leibler divergence of predicted distribution from target + distribution with shape [...]. + """ + chex.assert_type([log_predictions, targets], float) + loss = targets * (jnp.log(targets) - log_predictions) + return jnp.sum(loss, axis=-1) + + +def kl_divergence_with_log_targets(log_predictions: chex.Array, + log_targets: chex.Array) -> chex.Array: + """Computes the Kullback-Leibler divergence (relative entropy) loss. + + Version of kl_div_loss where targets are given in log-space. + + Args: + log_predictions: Probabilities of predicted distribution with shape + [..., dim]. Expected to be in the log-space to avoid underflow. + log_targets: Probabilities of target distribution with shape [..., dim]. + Expected to be in the log-space. + + Returns: + Kullback-Leibler divergence of predicted distribution from target + distribution with shape [...]. + """ + chex.assert_type([log_predictions, log_targets], float) + loss = jnp.exp(log_targets) * (log_targets - log_predictions) + return jnp.sum(loss, axis=-1) + + +def hinge_loss(predictor_outputs: chex.Array, + targets: chex.Array) -> chex.Array: + """Computes the hinge loss for binary classification. + + Args: + predictor_outputs: Outputs of the decision function. + targets: Target values. Target values should be strictly in the set {-1, 1}. + + Returns: + Binary Hinge Loss. + """ + return jnp.maximum(0, 1 - predictor_outputs * targets) diff --git a/optax_add_eve/_src/loss_test.py b/optax_add_eve/_src/loss_test.py new file mode 100644 index 00000000..dd183177 --- /dev/null +++ b/optax_add_eve/_src/loss_test.py @@ -0,0 +1,500 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax._src.loss.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import loss + + +class L2LossTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ys = jnp.array([-2., -1., 0.5, 1.]) + self.ts = jnp.array([-1.5, 0., -1, 1.]) + # compute expected outputs in numpy. + self.exp = 0.5 * (self.ts - self.ys) ** 2 + + @chex.all_variants + def test_scalar(self): + np.testing.assert_allclose( + self.variant(loss.l2_loss)(self.ys[0], self.ts[0]), self.exp[0]) + + @chex.all_variants + def test_batched(self): + np.testing.assert_allclose( + self.variant(loss.l2_loss)(self.ys, self.ts), self.exp) + + @chex.all_variants + def test_shape_mismatch(self): + with self.assertRaises(AssertionError): + _ = self.variant(loss.l2_loss)(self.ys, jnp.expand_dims(self.ts, axis=-1)) + + +class HuberLossTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ys = np.array([-2.0, 0.5, 0., 0.5, 2.0, 4.0, 132.]) + self.ts = np.array([0.0, -0.5, 0., 1., 1.0, 2.0, 0.3]) + # computed expected outputs manually. + self.exp = np.array([1.5, 0.5, 0., 0.125, 0.5, 1.5, 131.2]) + + @chex.all_variants + def test_scalar(self): + np.testing.assert_allclose( + self.variant(loss.huber_loss)(self.ys[0], self.ts[0], delta=1.0), + self.exp[0]) + + @chex.all_variants + def test_batched(self): + np.testing.assert_allclose( + self.variant(loss.huber_loss)(self.ys, self.ts, delta=1.0), + self.exp) + + +class SmoothLabelsTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ts = np.array([[0., 1., 0.], [1., 0., 0.]], dtype=np.float32) + # compute expected outputs in numpy. + self.exp_alpha_zero = self.ts + self.exp_alpha_zero_point_one = 0.9 * self.ts + 0.1 / self.ts.shape[-1] + self.exp_alpha_one = jnp.ones_like(self.ts) / self.ts.shape[-1] + + @chex.all_variants + def test_scalar(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.smooth_labels)(self.ts[0], 0.), + self.exp_alpha_zero[0], atol=1e-4) + np.testing.assert_allclose( + self.variant(loss.smooth_labels)(self.ts[0], 0.1), + self.exp_alpha_zero_point_one[0], atol=1e-4) + np.testing.assert_allclose( + self.variant(loss.smooth_labels)(self.ts[0], 1.), + self.exp_alpha_one[0], atol=1e-4) + + @chex.all_variants + def test_batched(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.smooth_labels)(self.ts, 0.), + self.exp_alpha_zero, atol=1e-4) + np.testing.assert_allclose( + self.variant(loss.smooth_labels)(self.ts, 0.1), + self.exp_alpha_zero_point_one, atol=1e-4) + np.testing.assert_allclose( + self.variant(loss.smooth_labels)(self.ts, 1.), + self.exp_alpha_one, atol=1e-4) + + +class SoftmaxCrossEntropyTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ys = np.array([[10., 1., -2.], [1., 4., 0.2]], dtype=np.float32) + self.ts = np.array([[0., 1., 0.], [1., 0., 0.]], dtype=np.float32) + # taken expected outputs from rlax. + self.exp = np.array([9.00013, 3.0696733], dtype=np.float32) + + @chex.all_variants + def test_scalar(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.softmax_cross_entropy)(self.ys[0], self.ts[0]), + self.exp[0], atol=1e-4) + + @chex.all_variants + def test_batched(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.softmax_cross_entropy)(self.ys, self.ts), + self.exp, atol=1e-4) + + +class SoftmaxCrossEntropyWithIntegerLabelsTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ys = np.array([[10., 1., -2.], [1., 4., 0.2]], dtype=np.float32) + self.ts = np.array([1, 0], dtype=np.int32) + + @chex.all_variants + def test_consistent_with_softmax_cross_entropy_scalar(self): + """Tests for a scalar.""" + exp = loss.softmax_cross_entropy(self.ys[0], jax.nn.one_hot(self.ts[0], 3)) + np.testing.assert_allclose( + self.variant(loss.softmax_cross_entropy_with_integer_labels)( + self.ys[0], self.ts[0]), + exp, rtol=1e-6) + + @chex.all_variants + def test_consistent_with_softmax_cross_entropy_batched(self): + """Tests for a full batch.""" + exp = loss.softmax_cross_entropy(self.ys, jax.nn.one_hot(self.ts, 3)) + np.testing.assert_allclose( + self.variant(loss.softmax_cross_entropy_with_integer_labels)( + self.ys, self.ts), + exp, rtol=1e-6) + + +class SigmoidCrossEntropyTest(parameterized.TestCase): + + @parameterized.parameters( + dict(preds=np.array([-1e+09, -1e-09]), + labels=np.array([1., 0.]), + expected=5e+08), + dict(preds=np.array([-1e+09, -1e-09]), + labels=np.array([0., 1.]), + expected=0.3465736), + dict(preds=np.array([1e+09, 1e-09]), + labels=np.array([1., 0.]), + expected=0.3465736), + dict(preds=np.array([1e+09, 1e-09]), + labels=np.array([0., 1.]), + expected=5e+08), + dict(preds=np.array([-1e+09, 1e-09]), + labels=np.array([1., 0.]), + expected=5e+08), + dict(preds=np.array([-1e+09, 1e-09]), + labels=np.array([0., 1.]), + expected=0.3465736), + dict(preds=np.array([1e+09, -1e-09]), + labels=np.array([1., 0.]), + expected=0.3465736), + dict(preds=np.array([1e+09, -1e-09]), + labels=np.array([0., 1.]), + expected=5e+08), + dict(preds=np.array([0., 0.]), + labels=np.array([1., 0.]), + expected=0.6931472), + dict(preds=np.array([0., 0.]), + labels=np.array([0., 1.]), + expected=0.6931472), + ) + def testSigmoidCrossEntropy(self, preds, labels, expected): + tested = jnp.mean(loss.sigmoid_binary_cross_entropy(preds, labels)) + np.testing.assert_allclose(tested, expected, rtol=1e-6, atol=1e-6) + + +class CosineDistanceTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ys = np.array([[10., 1., -2.], [1., 4., 0.2]], dtype=np.float32) + self.ts = np.array([[0., 1.2, 0.2], [1., -0.3, 0.]], dtype=np.float32) + # distance computed expected output from `scipy 1.20`. + self.exp = np.array([0.9358251989, 1.0464068465], dtype=np.float32) + + @chex.all_variants + def test_scalar_distance(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.cosine_distance)(self.ys[0], self.ts[0]), + self.exp[0], atol=1e-4) + + @chex.all_variants + def test_scalar_similarity(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.cosine_similarity)(self.ys[0], self.ts[0]), + 1. - self.exp[0], atol=1e-4) + + @chex.all_variants + def test_batched_distance(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.cosine_distance)(self.ys, self.ts), + self.exp, atol=1e-4) + + @chex.all_variants + def test_batched_similarity(self): + """Tests for a full batch.""" + np.testing.assert_allclose( + self.variant(loss.cosine_similarity)(self.ys, self.ts), + 1. - self.exp, atol=1e-4) + + +# TODO(b/188419459): add test for grad and second order grad. +class LogCoshTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + # Test large values for overflow + self.ys = jnp.array([500, -2., -1., 0.5, 1.]) + self.ts = jnp.array([-200, -1.5, 0., -1, 1.]) + # computed using tensorflow.keras.losses.log_cosh v2.4.1 + self.exp = jnp.array([699.3068, 0.12011445, 0.4337809, 0.85544014, 0.]) + self.exp_ys_only = jnp.array( + [499.30685, 1.3250027, 0.4337809, 0.12011451, 0.43378082]) + + @chex.all_variants + def test_scalar(self): + out = self.variant(loss.log_cosh)(self.ys[0], self.ts[0]) + np.testing.assert_allclose(out, self.exp[0], atol=1e-5) + + @chex.all_variants + def test_batched(self): + out = self.variant(loss.log_cosh)(self.ys, self.ts) + np.testing.assert_allclose(out, self.exp, atol=1e-5) + + @chex.all_variants + def test_scalar_predictions_only(self): + out = self.variant(loss.log_cosh)(self.ys[0]) + np.testing.assert_allclose(out, self.exp_ys_only[0], atol=1e-5) + + @chex.all_variants + def test_batched_predictions_only(self): + out = self.variant(loss.log_cosh)(self.ys) + np.testing.assert_allclose(out, self.exp_ys_only, atol=1e-5) + + +def _lengths_to_paddings(lengths: chex.Array, maxlength: int) -> chex.Array: + indices = jnp.arange(maxlength).reshape((1,) * lengths.ndim + (maxlength,)) + lengths = jnp.expand_dims(lengths, axis=-1) + elem_valid = indices < lengths + return np.logical_not(elem_valid).astype(np.float32) + + +def _average_ctc_loss(logprobs: chex.Array, logprob_paddings: chex.Array, + labels: chex.Array, + label_paddings: chex.Array) -> chex.Array: + return jnp.average( + loss.ctc_loss(logprobs, logprob_paddings, labels, label_paddings)) + + +class CTCTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + np.random.seed(1234) + self._rtol = 5e-3 if jax.default_backend() != 'cpu' else 1e-6 + + @chex.all_variants + def test_with_one_to_one_alignment(self): + # when inputsteps and outputsteps are equal, no blank will be allowed. + batchsize = 8 + steps = 50 + nclasses = 40 + logits = np.random.randn(batchsize, steps, nclasses) + labels = np.random.uniform( + 1, nclasses, size=(batchsize, steps)).astype(np.int32) + + # This function only covers the cases without same-label repetition. + # `test_repeat_with_one_to_one_alignment` below complements those cases. + # So, redraw the samples for satisfying the non-repetition constraint. + for n in range(labels.shape[0]): + for t in range(1, labels.shape[1]): + while labels[n, t] == labels[n, t - 1]: + labels[n, t] = np.random.uniform(1, nclasses) + + results = self.variant(loss.ctc_loss_with_forward_probs)( + logits, np.zeros(logits.shape[:2]), + labels, np.zeros(labels.shape)) + (per_seq_loss, logalpha_blank, logalpha_emit) = results + + logprobs = jax.nn.log_softmax(logits) + for b in range(batchsize): + p = 0.0 + for t in range(steps): + p += logprobs[b, t, labels[b, t]] + np.testing.assert_allclose( + np.array(-p), per_seq_loss[b], rtol=self._rtol) + + # Check forward-probabilities. + # 1. All-phi path: logalpha_blank[-1, b, 0] must be a probability of + # the path that outputs blank symbols for all the frames. + np.testing.assert_allclose(logalpha_blank[-1, b, 0], + np.sum(logprobs[b, :, 0]), + rtol=self._rtol) + + # 2. After emitting all the labels + # the negated loss must be identical with the forward probability of + # paths after consuming all the labels (because one-to-one alignment + # doesn't allow extra blank symbols) + np.testing.assert_allclose(logalpha_emit[-1, b, steps - 1], + -per_seq_loss[b], + rtol=self._rtol) + # and, this forward probability must be copied to the blank forward + # probability of the next step. + np.testing.assert_allclose(logalpha_blank[-1, b, steps], + -per_seq_loss[b], + rtol=self._rtol) + + @chex.all_variants + def test_with_one_to_one_alignment_and_paddings(self): + batch_size = 5 + nclasses = 13 + steps = 7 + logits = np.random.normal(size=[batch_size, steps, nclasses]) + logprobs = jax.nn.log_softmax(logits) + + labels = [] + for n in range(batch_size): + row = list(range(1, nclasses)) + np.random.shuffle(row) + labels.append(row[:steps]) + labels = np.array(labels) + + lengths = np.random.randint(3, 6, size=(batch_size,)) + paddings = _lengths_to_paddings(lengths, steps) + + actual_loss = self.variant(loss.ctc_loss)(logits, paddings, labels, + paddings) + + value_and_grad = self.variant(jax.value_and_grad(_average_ctc_loss)) + unused_avg_loss, actual_gradients = value_and_grad(logits, paddings, labels, + paddings) + + for n in range(batch_size): + expected_loss = -sum(logprobs[n, t, k] + for t, k in enumerate(labels[n, :lengths[n]])) + np.testing.assert_allclose(expected_loss, actual_loss[n], rtol=self._rtol) + + expected_gradients = np.array(jax.nn.softmax(logits[n])) + expected_gradients[lengths[n]:] = 0.0 + for t, k in enumerate(labels[n, :lengths[n]]): + expected_gradients[t, k] -= 1.0 + expected_gradients /= batch_size + np.testing.assert_allclose( + expected_gradients, actual_gradients[n], rtol=self._rtol) + + @chex.all_variants + def test_repeat_with_one_to_one_alignment(self): + # test if it can correctly handle the same-label repetition. + nclasses = 5 + labels = np.array([ + [1, 2, 2, 3], + [2, 3, 4, 4], + [1, 1, 1, 1], + [1, 1, 2, 3], + [1, 1, 1, 2], + ]) + expected_alignment = [ # expected minimal alignment + [1, 2, 0, 2, 3], + [2, 3, 4, 0, 4], + [1, 0, 1, 0, 1, 0, 1], + [1, 0, 1, 2, 3], + [1, 0, 1, 0, 1, 2], + ] + batch_size = len(labels) + label_lens = np.array([4] * batch_size) + label_steps = 6 + # Designed to have two padding elements on the right. + labels = np.pad(labels, [(0, 0), (0, label_steps - labels.shape[1])]) + label_paddings = _lengths_to_paddings(label_lens, label_steps) + + logit_lengths = np.array([len(seq) for seq in expected_alignment]) + logit_steps = max(logit_lengths) + logits = np.random.randn(batch_size, logit_steps, nclasses) + logit_paddings = _lengths_to_paddings(logit_lengths, logit_steps) + + per_seq_loss = self.variant(loss.ctc_loss)(logits, logit_paddings, labels, + label_paddings) + + logprobs = jax.nn.log_softmax(logits) + for n in range(batch_size): + expected_loss = -sum(logprobs[n, t, k] + for t, k in enumerate(expected_alignment[n])) + np.testing.assert_allclose( + jnp.array(expected_loss), per_seq_loss[n], rtol=self._rtol) + + +class KLDivergenceTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.log_ps = np.array( + [[-2.9957, -3.5066, -3.9120, -1.2040, -0.6931, -2.3026], + [-1.6094, -1.6094, -1.6094, -2.3026, -1.8971, -1.8971]]) + self.qs = np.array([[0.2, 0.2, 0.2, 0.1, 0.15, 0.15], + [0.05, 0.03, 0.02, 0.3, 0.5, 0.1]]) + # Computed kullback-leibler divergence of P from Q. + self.exp = np.array([0.8875625, 0.7187435584901326]) + + @chex.all_variants + def test_scalar(self): + np.testing.assert_allclose( + self.variant(loss.kl_divergence)(self.log_ps[0], self.qs[0]), + self.exp[0], + atol=1e-4) + + @chex.all_variants + def test_batched(self): + np.testing.assert_allclose( + self.variant(loss.kl_divergence)(self.log_ps, self.qs), + self.exp, + atol=1e-4) + + +class KLDivergenceWithLogTargetsTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.log_ps = np.array( + [[-2.9957, -3.5066, -3.9120, -1.2040, -0.6931, -2.3026], + [-1.6094, -1.6094, -1.6094, -2.3026, -1.8971, -1.8971]]) + self.qs = np.array([[-1.6094, -1.6094, -1.6094, -2.3026, -1.8971, -1.8971], + [-2.9957, -3.5066, -3.9120, -1.2040, -0.6931, -2.3026]]) + # Computed kullback-leibler divergence of P from Q. + self.exp = np.array([0.8875625, 0.7187435584901326]) + + @chex.all_variants + def test_scalar(self): + np.testing.assert_allclose( + self.variant(loss.kl_divergence_with_log_targets)(self.log_ps[0], + self.qs[0]), + self.exp[0], + atol=1e-4) + + @chex.all_variants + def test_batched(self): + np.testing.assert_allclose( + self.variant(loss.kl_divergence_with_log_targets)(self.log_ps, self.qs), + self.exp, + atol=1e-4) + + +class HingeLossTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ys = np.array([ + -0.97740268, -1.01812625, -0.81675726, -0.73605974, 2.08235648, + 1.84101354, -1.0581002 + ]) + self.ts = np.array([-1, -1, -1, -1, 1, 1, -1]) + # Computed expected outputs. + self.correct_result = np.array( + [0.02259731, 0., 0.18324274, 0.26394027, 0., 0., 0.]) + + @chex.all_variants + def test_batched(self): + np.testing.assert_allclose( + self.variant(loss.hinge_loss)(self.ys, self.ts), + self.correct_result, + atol=1e-4) + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/numerics_test.py b/optax_add_eve/_src/numerics_test.py new file mode 100644 index 00000000..89c7a706 --- /dev/null +++ b/optax_add_eve/_src/numerics_test.py @@ -0,0 +1,112 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax._src.numerics.""" + +import functools +import itertools +import re + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import numerics + +_ALL_ORDS = [None, np.inf, -np.inf, 'fro', 'nuc', 0, 1, 2, -2, -2, -1.5, 1.5] + +int32_array = lambda i: jnp.array(i, dtype=jnp.int32) +float32_array = lambda i: jnp.array(i, dtype=jnp.float32) + + +def _invalid_ord_axis_inputs(ord_axis_keepdims): + ord_, axis = ord_axis_keepdims[0], ord_axis_keepdims[1] + return any(((ord_ == 0 and axis is None), + (isinstance(ord_, float) and axis is None), + (isinstance(ord_, str) and axis is not None))) + + +class NumericsTest(chex.TestCase): + + @chex.all_variants + def test_safe_int32_increments(self): + inc_fn = self.variant(numerics.safe_int32_increment) + # increment small numbers correctly. + base = int32_array(3) + incremented = inc_fn(base) + np.testing.assert_array_equal(incremented, int32_array(4)) + # avoid overflow when incrementing maxint. + base = int32_array(np.iinfo(np.int32).max) + incremented = inc_fn(base) + np.testing.assert_array_equal(incremented, base) + + @chex.all_variants + @parameterized.parameters( + itertools.filterfalse( + _invalid_ord_axis_inputs, + itertools.product(_ALL_ORDS, [None, 0, 1], [False, True]))) + def test_safe_norm(self, ord, axis, keepdims): # pylint: disable=redefined-builtin + dnorm_dx = self.variant( + jax.jacfwd( + functools.partial( + numerics.safe_norm, ord=ord, axis=axis, keepdims=keepdims), + argnums=0)) + # Test gradient is 0. in 0. when zero min norm is used. + g = dnorm_dx(float32_array(jnp.zeros((3, 4))), float32_array(0.)) + np.testing.assert_array_equal(g, jnp.zeros_like(g)) + # Test gradient is 0. in 0. when non zero min norm is used. + g = dnorm_dx(float32_array(jnp.zeros((3, 4))), float32_array(3.)) + np.testing.assert_array_equal(g, jnp.zeros_like(g)) + + @chex.all_variants + def test_safe_rms(self): + drms_dx = self.variant(jax.grad(numerics.safe_root_mean_squares)) + # Test gradient is 0. in 0. when zero min rms is used. + g = drms_dx(float32_array(0.), float32_array(0.)) + np.testing.assert_array_equal(g, jnp.zeros_like(g)) + # Test gradient is 0. in 0. when non zero min rms is used. + g = drms_dx(float32_array(0.), float32_array(3.)) + np.testing.assert_array_equal(g, jnp.zeros_like(g)) + + def test_complex_vs_real_abs_sqr(self): + # Tests that JAX generates the same HLO from `numerics.abs_sq`, + # `jnp.square(x)`, `x * x`, and `x**2`. + real_sq_fns = (lambda x: x**2, lambda x: x * x, jnp.square) + + def _get_hlo_repr(f, x): + hlo_string = jax.xla_computation(f)(x).as_hlo_text() + return re.sub('HloModule.*?\n', '', + re.sub('ENTRY.*?{', 'ENTRY XXXX', hlo_string)) + + # Real arg (same HLO). + for real_sq_fn in real_sq_fns: + for real_x in (3, 3.0, np.array([4, 5.2])): + self.assertEqual( + _get_hlo_repr(real_sq_fn, real_x), + _get_hlo_repr(numerics.abs_sq, real_x)) + + # Complex arg (different HLOs). + for real_sq_fn in real_sq_fns: + for complex_x in (1j, 3. + 1j, np.array([4 + 1j, 5.2 + 1j])): + self.assertNotEqual( + _get_hlo_repr(real_sq_fn, complex_x), + _get_hlo_repr(numerics.abs_sq, complex_x)) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/privacy.py b/optax_add_eve/_src/privacy.py new file mode 100644 index 00000000..78c58210 --- /dev/null +++ b/optax_add_eve/_src/privacy.py @@ -0,0 +1,74 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differential Privacy utilities.""" + +from typing import NamedTuple + +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base +from optax_add_eve._src import clipping + + +# pylint:disable=no-value-for-parameter +class DifferentiallyPrivateAggregateState(NamedTuple): + """State containing PRNGKey for `differentially_private_aggregate`.""" + rng_key: jnp.array + + +def differentially_private_aggregate( + l2_norm_clip: float, + noise_multiplier: float, + seed: int +) -> base.GradientTransformation: + """Aggregates gradients based on the DPSGD algorithm. + + WARNING: Unlike other transforms, `differentially_private_aggregate` expects + the input updates to have a batch dimension in the 0th axis. That is, this + function expects per-example gradients as input (which are easy to obtain in + JAX using `jax.vmap`). It can still be composed with other transformations as + long as it is the first in the chain. + + References: + [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) + + Args: + l2_norm_clip: maximum L2 norm of the per-example gradients. + noise_multiplier: ratio of standard deviation to the clipping norm. + seed: initial seed used for the jax.random.PRNGKey + + Returns: + A `GradientTransformation`. + """ + noise_std = l2_norm_clip * noise_multiplier + + def init_fn(params): + del params + return DifferentiallyPrivateAggregateState(rng_key=jax.random.PRNGKey(seed)) + + def update_fn(updates, state, params=None): + del params + grads_flat, grads_treedef = jax.tree_util.tree_flatten(updates) + bsize = grads_flat[0].shape[0] + clipped, _ = clipping.per_example_global_norm_clip(grads_flat, l2_norm_clip) + + new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1) + noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize + for g, r in zip(clipped, rngs)] + return (jax.tree_util.tree_unflatten(grads_treedef, noised), + DifferentiallyPrivateAggregateState(rng_key=new_key)) + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/privacy_test.py b/optax_add_eve/_src/privacy_test.py new file mode 100644 index 00000000..82455063 --- /dev/null +++ b/optax_add_eve/_src/privacy_test.py @@ -0,0 +1,112 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `privacy.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import privacy + + +class DifferentiallyPrivateAggregateTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.batch_size = 8 + self.params = {'key_a': (jnp.zeros((2, 3, 4)), jnp.zeros([])), + 'key_b': jnp.zeros((6, 7))} + # Example `i`'s grads are full of `i`s. Important to include 0 to ensure + # there are no divisons by 0 (e.g. in norm clipping) + a = jnp.arange(self.batch_size) + self.per_eg_grads = jax.tree_util.tree_map( + lambda p: jnp.moveaxis(a * jnp.ones(p.shape+(self.batch_size,)), -1, 0), + self.params) + + @chex.all_variants + def test_no_privacy(self): + """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" + dp_agg = privacy.differentially_private_aggregate( + l2_norm_clip=jnp.finfo(jnp.float32).max, + noise_multiplier=0., + seed=0) + state = dp_agg.init(self.params) + update_fn = self.variant(dp_agg.update) + mean_grads = jax.tree_util.tree_map(lambda g: g.mean(0), self.per_eg_grads) + + for _ in range(3): + updates, state = update_fn(self.per_eg_grads, state) + chex.assert_trees_all_close(updates, mean_grads) + + @chex.all_variants + @parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0) + def test_clipping_norm(self, l2_norm_clip): + dp_agg = privacy.differentially_private_aggregate( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0., + seed=42) + state = dp_agg.init(self.params) + update_fn = self.variant(dp_agg.update) + + # Shape of the three arrays below is (self.batch_size, ) + norms = [jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1) + for g in jax.tree_util.tree_leaves(self.per_eg_grads)] + global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0) + divisors = jnp.maximum(global_norms / l2_norm_clip, 1.) + # Since the values of all the parameters are the same within each example, + # we can easily compute what the values should be: + expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors) + expected_tree = jax.tree_util.tree_map( + lambda p: jnp.broadcast_to(expected_val, p.shape), self.params) + + for _ in range(3): + updates, state = update_fn(self.per_eg_grads, state, self.params) + chex.assert_trees_all_close(updates, expected_tree, rtol=2e-7) + + @chex.all_variants + @parameterized.parameters((3.0, 2.0), (1.0, 5.0), (100.0, 4.0), (1.0, 90.0)) + def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): + """Standard dev. of noise should be l2_norm_clip * noise_multiplier.""" + dp_agg = privacy.differentially_private_aggregate( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + seed=1337) + state = dp_agg.init(None) + update_fn = self.variant(dp_agg.update) + expected_std = l2_norm_clip * noise_multiplier + + grads = [jnp.ones((1, 100, 100))] # batch size 1 + for _ in range(3): + updates, state = update_fn(grads, state) + chex.assert_trees_all_close(expected_std, + jnp.std(updates[0]), + atol=0.1 * expected_std) + + def test_aggregated_updates_as_input_fails(self): + """Expect per-example gradients as input to this transform.""" + dp_agg = privacy.differentially_private_aggregate(l2_norm_clip=0.1, + noise_multiplier=1.1, + seed=2021) + state = dp_agg.init(self.params) + mean_grads = jax.tree_util.tree_map(lambda g: g.mean(0), self.per_eg_grads) + with self.assertRaises(ValueError): + dp_agg.update(mean_grads, state, self.params) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/schedule.py b/optax_add_eve/_src/schedule.py new file mode 100644 index 00000000..4fcdca7d --- /dev/null +++ b/optax_add_eve/_src/schedule.py @@ -0,0 +1,620 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""JAX Schedules. + +Schedules may be used to anneal the value of a hyper-parameter over time; for +instance, they may be used to anneal the learning rate used to update an agent's +parameters or the exploration factor used to select actions. +""" + +import functools +import inspect +from typing import Callable, Dict, Union, NamedTuple, Optional, Iterable, Sequence + +from absl import logging +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base +from optax_add_eve._src import numerics + + +def constant_schedule( + value: Union[float, int] +) -> base.Schedule: + """Constructs a constant schedule. + + Args: + value: value to be held constant throughout. + + Returns: + schedule: A function that maps step counts to values. + """ + return lambda count: value + + +def polynomial_schedule( + init_value: chex.Scalar, + end_value: chex.Scalar, + power: chex.Scalar, + transition_steps: int, + transition_begin: int = 0 +) -> base.Schedule: + """Constructs a schedule with polynomial transition from init to end value. + + Args: + init_value: initial value for the scalar to be annealed. + end_value: end value of the scalar to be annealed. + power: the power of the polynomial used to transition from init to end. + transition_steps: number of steps over which annealing takes place, + the scalar starts changing at `transition_begin` steps and completes + the transition by `transition_begin + transition_steps` steps. + If `transition_steps <= 0`, then the entire annealing process is disabled + and the value is held fixed at `init_value`. + transition_begin: must be positive. After how many steps to start annealing + (before this many steps the scalar value is held fixed at `init_value`). + + Returns: + schedule: A function that maps step counts to values. + """ + if transition_steps <= 0: + logging.info( + 'A polynomial schedule was set with a non-positive `transition_steps` ' + 'value; this results in a constant schedule with value `init_value`.') + return lambda count: init_value + + if transition_begin < 0: + logging.info( + 'An exponential schedule was set with a negative `transition_begin` ' + 'value; this will result in `transition_begin` falling back to `0`.') + transition_begin = 0 + + def schedule(count): + count = jnp.clip(count - transition_begin, 0, transition_steps) + frac = 1 - count / transition_steps + return (init_value - end_value) * (frac**power) + end_value + return schedule + + +# Alias polynomial schedule to linear schedule for convenience. +def linear_schedule( + init_value: chex.Scalar, + end_value: chex.Scalar, + transition_steps: int, + transition_begin: int = 0 +) -> base.Schedule: + return polynomial_schedule( + init_value=init_value, end_value=end_value, power=1, + transition_steps=transition_steps, transition_begin=transition_begin) + + +def piecewise_constant_schedule( + init_value: float, + boundaries_and_scales: Optional[Dict[int, float]] = None +) -> base.Schedule: + """Returns a function which implements a piecewise constant schedule. + + Args: + init_value: An initial value `init_v`. + boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling + factors `f_i`. For any step count `s`, the schedule returns `init_v` + scaled by the product of all factors `f_i` such that `b_i` < `s`. + + Returns: + schedule: A function that maps step counts to values. + """ + if boundaries_and_scales is not None: + all_positive = all(scale >= 0. for scale in boundaries_and_scales.values()) + if not all_positive: + raise ValueError( + '`piecewise_constant_schedule` expects non-negative scale factors') + + def schedule(count): + v = init_value + if boundaries_and_scales is not None: + for threshold, scale in sorted(boundaries_and_scales.items()): + indicator = jnp.maximum(0., jnp.sign(threshold - count)) + v = v * indicator + (1 - indicator) * scale * v + return v + + return schedule + + +def exponential_decay( + init_value: float, + transition_steps: int, + decay_rate: float, + transition_begin: int = 0, + staircase: bool = False, + end_value: Optional[float] = None +) -> base.Schedule: + """Constructs a schedule with either continuous or discrete exponential decay. + + This function applies an exponential decay function to a provided initial + value. The function returns the decayed value as follows: + + ``` + decayed_value = init_value * decay_rate ^ (count / transition_steps) + ``` + + If the argument `staircase` is `True`, then `count / transition_steps` is + an integer division and the decayed value follows a staircase function. + + Args: + init_value: the initial learning rate. + transition_steps: must be positive. See the decay computation above. + decay_rate: must not be zero. The decay rate. + transition_begin: must be positive. After how many steps to start annealing + (before this many steps the scalar value is held fixed at `init_value`). + staircase: if `True`, decay the values at discrete intervals. + end_value: the value at which the exponential decay stops. When + `decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as + an upper bound. Has no effect when `decay_rate` = 0. + + Returns: + schedule: A function that maps step counts to values. + """ + + if transition_steps <= 0: + logging.info( + 'An exponential schedule was set with a non-positive `transition_steps`' + ' value; this will result in a constant schedule with value ' + '`init_value`.') + return lambda count: init_value + + if decay_rate == 0: + logging.info( + 'An exponential schedule was set with a zero `decay_rate` value; ' + 'this will result in a constant schedule with value `init_value`.') + return lambda count: init_value + + if transition_begin < 0: + logging.info( + 'An exponential schedule was set with a negative `transition_begin` ' + 'value; this will result in `transition_begin` falling back to `0`.') + transition_begin = 0 + + if end_value is not None: + clip_fn = jnp.maximum if decay_rate < 1.0 else jnp.minimum + + def schedule(count): + decreased_count = count - transition_begin + p = decreased_count / transition_steps + if staircase: + p = jnp.floor(p) + decayed_value = jnp.where( + decreased_count <= 0, init_value, init_value * jnp.power(decay_rate, p)) + if end_value is not None: + decayed_value = clip_fn(decayed_value, end_value) + return decayed_value + + return schedule + + +def cosine_decay_schedule( + init_value: float, + decay_steps: int, + alpha: float = 0.0 +) -> base.Schedule: + """Returns a function which implements cosine learning rate decay. + + The schedule does not restart when ``decay_steps`` has been reached. Instead, + the learning rate remains constant afterwards. For a cosine schedule with + restarts, :func:`optax.join_schedules` can be used to join several cosine + decay schedules. + + For more details see: https://arxiv.org/abs/1608.03983. + + Args: + init_value: An initial value `init_v`. + decay_steps: Positive integer - the number of steps for which to apply + the decay for. + alpha: Float. The minimum value of the multiplier used to adjust the + learning rate. + + Returns: + schedule: A function that maps step counts to values. + """ + if not decay_steps > 0: + raise ValueError('The cosine_decay_schedule requires positive decay_steps!') + + def schedule(count): + count = jnp.minimum(count, decay_steps) + cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * count / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + return init_value * decayed + + return schedule + + +def _linear_interpolate(start: float, end: float, pct: float): + return (end-start) * pct + start + + +def _cosine_interpolate(start: float, end: float, pct: float): + return end + (start-end) / 2.0 * (jnp.cos(jnp.pi * pct) + 1) + + +def piecewise_interpolate_schedule( + interpolate_type: str, + init_value: float, + boundaries_and_scales: Optional[Dict[int, float]] = None +) -> base.Schedule: + """Returns a function which implements a piecewise interpolated schedule. + + Args: + interpolate_type: 'linear' or 'cosine', specifying the interpolation + strategy. + init_value: An initial value `init_v`. + boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling + factors `f_i`. At boundary step `b_i`, the schedule returns `init_v` + scaled by the product of all factors `f_j` such that `b_j` <= `b_i`. The + values in between each boundary will be interpolated as per `type`. + + Returns: + schedule: A function that maps step counts to values. + """ + if interpolate_type == 'linear': + interpolate_fn = _linear_interpolate + elif interpolate_type == 'cosine': + interpolate_fn = _cosine_interpolate + else: + raise ValueError('`interpolate_type` must be either \'cos\' or \'linear\'') + + if boundaries_and_scales: + boundaries, scales = zip(*sorted(boundaries_and_scales.items())) + if not all(scale >= 0. for scale in scales): + raise ValueError( + '`piecewise_interpolate_schedule` expects non-negative scale factors') + else: + boundaries, scales = (), () + + bounds = jnp.stack((0,) + boundaries) + values = jnp.cumprod(jnp.stack((init_value,) + scales)) + interval_sizes = (bounds[1:] - bounds[:-1]) + + def schedule(count): + indicator = (bounds[:-1] <= count) & (count < bounds[1:]) + pct = (count - bounds[:-1]) / interval_sizes + interp_vals = interpolate_fn(values[:-1], values[1:], pct) + return indicator.dot(interp_vals) + (bounds[-1] <= count) * values[-1] + + return schedule + + +def linear_onecycle_schedule( + transition_steps: int, + peak_value: float, + pct_start: float = 0.3, + pct_final: float = 0.85, + div_factor: float = 25.0, + final_div_factor: float = 1e4 +) -> base.Schedule: + """Returns a function which implements the onecycle learning rate schedule. + + This function uses a linear annealing strategy. + For more details see: https://arxiv.org/abs/1708.07120 + + Args: + transition_steps: Number of steps over which annealing takes place. + peak_value: Maximum value attained by schedule at pct_start percent + of the cycle (in number of steps). + pct_start: The percentage of the cycle (in number of steps) spent + increasing the learning rate. + pct_final: The percentage of the cycle (in number of steps) spent + increasing to peak_value then decreasing back to init_value. + div_factor: Determines the initial value via init_value = + peak_value / div_factor + final_div_factor: Determines the final value via final_value = + init_value / final_div_factor + + Returns: + schedule: A function that maps step counts to values. + """ + if transition_steps <= 0: + raise ValueError( + 'A linear onecycle schedule was set with a non-positive ' + '`transition_steps`') + + return piecewise_interpolate_schedule( + 'linear', + peak_value / div_factor, + {int(pct_start * transition_steps): div_factor, + int(pct_final * transition_steps): 1. / div_factor, + transition_steps: 1. / final_div_factor}) + + +def cosine_onecycle_schedule( + transition_steps: int, + peak_value: float, + pct_start: float = 0.3, + div_factor: float = 25.0, + final_div_factor: float = 1e4 +) -> base.Schedule: + """Returns a function which implements the onecycle learning rate schedule. + + This function uses a cosine annealing strategy. + For more details see: https://arxiv.org/abs/1708.07120 + + Args: + transition_steps: Number of steps over which annealing takes place. + peak_value: Maximum value attained by schedule at pct_start percent + of the cycle (in number of steps). + pct_start: The percentage of the cycle (in number of steps) spent + increasing the learning rate. + div_factor: Determines the initial value via init_value = + peak_value / div_factor + final_div_factor: Determines the final value via final_value = + init_value / final_div_factor + + Returns: + schedule: A function that maps step counts to values. + """ + if transition_steps <= 0: + raise ValueError( + 'A linear onecycle schedule was set with a non-positive ' + '`transition_steps`') + + return piecewise_interpolate_schedule( + 'cosine', + peak_value / div_factor, + {int(pct_start * transition_steps): div_factor, + int(transition_steps): 1. / (div_factor * final_div_factor)}) + + +def join_schedules(schedules: Sequence[base.Schedule], + boundaries: Sequence[int]) -> base.Schedule: + """Sequentially apply multiple schedules. + + Args: + schedules: A list of callables (expected to be optax schedules). Each + schedule will receive a step count indicating the number of steps since + the previous boundary transition. + boundaries: A list of integers (of length one less than schedules) that + indicate when to transition between schedules. + Returns: + schedule: A function that maps step counts to values. + """ + def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray: + output = schedules[0](step) + for boundary, schedule in zip(boundaries, schedules[1:]): + output = jnp.where(step < boundary, output, schedule(step - boundary)) + return output + return schedule + + +def warmup_cosine_decay_schedule( + init_value: float, + peak_value: float, + warmup_steps: int, + decay_steps: int, + end_value: float = 0.0 +) -> base.Schedule: + """Linear warmup followed by cosine decay. + + Args: + init_value: Initial value for the scalar to be annealed. + peak_value: Peak value for scalar to be annealed at end of warmup. + warmup_steps: Positive integer, the length of the linear warmup. + decay_steps: Positive integer, the total length of the schedule. Note that + this includes the warmup time, so the number of steps during which cosine + annealing is applied is `decay_steps - warmup_steps`. + end_value: End value of the scalar to be annealed. + Returns: + schedule: A function that maps step counts to values. + """ + schedules = [ + linear_schedule( + init_value=init_value, + end_value=peak_value, + transition_steps=warmup_steps), + cosine_decay_schedule( + init_value=peak_value, + decay_steps=decay_steps - warmup_steps, + alpha=end_value/peak_value)] + return join_schedules(schedules, [warmup_steps]) + + +def warmup_exponential_decay_schedule( + init_value: float, + peak_value: float, + warmup_steps: int, + transition_steps: int, + decay_rate: float, + transition_begin: int = 0, + staircase: bool = False, + end_value: Optional[float] = None +) -> base.Schedule: + """Linear warmup followed by exponential decay. + + Args: + init_value: Initial value for the scalar to be annealed. + peak_value: Peak value for scalar to be annealed at end of warmup. + warmup_steps: Positive integer, the length of the linear warmup. + transition_steps: must be positive. See `exponential_decay` for more + details. + decay_rate: must not be zero. The decay rate. + transition_begin: must be positive. After how many steps to start annealing + (before this many steps the scalar value is held fixed at `peak_value`). + staircase: if `True`, decay the values at discrete intervals. + end_value: the value at which the exponential decay stops. When + `decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as + an upper bound. Has no effect when `decay_rate` = 0. + Returns: + schedule: A function that maps step counts to values. + """ + schedules = [ + linear_schedule( + init_value=init_value, + end_value=peak_value, + transition_steps=warmup_steps), + exponential_decay( + init_value=peak_value, + transition_steps=transition_steps, + decay_rate=decay_rate, + transition_begin=transition_begin, + staircase=staircase, + end_value=end_value)] + return join_schedules(schedules, [warmup_steps]) + + +def sgdr_schedule(cosine_kwargs: Iterable[Dict[str, chex.Numeric]] + ) -> base.Schedule: + """SGD with warm restarts, from Loschilov & Hutter (arXiv:1608.03983). + + This learning rate schedule applies multiple joined cosine decay cycles. + For more details see: https://arxiv.org/abs/1608.03983 + + Args: + cosine_kwargs: An Iterable of dicts, where each element specifies the + arguments to pass to each cosine decay cycle. The `decay_steps` kwarg + will specify how long each cycle lasts for, and therefore when to + transition to the next cycle. + Returns: + schedule: A function that maps step counts to values. + """ + boundaries = [] + schedules = [] + step = 0 + for kwargs in cosine_kwargs: + schedules += [warmup_cosine_decay_schedule(**kwargs)] + boundaries += [step + kwargs['decay_steps']] + step += kwargs['decay_steps'] + return join_schedules(schedules, boundaries[:-1]) + + +def _convert_floats(x, dtype): + """Convert float-like inputs to dtype, rest pass through.""" + if jax.dtypes.scalar_type_of(x) == float: + return jnp.asarray(x, dtype=dtype) + return x + + +class InjectHyperparamsState(NamedTuple): + """Maintains inner transform state, hyperparameters, and step count.""" + count: jnp.ndarray # shape=(), dtype=jnp.int32 + hyperparams: Dict[str, chex.Numeric] + inner_state: base.OptState + + +def inject_hyperparams( + inner_factory: Callable[..., base.GradientTransformation], + static_args: Union[str, Iterable[str]] = (), + hyperparam_dtype: Optional[jnp.dtype] = None, +) -> Callable[..., base.GradientTransformation]: + """Wrapper that injects hyperparameters into the inner GradientTransformation. + + This wrapper allows you to pass schedules (i.e. a function that returns a + numeric value given a step count) instead of constants for + hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean + flags cannot be scheduled). + + For example, to use ``scale_by_adam`` with a piecewise linear + schedule for beta_1 and constant for beta_2:: + + scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)( + b1=optax.piecewise_linear_schedule(...), + b2=0.99) + + You may manually change numeric hyperparameters that were not scheduled + through the ``hyperparams`` dict in the ``InjectHyperparamState``:: + + state = scheduled_adam.init(params) + updates, state = scheduled_adam.update(grads, state) + state.hyperparams['b2'] = 0.95 + updates, state = scheduled_adam.update(updates, state) # uses b2 = 0.95 + + Manually overriding scheduled hyperparameters will have no effect (e.g. + in the code sample above, you cannot manually adjust ``b1``). + + Args: + inner_factory: a function that returns the inner + ``optax.GradientTransformation`` given the hyperparameters. + static_args: a string or iterable of strings specifying which + callable parameters are not schedules. inject_hyperparams treats all + callables as schedules by default, so if a hyperparameter is a + non-schedule callable, you must specify that using this argument. + hyperparam_dtype: Optional datatype override. If specified, all float + hyperparameters will be cast to this type. + + Returns: + A callable that returns a ``optax.GradientTransformation``. This callable + accepts the same arguments as ``inner_factory``, except you may provide + schedules in place of the constant arguments. + """ + static_args = ({static_args} if isinstance(static_args, str) else + set(static_args)) + inner_signature = inspect.signature(inner_factory) + + if not static_args.issubset(inner_signature.parameters): + raise ValueError( + '`static_args` must specify a subset of `inner_factory`\'s parameters. ' + f'Given `static_args`: {static_args}. `inner_factory` parameters: ' + f'{set(inner_signature.parameters.keys())}') + + @functools.wraps(inner_factory) + def wrapped_transform(*args, **kwargs) -> base.GradientTransformation: + bound_arguments = inner_signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + + sched_hps, numeric_hps, other_hps = {}, {}, {} + for name, value in bound_arguments.arguments.items(): + if name in static_args or isinstance(value, bool): + other_hps[name] = value + elif callable(value): + sched_hps[name] = value + elif isinstance(value, (int, float, chex.Array)): + numeric_hps[name] = value + else: + other_hps[name] = value + + def schedule_fn(count, dtype): + return {k: _convert_floats(f(count), dtype) for k, f in sched_hps.items()} + + def init_fn(params): + count = jnp.zeros([], jnp.int32) + if hyperparam_dtype is None: + dtype = getattr(next(iter( + jax.tree_util.tree_leaves(params)), None), 'dtype', None) + else: + dtype = hyperparam_dtype + hparams = { + k: jnp.asarray(_convert_floats(v, dtype)) + for k, v in numeric_hps.items()} + hparams.update(schedule_fn(count, dtype)) + return InjectHyperparamsState( # pylint:disable=too-many-function-args + count, hparams, inner_factory(**other_hps, **hparams).init(params)) + + def update_fn(updates, state, params=None): + if hyperparam_dtype is None: + dtype = getattr(next(iter( + jax.tree_util.tree_leaves(updates)), None), 'dtype', None) + else: + dtype = hyperparam_dtype + hparams = {k: _convert_floats(v, dtype) + for k, v in state.hyperparams.items()} + hparams.update(schedule_fn(state.count, dtype)) + updates, inner_state = inner_factory(**other_hps, **hparams).update( + updates, state.inner_state, params) + count_inc = numerics.safe_int32_increment(state.count) + + # pylint:disable=too-many-function-args + return updates, InjectHyperparamsState(count_inc, hparams, inner_state) + # pylint:enable=too-many-function-args + + return base.GradientTransformation(init_fn, update_fn) + + return wrapped_transform diff --git a/optax_add_eve/_src/schedule_test.py b/optax_add_eve/_src/schedule_test.py new file mode 100644 index 00000000..a862c442 --- /dev/null +++ b/optax_add_eve/_src/schedule_test.py @@ -0,0 +1,649 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `schedule.py`.""" + +import functools + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import clipping +from optax_add_eve._src import schedule +from optax_add_eve._src import transform +from optax_add_eve._src import wrappers + + +class ConstantTest(chex.TestCase): + + @chex.all_variants + def test_constant(self): + """Check constant schedule.""" + # Get schedule function. + const_value = 10 + num_steps = 15 + schedule_fn = self.variant(schedule.constant_schedule(const_value)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(num_steps): + # Compute next value. + generated_vals.append(schedule_fn(count)) + # Test output. + expected_vals = np.array([const_value] * num_steps, dtype=np.float32) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + +class PolynomialTest(chex.TestCase): + + @chex.all_variants + def test_linear(self): + """Check linear schedule.""" + # Get schedule function. + schedule_fn = self.variant( + schedule.polynomial_schedule( + init_value=10., end_value=20., power=1, transition_steps=10)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(15): + # Compute next value. + generated_vals.append(schedule_fn(count)) + # Test output. + expected_vals = np.array(list(range(10, 20)) + [20] * 5, dtype=np.float32) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + @chex.all_variants + def test_zero_steps_schedule(self): + # Get schedule function. + initial_value = 10. + end_value = 20. + + for num_steps in [-1, 0]: + schedule_fn = self.variant( + schedule.polynomial_schedule( + init_value=initial_value, end_value=end_value, + power=1, transition_steps=num_steps)) + for count in range(15): + np.testing.assert_allclose(schedule_fn(count), initial_value) + + @chex.all_variants + def test_nonlinear(self): + """Check non-linear (quadratic) schedule.""" + # Get schedule function. + schedule_fn = self.variant( + schedule.polynomial_schedule( + init_value=25., end_value=10., power=2, transition_steps=10)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(15): + # Compute next value. + generated_vals.append(schedule_fn(count)) + # Test output. + expected_vals = np.array( + [10. + 15. * (1. - n / 10)**2 for n in range(10)] + [10] * 5, + dtype=np.float32) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + @chex.all_variants + def test_with_decay_begin(self): + """Check quadratic schedule with non-zero schedule begin.""" + # Get schedule function. + schedule_fn = self.variant( + schedule.polynomial_schedule( + init_value=30., end_value=10., power=2, + transition_steps=10, transition_begin=4)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(20): + # Compute next value. + generated_vals.append(schedule_fn(count)) + # Test output. + expected_vals = np.array( + [30.] * 4 + [10. + 20. * (1. - n / 10)**2 for n in range(10)] + + [10] * 6, + dtype=np.float32) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + +class PiecewiseConstantTest(chex.TestCase): + + @chex.all_variants + def test_positive(self): + """Check piecewise constant schedule of positive values.""" + # Get schedule function. + schedule_fn = self.variant( + schedule.piecewise_constant_schedule(0.1, {3: 2., 6: 0.5})) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(10): + # Compute next value. + generated_vals.append(schedule_fn(count)) + # Test output. + expected_vals = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1]) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + @chex.all_variants + def test_negative(self): + """Check piecewise constant schedule of negative values.""" + # Get schedule function. + schedule_fn = self.variant( + schedule.piecewise_constant_schedule(-0.1, {3: 2., 6: 0.5})) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(10): + # Compute next value. + generated_vals.append(schedule_fn(count)) + # Test output. + expected_vals = -1 * np.array( + [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1]) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + +class ExponentialTest(chex.TestCase): + + @chex.all_variants + @parameterized.parameters(False, True) + def test_constant_schedule(self, staircase): + """Checks constant schedule for exponential decay schedule.""" + num_steps = 15 + # Get schedule function. + init_value = 1. + schedule_fn = self.variant( + schedule.exponential_decay( + init_value=init_value, transition_steps=num_steps, + decay_rate=1., staircase=staircase)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(num_steps): + generated_vals.append(schedule_fn(count)) + expected_vals = np.array([init_value] * num_steps, dtype=np.float32) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + @chex.all_variants + @parameterized.parameters(False, True) + def test_nonvalid_transition_steps(self, staircase): + """Checks nonvalid decay steps results in a constant schedule.""" + init_value = 1. + for transition_steps in [-1, 0]: + schedule_fn = self.variant( + schedule.exponential_decay( + init_value=init_value, transition_steps=transition_steps, + decay_rate=1., staircase=staircase)) + for count in range(15): + np.testing.assert_allclose(schedule_fn(count), init_value) + + @chex.all_variants + @parameterized.parameters(False, True) + def test_nonvalid_decay_rate(self, staircase): + """Checks nonvalid decay steps results in a constant schedule.""" + init_value = 1. + schedule_fn = self.variant( + schedule.exponential_decay( + init_value=init_value, transition_steps=2, + decay_rate=0., staircase=staircase)) + for count in range(15): + np.testing.assert_allclose(schedule_fn(count), init_value) + + @chex.all_variants + @parameterized.parameters((False, 0), (True, 0), (False, 5), (True, 5)) + def test_exponential(self, staircase, transition_begin): + """Checks non-linear (quadratic) schedule.""" + # Get schedule function. + init_value = 1. + num_steps = 15 + transition_steps = 2 + decay_rate = 2. + schedule_fn = self.variant( + schedule.exponential_decay( + init_value=init_value, transition_steps=transition_steps, + decay_rate=decay_rate, transition_begin=transition_begin, + staircase=staircase)) + + # Test that generated values equal the expected schedule values. + def _staircased(count): + p = count / transition_steps + if staircase: + p = np.floor(p) + return p + + generated_vals = [] + for count in range(num_steps + transition_begin): + generated_vals.append(schedule_fn(count)) + expected_vals = np.array( + [init_value] * transition_begin + [ + init_value * np.power(decay_rate, _staircased(count)) + for count in range(num_steps) + ], + dtype=np.float32) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + @chex.all_variants + @parameterized.parameters( + (0.2, 0.1, False), (1.0, 0.1, False), (2.0, 3.0, False), + (0.2, 0.1, True), (1.0, 0.1, True), (2.0, 3.0, True)) + def test_end_value_with_staircase(self, decay_rate, end_value, staircase): + # Get schedule function. + init_value = 1. + num_steps = 11 + transition_steps = 2 + transition_begin = 3 + schedule_fn = self.variant( + schedule.exponential_decay( + init_value=init_value, transition_steps=transition_steps, + decay_rate=decay_rate, transition_begin=transition_begin, + staircase=staircase, end_value=end_value)) + + # Test that generated values equal the expected schedule values. + def _staircased(count): + p = count / transition_steps + if staircase: + p = np.floor(p) + return p + + generated_vals = [] + for count in range(num_steps + transition_begin): + generated_vals.append(schedule_fn(count)) + expected_vals = np.array( + [init_value] * transition_begin + [ + init_value * np.power(decay_rate, _staircased(count)) + for count in range(num_steps) + ], + dtype=np.float32) + + if decay_rate < 1.0: + expected_vals = np.maximum(expected_vals, end_value) + else: + expected_vals = np.minimum(expected_vals, end_value) + + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + @chex.all_variants + def test_immutable_count(self): + """Checks constant schedule for exponential decay schedule.""" + num_steps = 5 + # Get schedule function. + init_value = 32. + schedule_fn = self.variant( + schedule.exponential_decay( + init_value=init_value, transition_steps=1, + decay_rate=0.5)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(num_steps): + # Jax arrays are read-only in ChexVariantType.WITHOUT_DEVICE. + immutable_count = jnp.array(count, dtype=jnp.float32) + generated_vals.append(schedule_fn(immutable_count)) + expected_vals = np.array([32, 16, 8, 4, 2], dtype=np.float32) + np.testing.assert_allclose( + expected_vals, np.array(generated_vals), atol=1e-3) + + +class CosineDecayTest(chex.TestCase): + + @chex.all_variants + def test_decay_count_smaller_count(self): + """Check cosine schedule decay for the entire training schedule.""" + initial_value = 0.1 + schedule_fn = self.variant( + schedule.cosine_decay_schedule(initial_value, 10, 0.0)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(10): + # Compute next value. + generated_vals.append(schedule_fn(count)) + # Test output. + expected_multipliers = np.array( + 0.5 + 0.5 * np.cos( + np.pi * np.array( + [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]))) + np.testing.assert_allclose( + initial_value * expected_multipliers, + np.array(generated_vals), atol=1e-3) + + @chex.all_variants + def test_decay_count_greater_count(self): + """Check cosine schedule decay for a part of the training schedule.""" + initial_value = 0.1 + schedule_fn = self.variant( + schedule.cosine_decay_schedule(initial_value, 5, 0.0)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(12): + # Compute next value. + generated_vals.append(schedule_fn(count)) + + # Test output. + expected_multipliers = np.array( + 0.5 + 0.5 * np.cos( + np.pi * np.array( + [0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.]))) + np.testing.assert_allclose( + initial_value * expected_multipliers, + np.array(generated_vals), atol=1e-3) + + @chex.all_variants + def test_decay_count_greater_count_with_alpha(self): + """Check cosine schedule decay for a part of the training schedule.""" + # Get schedule function. + initial_value = 0.1 + schedule_fn = self.variant( + schedule.cosine_decay_schedule(initial_value, 5, 0.1)) + # Test that generated values equal the expected schedule values. + generated_vals = [] + for count in range(12): + # Compute next value. + generated_vals.append(schedule_fn(count)) + + # Test output. + expected_multipliers = np.array( + 0.5 + 0.5 * np.cos( + np.pi * np.array( + [0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.]))) + expected_multipliers = 0.9 * expected_multipliers + 0.1 + np.testing.assert_allclose( + initial_value * expected_multipliers, + np.array(generated_vals), atol=1e-3) + + +class WarmupCosineDecayTest(chex.TestCase): + + @chex.all_variants + @parameterized.named_parameters( + ('with end value', 10, 0.5, 1e-4), + ('without end value', 5, 3, 0.),) + def test_limits(self, init_value, peak_value, end_value): + """Check cosine schedule decay for the entire training schedule.""" + schedule_fn = self.variant(schedule.warmup_cosine_decay_schedule( + init_value=init_value, + peak_value=peak_value, + warmup_steps=100, + decay_steps=1000, + end_value=end_value, + )) + + np.testing.assert_allclose(init_value, schedule_fn(0)) + np.testing.assert_allclose(peak_value, schedule_fn(100)) + np.testing.assert_allclose(end_value, schedule_fn(1000), rtol=1e-3) + + +class SGDRTest(chex.TestCase): + + @chex.all_variants + @parameterized.named_parameters( + ('with step decay', 1.6, 0.8, 0.4), + ('without step_decay', 1.6, 1.6, 1.6),) + def test_limits(self, lr0, lr1, lr2): + """Check cosine schedule decay for the entire training schedule.""" + lr_kwargs = [] + for step, lr in zip([2e3, 3e3, 5e3], [lr0, lr1, lr2]): + lr_kwargs += [dict(decay_steps=int(step), peak_value=lr, + init_value=0, end_value=0.0, warmup_steps=500)] + schedule_fn = self.variant(schedule.sgdr_schedule(lr_kwargs)) + np.testing.assert_allclose(lr0, schedule_fn(500)) + np.testing.assert_allclose(lr1, schedule_fn(2500)) + np.testing.assert_allclose(lr2, schedule_fn(5500)) + + +class PiecewiseInterpolateTest(chex.TestCase): + + @chex.all_variants + def test_linear_piecewise(self): + schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( + 'linear', 200., {5: 1.5, 10: 0.25})) + generated_vals = [schedule_fn(step) for step in range(13)] + expected_vals = [200., 220., 240., 260., 280., 300., 255., 210., 165., + 120., 75., 75., 75.] + np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) + + @chex.all_variants + def test_cos_piecewise(self): + schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( + 'cosine', 400., {5: 1.2, 3: 0.6, 7: 1.})) + generated_vals = [schedule_fn(step) for step in range(9)] + expected_vals = [400., 360., 280., 240., 264., 288., 288., 288., 288.] + np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) + + @chex.all_variants + def test_empty_dict(self): + schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( + 'linear', 13., {})) + generated_vals = [schedule_fn(step) for step in range(5)] + expected_vals = [13., 13., 13., 13., 13.] + np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) + + @chex.all_variants + def test_no_dict(self): + schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( + 'cosine', 17.)) + generated_vals = [schedule_fn(step) for step in range(3)] + expected_vals = [17., 17., 17.] + np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) + + def test_invalid_type(self): + # pytype: disable=wrong-arg-types + with self.assertRaises(ValueError): + schedule.piecewise_interpolate_schedule('linar', 13.) + with self.assertRaises(ValueError): + schedule.piecewise_interpolate_schedule('', 13., {5: 3.}) + with self.assertRaises(ValueError): + schedule.piecewise_interpolate_schedule(None, 13., {}) + # pytype: enable=wrong-arg-types + + def test_invalid_scale(self): + with self.assertRaises(ValueError): + schedule.piecewise_interpolate_schedule('linear', 13., {5: -3}) + + +class OneCycleTest(chex.TestCase): + + @chex.all_variants + def test_linear(self): + schedule_fn = self.variant(schedule.linear_onecycle_schedule( + transition_steps=10, + peak_value=1000, + pct_start=0.3, + pct_final=0.7, + div_factor=10., + final_div_factor=100.)) + + generated_vals = [schedule_fn(step) for step in range(12)] + expected_vals = [100., 400., 700., 1000., 775., 550., 325., 100., 67., + 34., 1., 1.] + np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) + + @chex.all_variants + def test_cosine(self): + schedule_fn = self.variant(schedule.cosine_onecycle_schedule( + transition_steps=5, + peak_value=1000., + pct_start=0.4, + div_factor=10., + final_div_factor=100.)) + + generated_vals = [schedule_fn(step) for step in range(7)] + expected_vals = [100., 550., 1000., 750.25, 250.75, 1., 1.] + np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) + + def test_nonpositive_transition_steps(self): + with self.assertRaises(ValueError): + schedule.cosine_onecycle_schedule(transition_steps=0, peak_value=5.) + with self.assertRaises(ValueError): + schedule.linear_onecycle_schedule(transition_steps=0, peak_value=5.) + + +class InjectHyperparamsTest(chex.TestCase): + """Tests for the inject_hyperparams wrapper.""" + + @chex.all_variants + def test_updates(self): + optim = schedule.inject_hyperparams(transform.scale)( # stateless + step_size=schedule.piecewise_constant_schedule( + 3.0, {1: 5, 7: 2, 12: 1.5})) + + params = [jnp.zeros([], dtype=jnp.float32)] + state = self.variant(optim.init)(params) + update_fn = self.variant(optim.update) + expected_step_size = [3.0]*2 + [15.0]*6 + [30.0]*5 + [45.0]*3 + + grads = [jnp.ones([], dtype=jnp.float32)] + for i in range(15): + updates, state = update_fn(grads, state, params=params) + np.testing.assert_almost_equal(updates[0], expected_step_size[i+1]) + + @chex.all_variants + def test_hyperparams_state(self): + optim = schedule.inject_hyperparams(transform.trace)( # stateful + decay=schedule.piecewise_constant_schedule( + 0.8, {3: 0.5, 9: 1.25}), + nesterov=True) + + params = [jnp.zeros([2, 3]) for _ in range(3)] + state = self.variant(optim.init)(params) + update_fn = self.variant(optim.update) + + expected_mom = [0.8]*4 + [0.4]*6 + [0.5]*2 + grads = jax.tree_util.tree_map(jnp.ones_like, params) + for i in range(12): + np.testing.assert_almost_equal(state.hyperparams['decay'], + expected_mom[i]) + _, state = update_fn(grads, state) + + np.testing.assert_almost_equal(state.hyperparams['decay'], + expected_mom[-1]) + + @chex.all_variants + def test_constant_hyperparams(self): + optim = schedule.inject_hyperparams(transform.scale_by_adam)(b1=0., b2=0.) + + params = [jnp.zeros([2, 3]) for _ in range(3)] + state = self.variant(optim.init)(params) + update_fn = self.variant(optim.update) + + grads = jax.tree_util.tree_map(jnp.ones_like, params) + for _ in range(5): + updates, state = update_fn(grads, state, params) + np.testing.assert_almost_equal(state.hyperparams['b1'], 0.0) + np.testing.assert_almost_equal(state.hyperparams['b2'], 0.0) + np.testing.assert_almost_equal(state.hyperparams['eps'], 1e-8) + np.testing.assert_almost_equal(state.hyperparams['eps_root'], 0.0) + assert 'eps' in state.hyperparams + chex.assert_trees_all_close(updates, grads) + + @chex.all_variants + def test_overriding_hyperparam(self): + optim = schedule.inject_hyperparams(clipping.clip_by_global_norm)(0.1) + params = jnp.zeros((3, 5, 7)) + state = self.variant(optim.init)(params) + update_fn = self.variant(optim.update) + + grads = jnp.ones_like(params) + for i in range(5): + state.hyperparams['max_norm'] = i + updates, state = update_fn(grads, state) + assert np.isclose(jnp.linalg.norm(updates.ravel()), i) + + @chex.all_variants + @parameterized.named_parameters(('string', 'mask'), ('list', ['mask'])) + def test_static_args(self, static_args): + @functools.partial(schedule.inject_hyperparams, static_args=static_args) + def custom_optim(learning_rate, mask): + return wrappers.masked(transform.scale(-learning_rate), mask) + + optim = custom_optim( + 0.1, functools.partial(jax.tree_util.tree_map, lambda x: x.ndim > 1)) + params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))] + grads = params + state = self.variant(optim.init)(params) + updates, state = self.variant(optim.update)(grads, state) + expected_updates = jax.tree_util.tree_map( + lambda x: -0.1 * x if x.ndim > 1 else x, grads) + + assert set(state.hyperparams.keys()) == {'learning_rate'}, state.hyperparams + chex.assert_trees_all_close(updates, expected_updates) + + @chex.all_variants + @parameterized.named_parameters(('one_arg', 'b1'), ('two_arg', ['b1', 'b2'])) + def test_numeric_static_args(self, static_args): + optim = schedule.inject_hyperparams( + transform.scale_by_adam, static_args=static_args)(b1=0.9, b2=0.95) + + params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))] + grads = params + state = self.variant(optim.init)(params) + _, state = self.variant(optim.update)(grads, state) + + assert not set(state.hyperparams.keys()).intersection(set(static_args)) + + @chex.all_variants + @parameterized.named_parameters( + ('bf16hyp f32param bf16grad', jnp.bfloat16, jnp.float32, jnp.bfloat16), + ('bf16hyp f32param f32_grads', jnp.bfloat16, jnp.float32, jnp.float32), + ('f32hyp bf16param bf16grad', jnp.float32, jnp.bfloat16, jnp.bfloat16), + ('f32hyp f32param bf16grad', jnp.float32, jnp.float32, jnp.bfloat16), + ('f32hyp bf16param f32grad', jnp.float32, jnp.bfloat16, jnp.float32), + ) + def test_hyperparam_dtypes(self, + hyperparam_dtype, + param_dtype, + grad_dtype): + """Tests that hyperparam dtype override works as desired.""" + optim = schedule.inject_hyperparams( + transform.scale_by_adam, + hyperparam_dtype=hyperparam_dtype)(b1=0.9, b2=0.95) + + params = [jnp.ones((1, 2), dtype=param_dtype), + jnp.ones(2, dtype=param_dtype), + jnp.ones((1, 1, 1), dtype=param_dtype)] + grads = jax.tree_map(lambda x: x.astype(grad_dtype), params) + state = self.variant(optim.init)(params) + # Check that the hyperparams are overriden + self.assertEqual(state.hyperparams['b1'].dtype, hyperparam_dtype) + self.assertEqual(state.hyperparams['b2'].dtype, hyperparam_dtype) + + _, state = self.variant(optim.update)(grads, state) + + self.assertEqual(state.hyperparams['b1'].dtype, hyperparam_dtype) + self.assertEqual(state.hyperparams['b2'].dtype, hyperparam_dtype) + + @parameterized.named_parameters(('string', 'lr'), ('list', ['lr'])) + def test_static_args_error(self, static_args): + with self.assertRaises(ValueError): + schedule.inject_hyperparams(transform.scale, static_args=static_args) + + @chex.all_variants + def test_inject_hyperparams_starts_with_step_count_zero(self): + """Checks that inject_hyperparams uses step count 0 in the first update.""" + # See also: https://github.com/deepmind/optax/issues/415. + opt = schedule.inject_hyperparams(transform.scale)(lambda count: count) + params = jnp.zeros(3) + grads = jnp.array([-1, 0, 1]) + updates, _ = self.variant(opt.update)(grads, opt.init(params)) + np.testing.assert_array_equal(updates, np.zeros(3)) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/second_order_test.py b/optax_add_eve/_src/second_order_test.py new file mode 100644 index 00000000..820f1ed8 --- /dev/null +++ b/optax_add_eve/_src/second_order_test.py @@ -0,0 +1,93 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `second_order.py`.""" + +import collections +import functools +import itertools + +from absl.testing import absltest + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import second_order + + +NUM_CLASSES = 2 +NUM_SAMPLES = 3 +NUM_FEATURES = 4 + + +class SecondOrderTest(chex.TestCase): + + def setUp(self): + super().setUp() + + self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES) + self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES) + + def net_fn(z): + mlp = hk.Sequential( + [hk.Linear(10), jax.nn.relu, hk.Linear(NUM_CLASSES)], name='mlp') + return jax.nn.log_softmax(mlp(z)) + + net = hk.without_apply_rng(hk.transform(net_fn)) + self.parameters = net.init(jax.random.PRNGKey(0), self.data) + + def loss(params, inputs, targets): + log_probs = net.apply(params, inputs) + return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs) + + self.loss_fn = loss + + def jax_hessian_diag(loss_fun, params, inputs, targets): + """This is the 'ground-truth' obtained via the JAX library.""" + hess = jax.hessian(loss_fun)(params, inputs, targets) + + # Extracts the diagonal components. + hess_diag = collections.defaultdict(dict) + for k0, k1 in itertools.product(params.keys(), ['w', 'b']): + params_shape = params[k0][k1].shape + n_params = np.prod(params_shape) + hess_diag[k0][k1] = jnp.diag(hess[k0][k1][k0][k1].reshape( + n_params, n_params)).reshape(params_shape) + for k, v in hess_diag.items(): + hess_diag[k] = v + return second_order.ravel(hess_diag) + + self.hessian = jax_hessian_diag( + self.loss_fn, self.parameters, self.data, self.labels) + + @chex.all_variants + def test_hessian_diag(self): + hessian_diag_fn = self.variant( + functools.partial(second_order.hessian_diag, self.loss_fn)) + actual = hessian_diag_fn(self.parameters, self.data, self.labels) + np.testing.assert_array_almost_equal(self.hessian, actual, 5) + + @chex.all_variants + def test_fisher_diag_shape(self): + fisher_diag_fn = self.variant( + functools.partial(second_order.fisher_diag, self.loss_fn)) + fisher_diagonal = fisher_diag_fn(self.parameters, self.data, self.labels) + chex.assert_equal_shape([fisher_diagonal, self.hessian]) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/stochastic_gradient_estimators.py b/optax_add_eve/_src/stochastic_gradient_estimators.py new file mode 100644 index 00000000..82d0d0f5 --- /dev/null +++ b/optax_add_eve/_src/stochastic_gradient_estimators.py @@ -0,0 +1,317 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Stochastic Monte Carlo gradient estimators. + +Utility functions to approximate gradients of the form using Monte Carlo +estimation: + \nabla_{\theta} E_{p(x; \theta)} f(x) + +Here f is assumed to have no dependence on the parameters theta - if f has +dependence on theta, the functions below need to be called with `stop_grad(f)` +and the chain rule needs to be applied outside these functions in order +to obtain unbiased gradient. + +For more details, see: +S. Mohamed, M. Rosca, M. Figurnov, A Mnih. + Monte Carlo Gradient Estimation in Machine Learning. JMLR, 2020. +""" + +import math +from typing import Any, Callable, Sequence + +import chex +import jax +import jax.numpy as jnp +import numpy as np +from optax_add_eve._src import base +from optax_add_eve._src import utils + + +def score_function_jacobians( + function: Callable[[chex.Array], float], + params: base.Params, + dist_builder: Callable[..., Any], + rng: chex.PRNGKey, + num_samples: int) -> Sequence[chex.Array]: + r"""Score function gradient estimation. + + Approximates: + \nabla_{\theta} E_{p(x; \theta)} f(x) + With: + E_{p(x; \theta)} f(x) \nabla_{\theta} \log p(x; \theta) + + Requires: p to be differentiable wrt to theta. Applicable to both continuous + and discrete random variables. No requirements on f. + + Args: + function: Function f(x) for which to estimate grads_{params} E_dist f(x). + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + params: A tuple of jnp arrays. + The parameters for which to construct the distribution. + dist_builder: a constructor which builds a distribution given the input + parameters specified by params. `dist_builder(params)` should return a + valid distribution. + rng: a PRNGKey key. + num_samples: Int, the number of samples used to compute the grads. + + Returns: + A tuple of size `params`, each element is `num_samples x param.shape` + jacobian vector containing the estimates of the gradients obtained for + each sample. + The mean of this vector is the gradient wrt to parameters that can be used + for learning. The entire jacobian vector can be used to assess estimator + variance. + """ + def surrogate(params): + dist = dist_builder(*params) + one_sample_surrogate_fn = lambda x: function(x) * dist.log_prob(x) + samples = jax.lax.stop_gradient(dist.sample((num_samples,), seed=rng)) + # We vmap the function application over samples - this ensures that the + # function we use does not have to be vectorized itself. + return jax.vmap(one_sample_surrogate_fn)(samples) + + return jax.jacfwd(surrogate)(params) + + +def pathwise_jacobians( + function: Callable[[chex.Array], float], + params: base.Params, + dist_builder: Callable[..., Any], + rng: chex.PRNGKey, + num_samples: int) -> Sequence[chex.Array]: + r"""Pathwise gradient estimation. + + Approximates: + \nabla_{\theta} E_{p(x; \theta)} f(x) + With: + E_{p(\epsilon)} \nabla_{\theta} f(g(\epsilon, \theta)) + where x = g(\epsilon, \theta). g depends on the distribution p. + + Requires: p to be reparametrizable and the reparametrization to be implemented + in tensorflow_probability. Applicable to continuous random variables. + f needs to be differentiable. + + Args: + function: Function f(x) for which to estimate grads_{params} E_dist f(x). + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + params: A tuple of jnp arrays. + The parameters for which to construct the distribution. + dist_builder: a constructor which builds a distribution given the input + parameters specified by params. `dist_builder(params)` should return a + valid distribution. + rng: a PRNGKey key. + num_samples: Int, the number of samples used to compute the grads. + + Returns: + A tuple of size `params`, each element is `num_samples x param.shape` + jacobian vector containing the estimates of the gradients obtained for + each sample. + The mean of this vector is the gradient wrt to parameters that can be used + for learning. The entire jacobian vector can be used to assess estimator + variance. + """ + def surrogate(params): + # We vmap the function application over samples - this ensures that the + # function we use does not have to be vectorized itself. + dist = dist_builder(*params) + return jax.vmap(function)(dist.sample((num_samples,), seed=rng)) + + return jax.jacfwd(surrogate)(params) + + +def measure_valued_jacobians( + function: Callable[[chex.Array], float], + params: base.Params, + dist_builder: Callable[..., Any], + rng: chex.PRNGKey, + num_samples: int, + coupling: bool = True) -> Sequence[chex.Array]: + r"""Measure valued gradient estimation. + + Approximates: + \nabla_{\theta} E_{p(x; \theta)} f(x) + With: + 1./ c (E_{p1(x; \theta)} f(x) - E_{p2(x; \theta)} f(x)) where p1 and p2 are + measures which depend on p. + + Currently only supports computing gradients of expectations of Gaussian RVs. + + Args: + function: Function f(x) for which to estimate grads_{params} E_dist f(x). + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + params: A tuple of jnp arrays. + The parameters for which to construct the distribution. + dist_builder: a constructor which builds a distribution given the input + parameters specified by params. `dist_builder(params)` should return a + valid distribution. + rng: a PRNGKey key. + num_samples: Int, the number of samples used to compute the grads. + coupling: A boolean. Whether or not to use coupling for the positive and + negative samples. Recommended: True, as this reduces variance. + + Returns: + A tuple of size `params`, each element is `num_samples x param.shape` + jacobian vector containing the estimates of the gradients obtained for + each sample. + The mean of this vector is the gradient wrt to parameters that can be used + for learning. The entire jacobian vector can be used to assess estimator + variance. + """ + if dist_builder is not utils.multi_normal: + raise ValueError( + 'Unsupported distribution builder for measure_valued_jacobians!') + dist = dist_builder(*params) + # Need to apply chain rule for log scale grad (instead of scale grad). + return [ + measure_valued_estimation_mean( + function, dist, rng, num_samples, coupling=coupling), + jnp.exp(dist.log_scale) * measure_valued_estimation_std( + function, dist, rng, num_samples, coupling=coupling)] + + +def measure_valued_estimation_mean( + function: Callable[[chex.Array], float], + dist: Any, + rng: chex.PRNGKey, + num_samples: int, + coupling: bool = True) -> chex.Array: + """Measure valued grads of a Gaussian expectation of `function` wrt the mean. + + Args: + function: Function f(x) for which to estimate grads_{mean} E_dist f(x). + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + dist: a distribution on which we can call `sample`. + rng: a PRNGKey key. + num_samples: Int, the number of samples used to compute the grads. + coupling: A boolean. Whether or not to use coupling for the positive and + negative samples. Recommended: True, as this reduces variance. + + Returns: + A `num_samples x D` vector containing the estimates of the gradients + obtained for each sample. The mean of this vector can be used to update + the mean parameter. The entire vector can be used to assess estimator + variance. + """ + mean, log_std = dist.params + std = jnp.exp(log_std) + + dist_samples = dist.sample((num_samples,), seed=rng) + + pos_rng, neg_rng = jax.random.split(rng) + pos_sample = jax.random.weibull_min( + pos_rng, scale=math.sqrt(2.), concentration=2., shape=dist_samples.shape) + + if coupling: + neg_sample = pos_sample + else: + neg_sample = jax.random.weibull_min( + neg_rng, + scale=math.sqrt(2.), + concentration=2., + shape=dist_samples.shape) + + # N x D + positive_diag = mean + std * pos_sample + # N x D + negative_diag = mean - std * neg_sample + + # NOTE: you can sample base samples here if you use the same rng + # Duplicate the D dimension - N x D x D. + base_dist_samples = utils.tile_second_to_last_dim(dist_samples) + positive = utils.set_diags(base_dist_samples, positive_diag) + negative = utils.set_diags(base_dist_samples, negative_diag) + + c = np.sqrt(2 * np.pi) * std # D + # Apply function. We apply the function to each element of N x D x D. + # We apply a function that takes a sample and returns one number, so the + # output will be N x D (which is what we want, batch by dimension). + # We apply a function in parallel to the batch. + # Broadcast the division. + vmaped_function = jax.vmap(jax.vmap(function, 1, 0)) + grads = (vmaped_function(positive) - vmaped_function(negative)) / c + + chex.assert_shape(grads, (num_samples,) + std.shape) + return grads + + +def measure_valued_estimation_std( + function: Callable[[chex.Array], float], + dist: Any, + rng: chex.PRNGKey, + num_samples: int, + coupling: bool = True) -> chex.Array: + """Measure valued grads of a Gaussian expectation of `function` wrt the std. + + Args: + function: Function f(x) for which to estimate grads_{std} E_dist f(x). + The function takes in one argument (a sample from the distribution) and + returns a floating point value. + dist: a distribution on which we can call `sample`. + rng: a PRNGKey key. + num_samples: Int, the number of samples used to compute the grads. + coupling: A boolean. Whether or not to use coupling for the positive and + negative samples. Recommended: True, as this reduces variance. + + Returns: + A `num_samples x D` vector containing the estimates of the gradients + obtained for each sample. The mean of this vector can be used to update + the scale parameter. The entire vector can be used to assess estimator + variance. + """ + mean, log_std = dist.params + std = jnp.exp(log_std) + + dist_samples = dist.sample((num_samples,), seed=rng) + + pos_rng, neg_rng = jax.random.split(rng) + + # The only difference between mean and std gradients is what we sample. + pos_sample = jax.random.double_sided_maxwell( + pos_rng, loc=0.0, scale=1.0, shape=dist_samples.shape) + if coupling: + unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape) + neg_sample = unif_rvs * pos_sample + else: + neg_sample = jax.random.normal(neg_rng, dist_samples.shape) + + # Both need to be positive in the case of the scale. + # N x D + positive_diag = mean + std * pos_sample + # N x D + negative_diag = mean + std * neg_sample + + # NOTE: you can sample base samples here if you use the same rng + # Duplicate the D dimension - N x D x D. + base_dist_samples = utils.tile_second_to_last_dim(dist_samples) + positive = utils.set_diags(base_dist_samples, positive_diag) + negative = utils.set_diags(base_dist_samples, negative_diag) + + # Different C for the scale + c = std # D + # Apply function. We apply the function to each element of N x D x D. + # We apply a function that takes a sample and returns one number, so the + # output will be N x D (which is what we want, batch by dimension). + # We apply a function in parallel to the batch. + # Broadcast the division. + vmaped_function = jax.vmap(jax.vmap(function, 1, 0)) + grads = (vmaped_function(positive) - vmaped_function(negative)) / c + + chex.assert_shape(grads, (num_samples,) + std.shape) + return grads + diff --git a/optax_add_eve/_src/stochastic_gradient_estimators_test.py b/optax_add_eve/_src/stochastic_gradient_estimators_test.py new file mode 100644 index 00000000..e89532d4 --- /dev/null +++ b/optax_add_eve/_src/stochastic_gradient_estimators_test.py @@ -0,0 +1,371 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `stochastic_gradient_estimators.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import stochastic_gradient_estimators as sge +from optax_add_eve._src import utils + + +# Set seed for deterministic sampling. +np.random.seed(42) + + +_estimator_to_num_samples = { + sge.score_function_jacobians: 5 * 10**5, + sge.measure_valued_jacobians: 10**5, + sge.pathwise_jacobians: 5 * 10**4, +} + +_weighted_estimator_to_num_samples = { + sge.score_function_jacobians: 5 * 10**6, + sge.measure_valued_jacobians: 5 * 10**5, + sge.pathwise_jacobians: 5 * 10**4, +} + + +def _ones(dims): + return jnp.ones(shape=(dims), dtype=jnp.float32) + + +def _assert_equal(actual, expected, rtol=1e-2, atol=1e-2): + """Asserts that arrays are equal.""" + # Note: assert_allclose does not check shapes + chex.assert_equal_shape((actual, expected)) + + # We get around the bug https://github.com/numpy/numpy/issues/13801 + zero_indices = np.argwhere(expected == 0) + if not np.all(np.abs(actual[zero_indices]) <= atol): + raise AssertionError(f'Larger than {atol} diff in {actual[zero_indices]}') + + non_zero_indices = np.argwhere(expected != 0) + np.testing.assert_allclose( + np.asarray(actual)[non_zero_indices], + expected[non_zero_indices], rtol, atol) + + +def _estimator_variant(variant, estimator): + return variant(estimator, static_argnums=(0, 2, 4)) + + +def _measure_valued_variant(variant): + return variant( + sge.measure_valued_jacobians, + static_argnums=(0, 2, 4, 5)) + + +class GradientEstimatorsTest(chex.TestCase): + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('0.1', 0.1), + ('0.5', 0.5), + ('0.9', 0.9), + ], + named=True)) + def testConstantFunction(self, estimator, constant): + data_dims = 3 + num_samples = _estimator_to_num_samples[estimator] + + effective_mean = 1.5 + mean = effective_mean * _ones(data_dims) + + effective_log_scale = 0.0 + log_scale = effective_log_scale * _ones(data_dims) + rng = jax.random.PRNGKey(1) + + jacobians = _estimator_variant(self.variant, estimator)( + lambda x: jnp.array(constant), [mean, log_scale], + utils.multi_normal, rng, num_samples) + + # Average over the number of samples. + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads = np.mean(mean_jacobians, axis=0) + expected_mean_grads = np.zeros(data_dims, dtype=np.float32) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads = np.mean(log_scale_jacobians, axis=0) + expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) + + _assert_equal(mean_grads, expected_mean_grads, atol=5e-3) + _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-3) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('0.5_-1.', 0.5, -1.), + ('0.7_0.0)', 0.7, 0.0), + ('0.8_0.1', 0.8, 0.1), + ], + named=True)) + def testLinearFunction(self, estimator, effective_mean, effective_log_scale): + data_dims = 3 + num_samples = _estimator_to_num_samples[estimator] + rng = jax.random.PRNGKey(1) + + mean = effective_mean * _ones(data_dims) + log_scale = effective_log_scale * _ones(data_dims) + + jacobians = _estimator_variant(self.variant, estimator)( + np.sum, [mean, log_scale], + utils.multi_normal, rng, num_samples) + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads = np.mean(mean_jacobians, axis=0) + expected_mean_grads = np.ones(data_dims, dtype=np.float32) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads = np.mean(log_scale_jacobians, axis=0) + expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) + + _assert_equal(mean_grads, expected_mean_grads) + _assert_equal(log_scale_grads, expected_log_scale_grads) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('1.0_0.3', 1.0, 0.3), + ], + named=True)) + def testQuadraticFunction( + self, estimator, effective_mean, effective_log_scale): + data_dims = 3 + num_samples = _estimator_to_num_samples[estimator] + rng = jax.random.PRNGKey(1) + + mean = effective_mean * _ones(data_dims) + log_scale = effective_log_scale * _ones(data_dims) + + jacobians = _estimator_variant(self.variant, estimator)( + lambda x: np.sum(x**2) / 2, [mean, log_scale], + utils.multi_normal, rng, num_samples) + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads = np.mean(mean_jacobians, axis=0) + expected_mean_grads = effective_mean * np.ones( + data_dims, dtype=np.float32) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads = np.mean(log_scale_jacobians, axis=0) + expected_log_scale_grads = np.exp(2 * effective_log_scale) * np.ones( + data_dims, dtype=np.float32) + + _assert_equal(mean_grads, expected_mean_grads, atol=5e-2) + _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), + ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), + ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [10., 5., 1.]), + ], + named=True)) + def testWeightedLinear( + self, estimator, effective_mean, effective_log_scale, weights): + num_samples = _weighted_estimator_to_num_samples[estimator] + rng = jax.random.PRNGKey(1) + + mean = jnp.array(effective_mean) + log_scale = jnp.array(effective_log_scale) + weights = jnp.array(weights) + + data_dims = len(effective_mean) + + function = lambda x: jnp.sum(x * weights) + jacobians = _estimator_variant(self.variant, estimator)( + function, [mean, log_scale], + utils.multi_normal, rng, num_samples) + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads = np.mean(mean_jacobians, axis=0) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads = np.mean(log_scale_jacobians, axis=0) + + expected_mean_grads = weights + expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) + + _assert_equal(mean_grads, expected_mean_grads, atol=5e-2) + _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product([ + ('_score_function_jacobians', sge.score_function_jacobians), + ('_pathwise_jacobians', sge.pathwise_jacobians), + ('_measure_valued_jacobians', sge.measure_valued_jacobians), + ], [ + ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), + ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), + ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [3., 5., 1.]), + ], + named=True)) + def testWeightedQuadratic( + self, estimator, effective_mean, effective_log_scale, weights): + num_samples = _weighted_estimator_to_num_samples[estimator] + rng = jax.random.PRNGKey(1) + + mean = jnp.array(effective_mean, dtype=jnp.float32) + log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) + weights = jnp.array(weights, dtype=jnp.float32) + + data_dims = len(effective_mean) + + function = lambda x: jnp.sum(x * weights) ** 2 + jacobians = _estimator_variant(self.variant, estimator)( + function, [mean, log_scale], + utils.multi_normal, rng, num_samples) + + mean_jacobians = jacobians[0] + chex.assert_shape(mean_jacobians, (num_samples, data_dims)) + mean_grads = np.mean(mean_jacobians, axis=0) + + log_scale_jacobians = jacobians[1] + chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) + log_scale_grads = np.mean(log_scale_jacobians, axis=0) + + expected_mean_grads = 2 * weights * np.sum(weights * mean) + effective_scale = np.exp(log_scale) + expected_scale_grads = 2 * weights ** 2 * effective_scale + expected_log_scale_grads = expected_scale_grads * effective_scale + + _assert_equal(mean_grads, expected_mean_grads, atol=1e-1, rtol=1e-1) + _assert_equal( + log_scale_grads, expected_log_scale_grads, atol=1e-1, rtol=1e-1) + + @chex.all_variants + @parameterized.named_parameters( + chex.params_product( + [ + ('_sum_cos_x', [1.0], [1.0], lambda x: jnp.sum(jnp.cos(x))), + # Need to ensure that the mean is not too close to 0. + ('_sum_log_x', [10.0], [0.0], lambda x: jnp.sum(jnp.log(x))), + ('_sum_cos_2x', [1.0, 2.0], [1.0, -2 + ], lambda x: jnp.sum(jnp.cos(2 * x))), + ('_cos_sum_2x', [1.0, 2.0], [1.0, -2 + ], lambda x: jnp.cos(jnp.sum(2 * x))), + ], + [ + ('coupling', True), + ('nocoupling', False), + ], + named=True)) + def testNonPolynomialFunctionConsistencyWithPathwise(self, effective_mean, + effective_log_scale, + function, coupling): + num_samples = 10**5 + rng = jax.random.PRNGKey(1) + measure_rng, pathwise_rng = jax.random.split(rng) + + mean = jnp.array(effective_mean, dtype=jnp.float32) + log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) + data_dims = len(effective_mean) + + measure_valued_jacobians = _measure_valued_variant(self.variant)( + function, [mean, log_scale], + utils.multi_normal, measure_rng, num_samples, coupling) + + measure_valued_mean_jacobians = measure_valued_jacobians[0] + chex.assert_shape(measure_valued_mean_jacobians, (num_samples, data_dims)) + measure_valued_mean_grads = np.mean(measure_valued_mean_jacobians, axis=0) + + measure_valued_log_scale_jacobians = measure_valued_jacobians[1] + chex.assert_shape( + measure_valued_log_scale_jacobians, (num_samples, data_dims)) + measure_valued_log_scale_grads = np.mean( + measure_valued_log_scale_jacobians, axis=0) + + pathwise_jacobians = _estimator_variant( + self.variant, sge.pathwise_jacobians)(function, [mean, log_scale], + utils.multi_normal, pathwise_rng, + num_samples) + + pathwise_mean_jacobians = pathwise_jacobians[0] + chex.assert_shape(pathwise_mean_jacobians, (num_samples, data_dims)) + pathwise_mean_grads = np.mean(pathwise_mean_jacobians, axis=0) + + pathwise_log_scale_jacobians = pathwise_jacobians[1] + chex.assert_shape(pathwise_log_scale_jacobians, (num_samples, data_dims)) + pathwise_log_scale_grads = np.mean(pathwise_log_scale_jacobians, axis=0) + + _assert_equal( + pathwise_mean_grads, measure_valued_mean_grads, rtol=5e-1, atol=1e-1) + _assert_equal( + pathwise_log_scale_grads, measure_valued_log_scale_grads, + rtol=5e-1, atol=1e-1) + + +class MeasuredValuedEstimatorsTest(chex.TestCase): + + @chex.all_variants + @parameterized.parameters([True, False]) + def testRaisesErrorForNonGaussian(self, coupling): + num_samples = 10**5 + rng = jax.random.PRNGKey(1) + + function = lambda x: jnp.sum(x) ** 2 + + mean = jnp.array(0, dtype=jnp.float32) + log_scale = jnp.array(0., dtype=jnp.float32) + + class TestDist(): + + def __init__(self, params): + self._params = params + + def sample(self, n): + return np.zeros(n) + + with self.assertRaises(ValueError): + _measure_valued_variant(self.variant)( + function, [mean, log_scale], + TestDist, rng, num_samples, coupling) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/transform.py b/optax_add_eve/_src/transform.py new file mode 100644 index 00000000..ba4037ac --- /dev/null +++ b/optax_add_eve/_src/transform.py @@ -0,0 +1,1206 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient transformations.""" + +import functools +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base +from optax_add_eve._src import clipping +from optax_add_eve._src import numerics +from optax_add_eve._src import utils +from optax_add_eve._src import wrappers + +# pylint:disable=no-value-for-parameter + +_abs_sq = numerics.abs_sq + + +class TraceState(NamedTuple): + """Holds an aggregation of past updates.""" + trace: base.Params + + +def trace( + decay: float, + nesterov: bool = False, + accumulator_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Compute a trace of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`. + Both are frequently found in the optimization literature. + + Args: + decay: Decay rate for the trace of past updates. + nesterov: Whether to use Nesterov momentum. + accumulator_dtype: Optional `dtype` to be used for the accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + + def init_fn(params): + return TraceState( + trace=jax.tree_util.tree_map( + lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) + + def update_fn(updates, state, params=None): + del params + f = lambda g, t: g + decay * t + new_trace = jax.tree_util.tree_map(f, updates, state.trace) + updates = ( + jax.tree_util.tree_map(f, updates, new_trace) if nesterov + else new_trace) + new_trace = utils.cast_tree(new_trace, accumulator_dtype) + return updates, TraceState(trace=new_trace) + + return base.GradientTransformation(init_fn, update_fn) + + +def update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order`-th moment.""" + return jax.tree_util.tree_map( + lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) + + +def update_infinity_moment(updates, moments, decay, eps): + """Compute the exponential moving average of the infinity norm.""" + return jax.tree_util.tree_map( + lambda g, t: jnp.maximum(jnp.abs(g) + eps, decay * t), updates, moments) + + +def update_moment_per_elem_norm(updates, moments, decay, order): + """Compute the EMA of the `order`-th moment of the element-wise norm.""" + + def orderth_norm(g): + if jnp.isrealobj(g): + return g ** order + else: + half_order = order / 2 + # JAX generates different HLO for int and float `order` + if half_order.is_integer(): + half_order = int(half_order) + return _abs_sq(g) ** half_order + + return jax.tree_util.tree_map( + lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments) + + +@functools.partial(jax.jit, inline=True) +def bias_correction(moment, decay, count): + """Performs bias correction. It becomes a no-op as count goes to infinity.""" + # The conversion to the data type of the moment ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_correction_` is calculated as calculating `decay**count` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + bias_correction_ = 1 - decay**count + + # Perform division in the original precision. + return jax.tree_util.tree_map( + lambda t: t / bias_correction_.astype(t.dtype), moment) + + +def _reject_complex(params): + if any(jnp.iscomplexobj(x) for x in jax.tree_util.tree_leaves(params)): + raise ValueError('This transformation does not support complex parameters.') + + +class EmaState(NamedTuple): + """Holds an exponential moving average of past updates.""" + count: chex.Array # shape=(), dtype=jnp.int32. + ema: base.Params + + +def ema( + decay: float, + debias: bool = True, + accumulator_dtype: Optional[Any] = None +) -> base.GradientTransformation: + """Compute an exponential moving average of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `ema = decay * ema + (1-decay) * t`, while `trace = decay * trace + t`. + Both are frequently found in the optimization literature. + + Args: + decay: Decay rate for the exponential moving average. + debias: Whether to debias the transformed gradient. + accumulator_dtype: Optional `dtype` to used for the accumulator; if `None` + then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + + def init_fn(params): + return EmaState( + count=jnp.zeros([], jnp.int32), + ema=jax.tree_util.tree_map( + lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) + + def update_fn(updates, state, params=None): + del params + updates = new_ema = update_moment(updates, state.ema, decay, order=1) + count_inc = utils.safe_int32_increment(state.count) + if debias: + updates = bias_correction(new_ema, decay, count_inc) + state_ema = utils.cast_tree(new_ema, accumulator_dtype) + return updates, EmaState(count=count_inc, ema=state_ema) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByRssState(NamedTuple): + """State holding the sum of gradient squares to date.""" + sum_of_squares: base.Updates + + +def scale_by_rss( + initial_accumulator_value: float = 0.1, + eps: float = 1e-7 +) -> base.GradientTransformation: + """Rescale updates by the root of the sum of all squared gradients to date. + + References: + [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) + [McMahan et al., 2010](https://arxiv.org/abs/1002.4908) + + Args: + initial_accumulator_value: Starting value for accumulators, must be >= 0. + eps: A small floating point value to avoid zero denominator. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + sum_of_squares = jax.tree_util.tree_map( + lambda t: jnp.full_like(t, initial_accumulator_value), params) + return ScaleByRssState(sum_of_squares=sum_of_squares) + + def update_fn(updates, state, params=None): + del params + sum_of_squares = jax.tree_util.tree_map( + lambda g, t: _abs_sq(g) + t, updates, state.sum_of_squares) + inv_sqrt_g_square = jax.tree_util.tree_map( + lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), sum_of_squares) + updates = jax.tree_util.tree_map( + lambda scale, g: scale * g, inv_sqrt_g_square, updates) + return updates, ScaleByRssState(sum_of_squares=sum_of_squares) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByRmsState(NamedTuple): + """State for exponential root mean-squared (RMS)-normalized updates.""" + nu: base.Updates + + +def scale_by_rms( + decay: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0. +) -> base.GradientTransformation: + """Rescale updates by the root of the exp. moving avg of the square. + + References: + [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + + Args: + decay: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + initial_scale: Initial value for second moment. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + nu = jax.tree_util.tree_map( + lambda n: jnp.full_like(n, initial_scale), params) # second moment + return ScaleByRmsState(nu=nu) + + def update_fn(updates, state, params=None): + del params + nu = update_moment_per_elem_norm(updates, state.nu, decay, 2) + updates = jax.tree_util.tree_map( + lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu) + return updates, ScaleByRmsState(nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByRStdDevState(NamedTuple): + """State for centered exponential moving average of squares of updates.""" + mu: base.Updates + nu: base.Updates + + +def scale_by_stddev( + decay: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0. +) -> base.GradientTransformation: + """Rescale updates by the root of the centered exp. moving average of squares. + + References: + [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + + Args: + decay: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + initial_scale: Initial value for second moment. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_util.tree_map( + lambda n: jnp.full_like(n, initial_scale), params) # second moment + return ScaleByRStdDevState(mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, decay, 1) + nu = update_moment_per_elem_norm(updates, state.nu, decay, 2) + updates = jax.tree_util.tree_map( + lambda g, m, n: g * jax.lax.rsqrt(n - _abs_sq(m) + eps), + updates, mu, nu) + return updates, ScaleByRStdDevState(mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the Adam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + + +def scale_by_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Rescale updates according to the Adam algorithm. + + References: + [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByAmsgradState(NamedTuple): + """State for the AMSGrad algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + nu_max: base.Updates + + +def scale_by_amsgrad( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Rescale updates according to the AMSGrad algorithm. + + References: + [Reddi et al, 2018](https://openreview.net/forum?id=ryQu7f-RZ) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + nu_max = jax.tree_util.tree_map(jnp.zeros_like, params) + return ScaleByAmsgradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, + nu_max=nu_max) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + nu_max = jax.tree_util.tree_map(jnp.maximum, state.nu_max, nu_hat) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_max) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByAmsgradState(count=count_inc, mu=mu, nu=nu, + nu_max=nu_max) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8 +) -> base.GradientTransformation: + """Rescale updates according to the Adamax algorithm. + + References:nu = update_infinity_moment(updates, state.nu, b2, eps) + count_inc = utils.numerics.safe_int32_increment(state.count) + mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) + nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted maximum of grads. + eps: Term added to the denominator to improve numerical stability. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Infinite moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + count_inc = numerics.safe_int32_increment(state.count) + mu = update_moment(updates, state.mu, b1, 1) + nu = update_infinity_moment(updates, state.nu, b2, eps) + # Bias correction for mean. No bias correction needed for infinity moment. + mu_hat = bias_correction(mu, b1, count_inc) + updates = jax.tree_util.tree_map(lambda m, v: m / v, mu_hat, nu) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByEveState(NamedTuple): + """State for the Eve algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + d: float + f_prev: float + + +def scale_by_eve(b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f_star: float = 0., + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Rescale updates according to the Eve algorithm. + + References: + [Hayashi et al, 2018](https://arxiv.org/abs/1611.01505) + + Args: + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + An (init_fn, update_fn) tuple. + """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=1.) + + def update_fn(updates: base.Updates, state: ScaleByEveState, f: float): + """ + Eve requires an additional parameter: the loss for the current iteration: f = f_t + ScaleByEveState holds the loss from the previous iteration: state.f_prev = f_{t-1} + """ + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = utils.numerics.safe_int32_increment(state.count) + mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) + nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) + if count_inc > 1: + d_new = jnp.abs(f - state.f_prev) / (jnp.min(jnp.array([f,state.f_prev])) - f_star) + d_tilde = jnp.clip(d_new,1/c,c) + d = b3*state.d + (1-b3)*d_tilde + else: + d = 1. + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=f) + + return base.GradientTransformation(init_fn, update_fn) + + +ScaleState = base.EmptyState + + +def scale( + step_size: float +) -> base.GradientTransformation: + """Scale updates by some fixed scalar `step_size`. + + Args: + step_size: A scalar corresponding to a fixed scaling factor for updates. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ScaleState() + + def update_fn(updates, state, params=None): + del params + updates = jax.tree_util.tree_map(lambda g: step_size * g, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_param_block_norm( + min_scale: float = 1e-3 +) -> base.GradientTransformation: + """Scale updates for each param block by the norm of that block's parameters. + + A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix + (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. + + Args: + min_scale: Minimum scaling factor. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return base.EmptyState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + updates = jax.tree_util.tree_map( + lambda u, p: u * numerics.safe_norm(p, min_scale), + updates, params) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_param_block_rms( + min_scale: float = 1e-3 +) -> base.GradientTransformation: + """Scale updates by rms of the gradient for each param vector or matrix. + + A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix + (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. + + Args: + min_scale: Minimum scaling factor. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return base.EmptyState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + updates = jax.tree_util.tree_map( + lambda u, p: u * numerics.safe_root_mean_squares(p, min_scale), + updates, params) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByBeliefState(NamedTuple): + """State for the rescaling by AdaBelief algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + + +def scale_by_belief( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-16, + eps_root: float = 1e-16 +) -> base.GradientTransformation: + """Rescale updates according to the AdaBelief algorithm. + + References: + [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of variance of grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the second moment of the prediction error to + improve numerical stability. If backpropagating gradients through the + gradient transformation (e.g. for meta-learning), this must be non-zero. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + s = jax.tree_util.tree_map(jnp.zeros_like, params) # Second Central moment + return ScaleByBeliefState(count=jnp.zeros([], jnp.int32), mu=mu, nu=s) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + prediction_error = jax.tree_util.tree_map( + lambda g, m: g-m, updates, state.mu) + nu = update_moment_per_elem_norm(prediction_error, state.nu, b2, 2) + nu = jax.tree_util.tree_map(lambda v: v + eps_root, nu) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat) + return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_yogi( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-3, + eps_root: float = 0.0, + initial_accumulator_value: float = 1e-6 +) -> base.GradientTransformation: + """Rescale updates according to the Yogi algorithm. + + Supports complex numbers, see + https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 + + References: + [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of variance of grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + initial_accumulator_value: The starting value for accumulators. + Only positive values are allowed. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + value_like = lambda p: jnp.full_like(p, initial_accumulator_value) + mu = jax.tree_util.tree_map(value_like, params) # First moment + nu = jax.tree_util.tree_map(value_like, params) # Second Central moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = jax.tree_util.tree_map( + lambda g, v: v - (1 - b2) * jnp.sign(v - _abs_sq(g)) * _abs_sq(g), + updates, state.nu) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_radam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + threshold: float = 5.0 +) -> base.GradientTransformation: + """Rescale updates according to the Rectified Adam algorithm. + + References: + [Liu et al, 2020](https://arxiv.org/abs/1908.03265) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + threshold: Threshold for variance tractability. + + Returns: + A `GradientTransformation` object. + """ + + ro_inf = 2./(1 - b2) - 1 + def _radam_update(params): + ro = params[0] + mu_hat = params[1] + nu_hat = params[2] + r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro)) + updates = jax.tree_util.tree_map( + lambda m, v: r*m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) + return updates + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + b2t = b2**count_inc + ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.lax.cond( + ro >= threshold, _radam_update, lambda _: mu_hat, + (ro, mu_hat, nu_hat)) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +AddDecayedWeightsState = base.EmptyState + + +def add_decayed_weights( + weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None +) -> base.GradientTransformation: + """Add parameter scaled by `weight_decay`. + + Args: + weight_decay: A scalar weight decay rate. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return AddDecayedWeightsState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + updates = jax.tree_util.tree_map( + lambda g, p: g + weight_decay * p, updates, params) + return updates, state + + # If mask is not `None`, apply mask to the gradient transformation. + # E.g. it is common to skip weight decay on bias units and batch stats. + if mask is not None: + return wrappers.masked( + base.GradientTransformation(init_fn, update_fn), mask) + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByScheduleState(NamedTuple): + """Maintains count for scale scheduling.""" + count: chex.Array # shape=(), dtype=jnp.int32 + + +def scale_by_schedule( + step_size_fn: base.Schedule +) -> base.GradientTransformation: + """Scale updates using a custom schedule for the `step_size`. + + Args: + step_size_fn: A function that takes an update count as input and proposes + the step_size to multiply the updates by. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ScaleByScheduleState(count=jnp.zeros([], jnp.int32)) + + def update_fn(updates, state, params=None): + del params + step_size = step_size_fn(state.count) + updates = jax.tree_util.tree_map( + lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates) + return updates, ScaleByScheduleState( + count=numerics.safe_int32_increment(state.count)) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByTrustRatioState(NamedTuple): + """The scale and decay trust ratio transformation is stateless.""" + + +def scale_by_trust_ratio( + min_norm: float = 0.0, + trust_coefficient: float = 1., + eps: float = 0., +) -> base.GradientTransformation: + """Scale updates by trust ratio`. + + References: + [You et. al 2020](https://arxiv.org/abs/1904.00962) + + Args: + min_norm: Minimum norm for params and gradient norms; by default is zero. + trust_coefficient: A multiplier for the trust ratio. + eps: Additive constant added to the denominator for numerical stability. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ScaleByTrustRatioState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + + def _scale_update(update, param): + + # Clip norms to minimum value, by default no clipping. + param_norm = numerics.safe_norm(param, min_norm) + update_norm = numerics.safe_norm(update, min_norm) + trust_ratio = trust_coefficient * param_norm / (update_norm + eps) + + # If no minimum norm clipping is used + # Set trust_ratio to 1 in case where parameters would never be updated. + zero_norm = jnp.logical_or(param_norm == 0., update_norm == 0.) + safe_trust_ratio = jnp.where( + zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio) + + return update * safe_trust_ratio + + updates = jax.tree_util.tree_map(_scale_update, updates, params) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class AddNoiseState(NamedTuple): + """State for adding gradient noise. Contains a count for annealing.""" + count: chex.Array + rng_key: chex.PRNGKey + + +def add_noise( + eta: float, + gamma: float, + seed: int +) -> base.GradientTransformation: + """Add gradient noise. + + References: + [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807) + + Args: + eta: Base variance of the gaussian noise added to the gradient. + gamma: Decay exponent for annealing of the variance. + seed: Seed for random number generation. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return AddNoiseState( + count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed)) + + def update_fn(updates, state, params=None): # pylint: disable=missing-docstring + del params + num_vars = len(jax.tree_util.tree_leaves(updates)) + treedef = jax.tree_util.tree_structure(updates) + count_inc = numerics.safe_int32_increment(state.count) + variance = eta / count_inc**gamma + standard_deviation = jnp.sqrt(variance) + all_keys = jax.random.split(state.rng_key, num=num_vars + 1) + noise = jax.tree_util.tree_map( + lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), + updates, jax.tree_util.tree_unflatten(treedef, all_keys[1:])) + updates = jax.tree_util.tree_map( + lambda g, n: g + standard_deviation.astype(g.dtype) * n, + updates, noise) + return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0]) + + return base.GradientTransformation(init_fn, update_fn) + + +class ApplyEvery(NamedTuple): + """Contains a counter and a gradient accumulator.""" + count: chex.Array + grad_acc: base.Updates + + +def apply_every( + k: int = 1 +) -> base.GradientTransformation: + """Accumulate gradients and apply them every k steps. + + Note that if this transformation is part of a chain, the states of the other + transformations will still be updated at every step. In particular, using + `apply_every` with a batch size of N/2 and k=2 is not necessarily equivalent + to not using `apply_every` with a batch size of N. If this equivalence is + important for you, consider using the `optax.MultiSteps`. + + Args: + k: Emit non-zero gradients every k steps, otherwise accumulate them. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + grad_acc = jax.tree_util.tree_map(jnp.zeros_like, params) + return ApplyEvery(count=jnp.zeros([], jnp.int32), grad_acc=grad_acc) + + def update_fn(updates, state, params=None): + del params + c = state.count % k + acc = c != 0 + grad_acc = jax.tree_util.tree_map( + lambda g, ga: acc * ga + g, updates, state.grad_acc) + emit = c == (k - 1) + updates = jax.tree_util.tree_map(lambda ga: emit * ga, grad_acc) + count_inc = numerics.safe_int32_increment(state.count) + return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc) + + return base.GradientTransformation(init_fn, update_fn) + + +def _subtract_mean(g): + if len(g.shape) > 1: + return g - g.mean(tuple(range(1, len(g.shape))), keepdims=True) + else: + return g + + +CentralState = base.EmptyState + + +def centralize() -> base.GradientTransformation: + """Centralize gradients. + + References: + [Yong et al, 2020](https://arxiv.org/abs/2004.01461) + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return CentralState() + + def update_fn(updates, state, params=None): + del params + updates = jax.tree_util.tree_map(_subtract_mean, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleBySM3State(NamedTuple): + """State for the SM3 algorithm.""" + mu: base.Updates + nu: base.Updates + + +def scale_by_sm3( + b1: float = 0.9, + b2: float = 1.0, + eps: float = 1e-8 +) -> base.GradientTransformation: + """Scale updates by sm3`. + + References: + [Anil et. al 2019](https://arxiv.org/abs/1901.11150) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + + Returns: + A `GradientTransformation` object. + """ + + def zeros_for_dim(p): + return [jnp.zeros([s]) for s in p.shape] + + def init_fn(params): + _reject_complex(params) + mu = jax.tree_util.tree_map(zeros_for_dim, params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) + return ScaleBySM3State(mu, nu) + + def _expanded_shape(shape, axis): + # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. + # For eg: i = 1 returns [1, N, 1]. + rank = len(shape) + return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) + + def _new_accum(g, v): + coeffs = ((1.0 - b2) if b2 != 1.0 else 1.0, b2) + if g.ndim < 2: + return coeffs[0]*g**2 + coeffs[1]*v[0] + else: + return coeffs[0]*g**2 + coeffs[1]*functools.reduce(jnp.minimum, v) + + def _new_mu(g, i): + if g.ndim < 2: + return g + else: + return jnp.max(g, axis=other_axes(i, g.ndim)) + + def other_axes(idx, ndim): + return list(range(idx)) + list(range(idx+1, ndim)) + + def update_fn(updates, state, params=None): + del params + mu = jax.tree_util.tree_map( + lambda g, v: # pylint:disable=g-long-lambda + [jnp.reshape(v[i], _expanded_shape(g.shape, i)) for i in range(g.ndim)], + updates, state.mu) + accum = jax.tree_util.tree_map(_new_accum, updates, mu) + accum_inv_sqrt = jax.tree_util.tree_map( + lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), accum) + up = jax.tree_util.tree_map(lambda g, a: g*a, updates, accum_inv_sqrt) + nu = update_moment(up, state.nu, b1, 1) + mu = jax.tree_util.tree_map( + lambda g: [_new_mu(g, i) for i in range(g.ndim)], accum) + + return nu, ScaleBySM3State(mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByNovogradState(NamedTuple): + """State for Novograd.""" + count: chex.Array + mu: base.Updates + nu: base.Updates + + +def scale_by_novograd( + b1: float = 0.9, + b2: float = 0.25, + eps: float = 1e-8, + eps_root: float = 0.0, + weight_decay: float = 0.0, + mu_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Computes NovoGrad updates. + + References: + [Ginsburg et al, 2019](https://arxiv.org/abs/1905.11286) + + Args: + b1: A decay rate for the exponentially weighted average of grads. + b2: A decay rate for the exponentially weighted average of squared grads. + eps: A term added to the denominator to improve numerical stability. + eps_root: A term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + weight_decay: A scalar weight decay rate. + mu_dtype: An optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype is inferred from `params` and `updates`. + + Returns: + The corresponding `GradientTransformation`. + """ + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(lambda _: 0.0, params) # Second moment + return ScaleByNovogradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def nu_addition(grads): + return jnp.linalg.norm(grads)**2 + + def mu_addition(grads, params, nu): + return grads / (jnp.sqrt(nu + eps_root) + eps) + weight_decay * params + + def init_nu(grads, nu): + del nu + return jax.tree_util.tree_map(nu_addition, grads) + + def update_nu(grads, nu): + updates = jax.tree_util.tree_map(nu_addition, grads) + return update_moment(updates, nu, b2, 1) + + def init_mu(grads, params, mu, nu): + del mu + return jax.tree_util.tree_map(mu_addition, grads, params, nu) + + def update_mu(grads, params, mu, nu): + updates = jax.tree_util.tree_map(mu_addition, grads, params, nu) + return jax.tree_util.tree_map(lambda m, u: b1 * m + u, mu, updates) + + # Second moment + def update_fn(updates, state, params): + count_inc = numerics.safe_int32_increment(state.count) + + nu = jax.lax.cond(count_inc == 1, init_nu, update_nu, updates, state.nu) + + mu = jax.lax.cond(count_inc == 1, init_mu, update_mu, updates, params, + state.mu, nu) + + mu = utils.cast_tree(mu, mu_dtype) + updates = mu + return updates, ScaleByNovogradState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_optimistic_gradient(alpha: float = 1.0, + beta: float = 1.0 + ) -> base.GradientTransformation: + """Compute generalized optimistic gradients. + + References: + [Mokhtari et al, 2019](https://arxiv.org/abs/1901.08511v2) + + Args: + alpha: Coefficient for generalized optimistic gradient descent. + beta: Coefficient for negative momentum. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + prev_grads = jax.tree_util.tree_map(jnp.zeros_like, params) + return TraceState(trace=prev_grads) + + def update_fn(updates, state, params=None): + del params + + new_updates = jax.tree_util.tree_map( + lambda grad_t, grad_tm1: (alpha + beta) * grad_t - beta * grad_tm1, + updates, state.trace) + return new_updates, TraceState(trace=updates) + + return base.GradientTransformation(init_fn, update_fn) + + +# TODO(b/183800387): remove legacy aliases. +# These legacy aliases are here for checkpoint compatibility +# To be removed once checkpoints have updated. +_safe_int32_increment = numerics.safe_int32_increment +safe_int32_increment = numerics.safe_int32_increment +AdditiveWeightDecayState = AddDecayedWeightsState +additive_weight_decay = add_decayed_weights +ClipState = clipping.ClipState +ClipByGlobalNormState = clipping.ClipByGlobalNormState diff --git a/optax_add_eve/_src/transform_test.py b/optax_add_eve/_src/transform_test.py new file mode 100644 index 00000000..8218c2d9 --- /dev/null +++ b/optax_add_eve/_src/transform_test.py @@ -0,0 +1,305 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +"""Tests for `transform.py`.""" + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import jax +import jax.numpy as jnp +import numpy as np + +from optax_add_eve._src import alias +from optax_add_eve._src import combine +from optax_add_eve._src import transform +from optax_add_eve._src import update + +STEPS = 50 +LR = 1e-2 + + +class TransformTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) + self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + + @chex.all_variants + @parameterized.named_parameters([ + ('adam', transform.scale_by_adam), + ('adamax', transform.scale_by_adamax), + ('rmsprop', transform.scale_by_rms), + ('stddev', transform.scale_by_stddev), + ('trust_ratio', transform.scale_by_trust_ratio), + ('param_block_norm', transform.scale_by_param_block_norm), + ('param_block_rms', transform.scale_by_param_block_rms), + ]) + def test_scalers(self, scaler_constr): + params = self.init_params + + scaler = scaler_constr() + init_fn = self.variant(scaler.init) + transform_fn = self.variant(scaler.update) + + state = init_fn(params) + chex.assert_tree_all_finite(state) + + updates, state = transform_fn(self.per_step_updates, state, params) + chex.assert_tree_all_finite((params, updates, state)) + jax.tree_util.tree_map( + lambda *args: chex.assert_equal_shape(args), params, updates) + + @chex.all_variants + def test_add_decayed_weights(self): + # Define a transform that add decayed weights. + # We can define a mask either as a pytree, or as a function that + # returns the pytree. Below we define the pytree directly. + mask = (True, dict(a=True, b=False)) + tx = transform.add_decayed_weights(0.1, mask=mask) + # Define input updates and weights. + updates = ( + jnp.zeros((2,), dtype=jnp.float32), + dict( + a=jnp.zeros((2,), dtype=jnp.float32), + b=jnp.zeros((2,), dtype=jnp.float32),)) + weights = ( + jnp.ones((2,), dtype=jnp.float32), + dict( + a=jnp.ones((2,), dtype=jnp.float32), + b=jnp.ones((2,), dtype=jnp.float32),)) + # This mask means that we will add decayed weights to the first two + # terms in the input updates, but not to the last element. + expected_tx_updates = ( + 0.1*jnp.ones((2,), dtype=jnp.float32), + dict( + a=0.1*jnp.ones((2,), dtype=jnp.float32), + b=jnp.zeros((2,), dtype=jnp.float32),)) + # Apply transform + state = tx.init(weights) + transform_fn = self.variant(tx.update) + new_updates, _ = transform_fn(updates, state, weights) + # Assert output as expected. + chex.assert_trees_all_close(new_updates, expected_tx_updates) + + @chex.all_variants + def test_ema(self): + values = jnp.array([5.0, 7.0]) + decay = 0.9 + d = decay + + ema = transform.ema(decay=decay, debias=False) + state = ema.init(values[0]) # init to zeroes + + transform_fn = self.variant(ema.update) + mean, state = transform_fn(values[0], state) + np.testing.assert_allclose(mean, (1-d) * values[0], atol=1e-4) + + mean, state = transform_fn(values[1], state) + np.testing.assert_allclose( + mean, + (1 - d) * (values[1] + d * values[0]), atol=1e-2) + + @chex.all_variants + def test_ema_debias(self): + values = jnp.array([5.0, 7.0]) + decay = 0.9 + d = decay + + ema = transform.ema(decay=decay) + state = ema.init(values[0]) + + transform_fn = self.variant(ema.update) + mean, state = transform_fn(values[0], state) + np.testing.assert_allclose(mean, values[0], atol=1e-4) + + mean, state = transform_fn(values[1], state) + np.testing.assert_allclose( + mean, + ((1 - d) * values[1] + d * (1 - d) * values[0]) / (1 - d**2), + atol=1e-2) + # The state must not be debiased. + np.testing.assert_allclose( + state.ema, + (1 - d) * values[1] + d * (1 - d) * values[0], + atol=1e-2) + + @chex.all_variants + def test_update_infinity_moment(self): + values = jnp.array([5.0, 7.0]) + decay = 0.9 + d = decay + + transform_fn = self.variant(transform.update_infinity_moment) + + # identity if updating with itself (and positive decay) + np.testing.assert_allclose( + transform_fn(values, values, decay=d, eps=0.), + values, + atol=1e-4 + ) + # return (decayed) max when updating with zeros + np.testing.assert_allclose( + transform_fn(jnp.zeros_like(values), values, decay=d, eps=0.), + d * values, + atol=1e-4 + ) + # infinity norm takes absolute values + np.testing.assert_allclose( + transform_fn(-values, jnp.zeros_like(values), decay=d, eps=0.), + values, + atol=1e-4 + ) + # return at least `eps` + np.testing.assert_allclose( + transform_fn(jnp.zeros_like(values), jnp.zeros_like(values), + decay=d, eps=1e-2), + jnp.ones_like(values) * 1e-2, + atol=1e-4 + ) + + @chex.all_variants + def test_apply_every(self): + # The frequency of the application of sgd + k = 4 + zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) + + # optax sgd + optax_sgd_params = self.init_params + sgd = alias.sgd(LR, 0.0) + state_sgd = sgd.init(optax_sgd_params) + + # optax sgd plus apply every + optax_sgd_apply_every_params = self.init_params + sgd_apply_every = combine.chain( + transform.apply_every(k=k), + transform.trace(decay=0, nesterov=False), + transform.scale(-LR)) + state_sgd_apply_every = sgd_apply_every.init(optax_sgd_apply_every_params) + transform_fn = self.variant(sgd_apply_every.update) + + for i in range(STEPS): + # Apply a step of sgd + updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) + optax_sgd_params = update.apply_updates(optax_sgd_params, updates_sgd) + + # Apply a step of sgd_apply_every + updates_sgd_apply_every, state_sgd_apply_every = transform_fn( + self.per_step_updates, state_sgd_apply_every) + optax_sgd_apply_every_params = update.apply_updates( + optax_sgd_apply_every_params, updates_sgd_apply_every) + + # Every k steps, check equivalence. + if i % k == k-1: + chex.assert_trees_all_close( + optax_sgd_apply_every_params, optax_sgd_params, + atol=1e-6, rtol=1e-5) + # Otherwise, check update is zero. + else: + chex.assert_trees_all_close( + updates_sgd_apply_every, zero_update, atol=0.0, rtol=0.0) + + def test_scale(self): + updates = self.per_step_updates + for i in range(1, STEPS + 1): + factor = 0.1 ** i + rescaler = transform.scale(factor) + # Apply rescaling. + scaled_updates, _ = rescaler.update(updates, None) + # Manually scale updates. + def rescale(t): + return t * factor # pylint:disable=cell-var-from-loop + manual_updates = jax.tree_util.tree_map(rescale, updates) + # Check the rescaled updates match. + chex.assert_trees_all_close(scaled_updates, manual_updates) + + @parameterized.named_parameters([ + ('1d', [1.0, 2.0], [1.0, 2.0]), + ('2d', [[1.0, 2.0], [3.0, 4.0]], [[-0.5, 0.5], [-0.5, 0.5]]), + ('3d', [[[1., 2.], [3., 4.]], + [[5., 6.], [7., 8.]]], [[[-1.5, -0.5], [0.5, 1.5]], + [[-1.5, -0.5], [0.5, 1.5]]]), + ]) + def test_centralize(self, inputs, outputs): + inputs = jnp.asarray(inputs) + outputs = jnp.asarray(outputs) + centralizer = transform.centralize() + centralized_inputs, _ = centralizer.update(inputs, None) + chex.assert_trees_all_close(centralized_inputs, outputs) + + @chex.all_variants + def test_add_noise_has_correct_variance_scaling(self): + # Prepare to compare noise with a rescaled unit-variance substitute. + eta = 0.3 + gamma = 0.55 + seed = 314 + noise = transform.add_noise(eta, gamma, seed) + noise_unit = transform.add_noise(1.0, 0.0, seed) + + params = self.init_params + state = noise.init(params) + state_unit = noise_unit.init(params) + + # Check the noise itself by adding it to zeros. + updates = jax.tree_util.tree_map(jnp.zeros_like, params) + + for i in range(1, STEPS + 1): + updates_i, state = self.variant(noise.update)(updates, state) + updates_i_unit, state_unit = noise_unit.update(updates, state_unit) + + scale = jnp.sqrt(eta / i**gamma) + + updates_i_rescaled = jax.tree_util.tree_map( + lambda g, s=scale: g * s, updates_i_unit) + + chex.assert_trees_all_close(updates_i, updates_i_rescaled, rtol=1e-4) + + def test_scale_by_optimistic_gradient(self): + + def f(params: jnp.ndarray) -> jnp.ndarray: + return params['x'] ** 2 + + initial_params = { + 'x': jnp.array(2.0) + } + + og = transform.scale_by_optimistic_gradient() + og_state = og.init(initial_params) + # Provide some arbitrary previous gradient. + og_state.trace['x'] = 1.5 + + g = jax.grad(f)(initial_params) + og_true = 2 * g['x'] - og_state.trace['x'] + og, og_state = og.update(g, og_state) + + # Compare transformation output with manually computed optimistic gradient. + chex.assert_trees_all_close(og_true, og['x']) + + @chex.all_variants + def test_bias_correction_bf16(self): + bias_correction_fn = self.variant(transform.bias_correction) + m = jnp.logspace(-10, 10, num=21, dtype=jnp.bfloat16) # 1e-10 ... 1e10 + for decay in (0.9, 0.99, 0.999, 0.9995): + for count in (1, 10, 100, 1000): + chex.assert_tree_all_finite( + bias_correction_fn(m, decay, count), + custom_message=f'failed with decay={decay}, count={count}') + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/update.py b/optax_add_eve/_src/update.py new file mode 100644 index 00000000..ad88eee8 --- /dev/null +++ b/optax_add_eve/_src/update.py @@ -0,0 +1,103 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Apply transformed gradient updates to parameters.""" + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import base + + +def apply_updates(params: base.Params, updates: base.Updates) -> base.Params: + """Applies an update to the corresponding parameters. + + This is a utility functions that applies an update to a set of parameters, and + then returns the updated parameters to the caller. As an example, the update + may be a gradient transformed by a sequence of`GradientTransformations`. This + function is exposed for convenience, but it just adds updates and parameters; + you may also apply updates to parameters manually, using `tree_map` + (e.g. if you want to manipulate updates in custom ways before applying them). + + Args: + params: a tree of parameters. + updates: a tree of updates, the tree structure and the shape of the leaf + nodes must match that of `params`. + + Returns: + Updated parameters, with same structure, shape and type as `params`. + """ + return jax.tree_util.tree_map( + lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), + params, updates) + + +def incremental_update( + new_tensors: base.Params, + old_tensors: base.Params, + step_size: chex.Numeric +) -> base.Params: + """Incrementally update parameters via polyak averaging. + + Polyak averaging tracks an (exponential moving) average of the past + parameters of a model, for use at test/evaluation time. + + References: + [Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046) + + Args: + new_tensors: the latest value of the tensors. + old_tensors: a moving average of the values of the tensors. + step_size: the step_size used to update the polyak average on each step. + + Returns: + an updated moving average `step_size*new+(1-step_size)*old` of the params. + """ + return jax.tree_util.tree_map( + lambda new, old: step_size * new + (1.0 - step_size) * old, + new_tensors, old_tensors) + + +def periodic_update( + new_tensors: base.Params, + old_tensors: base.Params, + steps: chex.Array, + update_period: int +) -> base.Params: + """Periodically update all parameters with new values. + + A slow copy of a model's parameters, updated every K actual updates, can be + used to implement forms of self-supervision (in supervised learning), or to + stabilise temporal difference learning updates (in reinforcement learning). + + References: + [Grill et al., 2020](https://arxiv.org/abs/2006.07733) + [Mnih et al., 2015](https://arxiv.org/abs/1312.5602) + + Args: + new_tensors: the latest value of the tensors. + old_tensors: a slow copy of the model's parameters. + steps: number of update steps on the "online" network. + update_period: every how many steps to update the "target" network. + + Returns: + a slow copy of the model's parameters, updated every `update_period` steps. + """ + return jax.lax.cond( + jnp.mod(steps, update_period) == 0, + lambda _: new_tensors, + lambda _: old_tensors, + None) + diff --git a/optax_add_eve/_src/update_test.py b/optax_add_eve/_src/update_test.py new file mode 100644 index 00000000..73f57128 --- /dev/null +++ b/optax_add_eve/_src/update_test.py @@ -0,0 +1,83 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `update.py`.""" + +from absl.testing import absltest + +import chex +import jax +import jax.numpy as jnp + +from optax_add_eve._src import update + + +class UpdateTest(chex.TestCase): + + @chex.all_variants + def test_apply_updates(self): + params = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) + grads = jax.tree_util.tree_map(lambda t: 2 * t, params) + exp_params = jax.tree_util.tree_map(lambda t: 3 * t, params) + new_params = self.variant(update.apply_updates)(params, grads) + + chex.assert_trees_all_close( + exp_params, new_params, atol=1e-10, rtol=1e-5) + + @chex.all_variants + def test_apply_updates_mixed_precision(self): + params = ( + {'a': jnp.ones((3, 2), dtype=jnp.bfloat16)}, + jnp.ones((1,), dtype=jnp.bfloat16)) + grads = jax.tree_util.tree_map( + lambda t: (2 * t).astype(jnp.float32), params) + new_params = self.variant(update.apply_updates)(params, grads) + + for leaf in jax.tree_util.tree_leaves(new_params): + assert leaf.dtype == jnp.bfloat16 + + @chex.all_variants + def test_incremental_update(self): + params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) + params_2 = jax.tree_util.tree_map(lambda t: 2 * t, params_1) + exp_params = jax.tree_util.tree_map(lambda t: 1.5 * t, params_1) + new_params = self.variant( + update.incremental_update)(params_2, params_1, 0.5) + + chex.assert_trees_all_close( + exp_params, new_params, atol=1e-10, rtol=1e-5) + + @chex.all_variants + def test_periodic_update(self): + params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) + params_2 = jax.tree_util.tree_map(lambda t: 2 * t, params_1) + + update_period = 5 + update_fn = self.variant(update.periodic_update) + + for j in range(3): + for i in range(1, update_period): + new_params = update_fn( + params_2, params_1, j*update_period+i, update_period) + chex.assert_trees_all_close( + params_1, new_params, atol=1e-10, rtol=1e-5) + + new_params = update_fn( + params_2, params_1, (j+1)*update_period, update_period) + chex.assert_trees_all_close( + params_2, new_params, atol=1e-10, rtol=1e-5) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/utils.py b/optax_add_eve/_src/utils.py new file mode 100644 index 00000000..a61febff --- /dev/null +++ b/optax_add_eve/_src/utils.py @@ -0,0 +1,152 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for testing.""" + +from typing import Optional, Tuple, Sequence + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats.norm as multivariate_normal + +from optax_add_eve._src import linear_algebra +from optax_add_eve._src import numerics + + +def tile_second_to_last_dim(a: chex.Array) -> chex.Array: + ones = jnp.ones_like(a) + a = jnp.expand_dims(a, axis=-1) + return jnp.expand_dims(ones, axis=-2) * a + + +def canonicalize_dtype( + dtype: Optional[chex.ArrayDType]) -> Optional[chex.ArrayDType]: + """Canonicalise a dtype, skip if None.""" + if dtype is not None: + return jax.dtypes.canonicalize_dtype(dtype) + return dtype + + +def cast_tree(tree: chex.ArrayTree, + dtype: Optional[chex.ArrayDType]) -> chex.ArrayTree: + """Cast tree to given dtype, skip if None.""" + if dtype is not None: + return jax.tree_util.tree_map(lambda t: t.astype(dtype), tree) + else: + return tree + + +def set_diags(a: chex.Array, new_diags: chex.Array) -> chex.Array: + """Set the diagonals of every DxD matrix in an input of shape NxDxD. + + Args: + a: rank 3, tensor NxDxD. + new_diags: NxD matrix, the new diagonals of each DxD matrix. + + Returns: + NxDxD tensor, with the same contents as `a` but with the diagonal + changed to `new_diags`. + """ + n, d, d1 = a.shape + assert d == d1 + + indices1 = jnp.repeat(jnp.arange(n), d) + indices2 = jnp.tile(jnp.arange(d), n) + indices3 = indices2 + + # Use numpy array setting + a = a.at[indices1, indices2, indices3].set(new_diags.flatten()) + return a + + +class MultiNormalDiagFromLogScale(): + """MultiNormalDiag which directly exposes its input parameters.""" + + def __init__(self, loc: chex.Array, log_scale: chex.Array): + self._log_scale = log_scale + self._scale = jnp.exp(log_scale) + self._mean = loc + self._param_shape = jax.lax.broadcast_shapes( + self._mean.shape, self._scale.shape) + + def sample(self, shape: Sequence[int], + seed: chex.PRNGKey) -> chex.Array: + sample_shape = tuple(shape) + self._param_shape + return jax.random.normal( + seed, shape=sample_shape) * self._scale + self._mean + + def log_prob(self, x: chex.Array) -> chex.Array: + log_prob = multivariate_normal.logpdf(x, loc=self._mean, scale=self._scale) + # Sum over parameter axes. + sum_axis = [-(i + 1) for i in range(len(self._param_shape))] + return jnp.sum(log_prob, axis=sum_axis) + + @property + def log_scale(self) -> chex.Array: + return self._log_scale + + @property + def params(self) -> Sequence[chex.Array]: + return [self._mean, self._log_scale] + + +def multi_normal(loc: chex.Array, + log_scale: chex.Array) -> MultiNormalDiagFromLogScale: + return MultiNormalDiagFromLogScale(loc=loc, log_scale=log_scale) + + +@jax.custom_vjp +def _scale_gradient(inputs: chex.ArrayTree, scale: float) -> chex.ArrayTree: + """Internal gradient scaling implementation.""" + del scale # Only used for the backward pass defined in _scale_gradient_bwd. + return inputs + + +def _scale_gradient_fwd(inputs: chex.ArrayTree, + scale: float) -> Tuple[chex.ArrayTree, float]: + return _scale_gradient(inputs, scale), scale + + +def _scale_gradient_bwd(scale: float, + g: chex.ArrayTree) -> Tuple[chex.ArrayTree, None]: + return (jax.tree_util.tree_map(lambda g_: g_ * scale, g), None) + + +_scale_gradient.defvjp(_scale_gradient_fwd, _scale_gradient_bwd) + + +def scale_gradient(inputs: chex.ArrayTree, scale: float) -> chex.ArrayTree: + """Scales gradients for the backwards pass. + + Args: + inputs: A nested array. + scale: The scale factor for the gradient on the backwards pass. + + Returns: + An array of the same structure as `inputs`, with scaled backward gradient. + """ + # Special case scales of 1. and 0. for more efficiency. + if scale == 1.: + return inputs + elif scale == 0.: + return jax.lax.stop_gradient(inputs) + else: + return _scale_gradient(inputs, scale) + + +# TODO(b/183800387): remove legacy aliases. +safe_norm = numerics.safe_norm +safe_int32_increment = numerics.safe_int32_increment +global_norm = linear_algebra.global_norm diff --git a/optax_add_eve/_src/utils_test.py b/optax_add_eve/_src/utils_test.py new file mode 100644 index 00000000..03f13d3d --- /dev/null +++ b/optax_add_eve/_src/utils_test.py @@ -0,0 +1,65 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `utils.py`.""" + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + +import jax + +from optax_add_eve._src import utils + + +class ScaleGradientTest(parameterized.TestCase): + + @parameterized.product(inputs=[-1., 0., 1.], scale=[-0.5, 0., 0.5, 1., 2.]) + @mock.patch.object(jax.lax, 'stop_gradient', wraps=jax.lax.stop_gradient) + def test_scale_gradient(self, mock_sg, inputs, scale): + + def fn(inputs): + outputs = utils.scale_gradient(inputs, scale) + return outputs ** 2 + + grad = jax.grad(fn) + self.assertEqual(grad(inputs), 2 * inputs * scale) + if scale == 0.: + mock_sg.assert_called_once_with(inputs) + else: + self.assertFalse(mock_sg.called) + self.assertEqual(fn(inputs), inputs ** 2) + + @parameterized.product(scale=[-0.5, 0., 0.5, 1., 2.]) + def test_scale_gradient_pytree(self, scale): + + def fn(inputs): + outputs = utils.scale_gradient(inputs, scale) + outputs = jax.tree_util.tree_map(lambda x: x ** 2, outputs) + return sum(jax.tree_util.tree_leaves(outputs)) + + inputs = dict(a=-1., b=dict(c=(2.,), d=0.)) + + grad = jax.grad(fn) + grads = grad(inputs) + jax.tree_util.tree_map( + lambda i, g: self.assertEqual(g, 2 * i * scale), inputs, grads) + self.assertEqual( + fn(inputs), + sum(jax.tree_util.tree_leaves( + jax.tree_util.tree_map(lambda x: x**2, inputs)))) + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/_src/wrappers.py b/optax_add_eve/_src/wrappers.py new file mode 100644 index 00000000..3ae66026 --- /dev/null +++ b/optax_add_eve/_src/wrappers.py @@ -0,0 +1,547 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Transformation wrappers.""" + +import functools +from typing import Any, Callable, NamedTuple, Optional, Tuple, Union + +import chex +import jax +from jax import lax +import jax.numpy as jnp +from jax.tree_util import tree_flatten +from jax.tree_util import tree_map +from jax.tree_util import tree_unflatten +import numpy as np +from optax_add_eve._src import base +from optax_add_eve._src import numerics +import typing_extensions + +Array = jnp.ndarray + + +def flatten( + inner: base.GradientTransformation +) -> base.GradientTransformation: + """Flattens parameters and gradients for init and update of inner transform. + + This can reduce the overhead of performing many calculations on lots of small + variables, at the cost of slightly increased memory usage. + + Args: + inner: Inner transformation to flatten inputs for. + + Returns: + New GradientTransformation. + """ + + def _flatten(params): + """Flattens and concatenates all tensors in params to a single vector.""" + params, _ = tree_flatten(params) + return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) + + def _unflatten(updates, flat): + """Extracts tensors from flat, using the structure and shapes of params.""" + updates_flat, treedef = tree_flatten(updates) + offsets = [] + for update in updates_flat: + size = np.prod(update.shape) + if offsets: + offsets.append(size + offsets[-1]) + else: + offsets.append(size) + del offsets[-1] + flat_split = jnp.split(flat, offsets) + reshaped = [ + jnp.reshape(flat_update, update.shape) + for flat_update, update in zip(flat_split, updates_flat) + ] + return tree_unflatten(treedef, reshaped) + + def init_fn(params): + flat = _flatten(params) + return inner.init(flat) + + def update_fn(updates, state, params=None): + if params is not None: + params = _flatten(params) + updates_flat, state = inner.update(_flatten(updates), state, params) + updates = _unflatten(updates, updates_flat) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class ApplyIfFiniteState(NamedTuple): + """State of the `GradientTransformation` returned by `apply_if_finite`. + + Fields: + notfinite_count: Number of consecutive gradient updates containing an Inf or + a NaN. This number is reset to 0 whenever a gradient update without an Inf + or a NaN is done. + last_finite: Whether or not the last gradient update contained an Inf of a + NaN. + total_notfinite: Total number of gradient updates containing an Inf or + a NaN since this optimizer was initialised. This number is never reset. + inner_state: The state of the inner `GradientTransformation`. + """ + notfinite_count: jnp.array + last_finite: jnp.array + total_notfinite: jnp.array + inner_state: Any + + +def apply_if_finite( + inner: base.GradientTransformation, + max_consecutive_errors: int +) -> base.GradientTransformation: + """A function that wraps an optimizer to make it robust to a few NaNs or Infs. + + The purpose of this function is to prevent any optimization to happen if the + gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the + gradients, the wrapped optimizer ignores that gradient update. If the NaNs or + Infs persist after a given number of updates, the wrapped optimizer gives up + and accepts the update. + + Args: + inner: Inner transformation to be wrapped. + max_consecutive_errors: Maximum number of consecutive gradient updates + containing NaNs of Infs that the wrapped optimizer will ignore. After + that many ignored updates, the optimizer will give up and accept. + + Returns: + New GradientTransformation. + """ + + def init(params): + return ApplyIfFiniteState( + notfinite_count=jnp.zeros([], jnp.int32), + last_finite=jnp.array(True, jnp.bool_), + total_notfinite=jnp.zeros([], jnp.int32), + inner_state=inner.init(params)) + + def update(updates, state, params=None): + inner_state = state.inner_state + flat_updates = tree_flatten(updates)[0] + isfinite = jnp.all( + jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) + notfinite_count = jnp.where( + isfinite, jnp.zeros([], jnp.int32), + numerics.safe_int32_increment(state.notfinite_count)) + + def do_update(_): + return inner.update(updates, inner_state, params) + def reject_update(_): + return (tree_map(jnp.zeros_like, updates), inner_state) + + updates, new_inner_state = lax.cond( + jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors), + do_update, reject_update, operand=None) + + return updates, ApplyIfFiniteState( + notfinite_count=notfinite_count, + last_finite=isfinite, + total_notfinite=jnp.where( + isfinite, state.total_notfinite, + numerics.safe_int32_increment(state.total_notfinite)), + inner_state=new_inner_state) + + return base.GradientTransformation(init=init, update=update) + + +def _zeros_tree_like(inp_tree): + return jax.tree_util.tree_map(jnp.zeros_like, inp_tree) + + +class MultiStepsState(NamedTuple): + """State of the `GradientTransformation` returned by `MultiSteps`. + + Fields: + mini_step: current mini-step counter. At an update, this either increases by + 1 or is reset to 0. + gradient_step: gradient step counter. This only increases after enough + mini-steps have been accumulated. + inner_opt_state: the state of the wrapped otpimiser. + acc_grads: accumulated gradients over multiple mini-steps. + skip_state: an arbitrarily nested tree of arrays. This is only + relevant when passing a `should_skip_update_fn` to `MultiSteps`. This + structure will then contain values for debugging and or monitoring. The + actual structure will vary depending on the choice of + `ShouldSkipUpdateFunction`. + """ + mini_step: Array + gradient_step: Array + inner_opt_state: Any + acc_grads: Any + skip_state: chex.ArrayTree = () + + +class ShouldSkipUpdateFunction(typing_extensions.Protocol): + + def __call__(self, updates: base.Updates, gradient_step: Array, + params: Optional[base.Params]) -> Tuple[Array, chex.ArrayTree]: + """Returns true to indicate that updates should be skipped in a multi-step. + + Args: + updates: The updates that the gradient transformation has proposed + to apply + gradient_step: The current gradient step (see + `MultiStepsState.gradient_step`). This can be used for example to reject + large gradients with an annealed maximum allowed gradient norm. + params: If known, the current parameter tree of the function being + transformed. + Returns: + A tuple: + * First element is an array with a single bool indicating whether or not + the updates should be applied. + * Second element is an arbitrarily nested structure of arrays that will be + stored in `MultiStepsState.skip_state`. The structure will vary from + function to function. Debugging info, or values to monitor, can be put + in this structure. + """ + + +def skip_not_finite( + updates: base.Updates, gradient_step: Array, + params: Optional[base.Params]) -> Tuple[Array, chex.ArrayTree]: + """Returns True iff any of the `updates` contains an inf or a NaN. + + Args: + updates: see `ShouldSkipUpdateFunction`. + gradient_step: see `ShouldSkipUpdateFunction`. + params: see `ShouldSkipUpdateFunction`. + + Returns: + A tuple: + * First element is a scalar array of type bool. + * Second element is a dictionary with keys: + - `should_skip`: True iff `updates` contains an inf or a NaN. + - `num_not_finite`: total number of inf and NaN found in `updates`. + """ + del gradient_step, params + all_is_finite = [jnp.sum(jnp.logical_not(jnp.isfinite(p))) + for p in jax.tree_util.tree_leaves(updates)] + num_not_finite = jnp.sum(jnp.array(all_is_finite)) + should_skip = num_not_finite > 0 + return should_skip, dict(should_skip=should_skip, + num_not_finite=num_not_finite) + + +def skip_large_updates(updates: base.Updates, + gradient_step: Array, + params: Optional[base.Params], + max_squared_norm: float) -> Tuple[Array, chex.ArrayTree]: + """Returns True if the global norm square of `updates` is small enough. + + Args: + updates: see `ShouldSkipUpdateFunction`. + gradient_step: see `ShouldSkipUpdateFunction`. + params: see `ShouldSkipUpdateFunction`. + max_squared_norm: only updates with a norm square strictly less than this + value will be accepted. + + Returns: + A tuple: + * First element is a scalar array of type bool. + * Second element is a dictionary with keys: + - `should_skip`: True iff square norm of `updates` is larger or equal than + `max_squared_norm`. + - `norm_squared`: overall norm square of the `updates`. + """ + del gradient_step, params + norm_sq = jnp.sum( + jnp.array([jnp.sum(p**2) for p in jax.tree_util.tree_leaves(updates)])) + # This will also return True if `norm_sq` is NaN. + should_skip = jnp.logical_not(norm_sq < max_squared_norm) + return should_skip, dict(should_skip=should_skip, norm_squared=norm_sq) + + +class MultiSteps: + """An optimizer wrapper to accumulate gradients over multiple steps. + + This wrapper collects together the updates passed to its `update` function + over consecutive steps until a given number of scheduled steps is reached. + In each of these intermediate steps, the returned value from the optimizer is + a tree of zeros of the same shape of the updates passed as input. + + Once the scheduled number of intermediate 'mini-steps' has been reached, the + gradients accumulated to the current time will be passed to the wrapped + optimizer's update function, (with the inner optimizer's state being updated + appropriately) and then returned to the caller. The wrapper's accumulated + gradients are then set back to zero and the process starts again. + + The number of mini-steps per gradient update is controlled by a function, and + it can vary over training. This offers a means of varying batch size over + training. + """ + + def __init__( + self, + opt: base.GradientTransformation, + every_k_schedule: Union[int, Callable[[Array], Array]], + use_grad_mean: bool = True, + should_skip_update_fn: Optional[ShouldSkipUpdateFunction] = None): + """Initialiser. + + Args: + opt: the wrapped optimizer. + every_k_schedule: an int or f a function. + * As a function, it returns how many mini-steps should be accumulated + in a single gradient step. Its only argument is the current + gradient step count. By varying the returned value, users can vary the + overall training batch size. + * If an `int`, this is the constant number of mini-steps per gradient + update. + use_grad_mean: if `True` (the default), gradients accumulated over + multiple mini-steps are averaged. Otherwise, they are summed. + should_skip_update_fn: if provided, this function is used to decide when + to accept or reject the updates from a mini-step. When a mini-step is + rejected, the inner state of `MultiSteps` is not updated. In other + words, it is as if this mini-step never happened. For example: + * to ignore updates containing inf or NaN, do + `should_skip_update_fn=skip_not_finite`; + * to ignore updates with a norm square larger then 42, do + `should_skip_update_fn=functools.partial(skip_large_updates, + max_norm_sq=42.)`. + Note that the optimizer's state `MultiStepsState` contains a field + `skip_state` in which debugging and monitoring information returned + by `should_skip_update_fn` is written. + """ + self._opt = opt + if isinstance(every_k_schedule, int): + self._every_k_schedule = lambda step: every_k_schedule + else: + self._every_k_schedule = every_k_schedule + self._use_grad_mean = use_grad_mean + + if self._use_grad_mean: + # Use Welford algorithm for numerically stable aggregation of mean. + self._acc_update = ( + lambda grad, acc, *, n_acc: acc + (grad - acc) / (n_acc + 1)) + else: + self._acc_update = lambda grad, acc, *, n_acc: grad + acc + + if should_skip_update_fn is None: + + def should_skip_update_fn(*unused_args, **unused_kwargs): + return jnp.array(False, dtype=jnp.bool_), () + + self._should_skip_update_fn = should_skip_update_fn + + @property + def inner_opt(self): + return self._opt + + def init(self, params: Any) -> MultiStepsState: + """Builds and returns initial `MultiStepsState`.""" + updates = _zeros_tree_like(params) + gradient_step = jnp.zeros([], dtype=jnp.int32) + _, skip_state = self._should_skip_update_fn(updates, gradient_step, params) + init_state = MultiStepsState( + mini_step=jnp.zeros([], dtype=jnp.int32), + gradient_step=gradient_step, + inner_opt_state=self._opt.init(params), + acc_grads=updates, + skip_state=skip_state) + return init_state + + def update(self, + updates: base.Updates, + state: MultiStepsState, + params: Optional[base.Params] = None + ) -> Tuple[base.Updates, MultiStepsState]: + """Accumulates gradients and proposes non-zero updates every `k_steps`.""" + k_steps = self._every_k_schedule(state.gradient_step) + acc_grads = jax.tree_util.tree_map( + functools.partial(self._acc_update, n_acc=state.mini_step), + updates, state.acc_grads) + + should_skip_update, skip_state = self._should_skip_update_fn( + updates, state.gradient_step, params) + + def final_step(args): + del args + final_updates, new_inner_state = self._opt.update( + acc_grads, state.inner_opt_state, params=params) + new_state = MultiStepsState( + mini_step=jnp.zeros([], dtype=jnp.int32), + gradient_step=numerics.safe_int32_increment(state.gradient_step), + inner_opt_state=new_inner_state, + acc_grads=_zeros_tree_like(acc_grads), + skip_state=skip_state) + return final_updates, new_state + + def mid_step(args): + del args + updates_shape_dtype, _ = jax.eval_shape( + self._opt.update, acc_grads, state.inner_opt_state, params=params) + mid_updates = jax.tree_util.tree_map( + lambda sd: jnp.zeros(sd.shape, sd.dtype), updates_shape_dtype) + new_state = MultiStepsState( + mini_step=numerics.safe_int32_increment(state.mini_step), + gradient_step=state.gradient_step, + inner_opt_state=state.inner_opt_state, + acc_grads=acc_grads, + skip_state=skip_state) + return mid_updates, new_state + + new_updates, new_state = jax.lax.cond( + state.mini_step < k_steps - 1, (), mid_step, (), final_step) + + if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()): + raise ValueError( + 'The `should_skip_update_fn` function should return a boolean scalar ' + f'array, but it returned an array of dtype {should_skip_update.dtype}' + f' and shape {should_skip_update.shape}') + + multi_state_when_skip = MultiStepsState( + mini_step=state.mini_step, + gradient_step=state.gradient_step, + inner_opt_state=state.inner_opt_state, + acc_grads=state.acc_grads, + skip_state=skip_state) + zero_updates = jax.tree_util.tree_map(jnp.zeros_like, updates) + new_updates, new_state = jax.lax.cond( + should_skip_update, + (), lambda args: (zero_updates, multi_state_when_skip), + (), lambda args: (new_updates, new_state)) + + return new_updates, new_state + + def has_updated(self, state: MultiStepsState) -> Array: + return jnp.logical_and(state.mini_step == 0, state.gradient_step > 0) + + def gradient_transformation(self) -> base.GradientTransformation: + return base.GradientTransformation(init=self.init, update=self.update) + + +class MaskedState(NamedTuple): + """Maintains inner transform state for masked transformations.""" + inner_state: Any + + +class MaskedNode(NamedTuple): + """A node used to mask out unspecified parts of a tree. + + This node is ignored when mapping functions across the tree e.g. using + `jax.tree_util.tree_map` since it is a container without children. It can + therefore be used to mask out parts of a tree. + """ + + +def masked( + inner: base.GradientTransformation, + mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]] +) -> base.GradientTransformation: + """Mask updates so only some are transformed, the rest are passed through. + + For example, it is common to skip weight decay for BatchNorm scale and all + bias parameters. In many networks, these are the only parameters with only + one dimension. So, you may create a mask function to mask these out as + follows:: + + mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim != 1, p) + weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn) + + You may alternatively create the mask pytree upfront:: + + mask = jax.tree_util.tree_map(lambda x: x.ndim != 1, params) + weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask) + + For the ``inner`` transform, state will only be stored for the parameters that + have a mask value of ``True``. + + Args: + inner: Inner transformation to mask. + mask: a PyTree with same structure as (or a prefix of) the params PyTree, or + a Callable that returns such a pytree given the params/updates. The leaves + should be booleans, ``True`` for leaves/subtrees you want to apply the + transformation to, and ``False`` for those you want to skip. The mask must + be static for the gradient transformation to be jit-compilable. + + Returns: + New GradientTransformation wrapping ``inner``. + """ + def mask_pytree(pytree, mask_tree): + return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree) + + def init_fn(params): + mask_tree = mask(params) if callable(mask) else mask + masked_params = mask_pytree(params, mask_tree) + return MaskedState(inner_state=inner.init(masked_params)) + + def update_fn(updates, state, params=None): + mask_tree = mask(updates) if callable(mask) else mask + masked_updates = mask_pytree(updates, mask_tree) + masked_params = None if params is None else mask_pytree(params, mask_tree) + + new_masked_updates, new_inner_state = inner.update( + masked_updates, state.inner_state, masked_params) + + new_updates = tree_map( + lambda m, new_u, old_u: new_u if m else old_u, + mask_tree, new_masked_updates, updates) + return new_updates, MaskedState(inner_state=new_inner_state) + + return base.GradientTransformation(init_fn, update_fn) + + +class MaybeUpdateState(NamedTuple): + """Maintains inner transform state and adds a step counter.""" + inner_state: Any + step: Array + + +def maybe_update( + inner: base.GradientTransformation, + should_update_fn: Callable[[Array], Array] +) -> base.GradientTransformation: + """Calls the inner update function only at certain steps. + + Creates a transformation wrapper which counts the number of times the `update` + function has been called. This counter is passed to the `should_update_fn` to + decide when to call the inner update function. + + When not calling the inner update function, the `updates` and the inner state + are left untouched and just passed through. The step counter is increased + regardless. + + Args: + inner: the inner transformation. + should_update_fn: this function takes in a step counter (array of shape [] + and dtype int32), and returns a boolean array of shape []. + + Returns: + An `optax.GradientTransformation`. + """ + + def init_fn(params): + return MaybeUpdateState( + inner_state=inner.init(params), step=jnp.zeros([], dtype=jnp.int32)) + + def update_fn(updates, state, params=None): + + def do_update(_): + return inner.update(updates, state.inner_state, params) + + def reject_update(_): + return updates, state.inner_state + + updates, new_inner_state = lax.cond( + should_update_fn(state.step), do_update, reject_update, operand=None) + return updates, MaybeUpdateState(new_inner_state, + numerics.safe_int32_increment(state.step)) + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/wrappers_test.py b/optax_add_eve/_src/wrappers_test.py new file mode 100644 index 00000000..1bdfa95f --- /dev/null +++ b/optax_add_eve/_src/wrappers_test.py @@ -0,0 +1,623 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `wrappers.py`.""" + +import copy + +from absl.testing import absltest +from absl.testing import parameterized + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +from optax_add_eve._src import alias +from optax_add_eve._src import combine +from optax_add_eve._src import constrain +from optax_add_eve._src import transform +from optax_add_eve._src import update +from optax_add_eve._src import wrappers +import tree + + +def _build_sgd(): + return alias.sgd(1.) + + +def _build_stateful_sgd(): + # This SGD behaves like _build_sgd but also tests the optimizer state. The + # momentum is set to zero rather than None so that the momentum terms are + # calculated, but do not change the results. + return alias.sgd(1., momentum=0.) + + +class WrappersTest(parameterized.TestCase): + + def test_flatten(self): + def init_params(): + return (jnp.array([1., 2.]), jnp.array([3., 4.])) + + per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) + + # First calculate new params without flattening + optax_sgd_params = init_params() + sgd = alias.sgd(1e-2, 0.0) + state_sgd = sgd.init(optax_sgd_params) + updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) + sgd_params_no_flatten = update.apply_updates(optax_sgd_params, updates_sgd) + + # And now calculate new params with flattening + optax_sgd_params = init_params() + sgd = wrappers.flatten(sgd) + state_sgd = sgd.init(optax_sgd_params) + updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) + sgd_params_flatten = update.apply_updates(optax_sgd_params, updates_sgd) + + # Test that both give the same result + chex.assert_trees_all_close( + sgd_params_no_flatten, sgd_params_flatten, atol=1e-7, rtol=1e-7) + + @chex.variants(with_jit=True, without_jit=True, with_pmap=True) + @parameterized.named_parameters( + ('sgd', _build_sgd), + ('stateful_sgd', _build_stateful_sgd), + ) + def test_apply_if_finite(self, opt_builder): + one = jnp.ones([]) + nan = jnp.array(jnp.nan) + def fn(x): + return x * hk.get_parameter('p', [], init=hk.initializers.Constant(0.)) + + fn = hk.without_apply_rng(hk.transform(fn)) + params = fn.init(jax.random.PRNGKey(1905), one) + opt = wrappers.apply_if_finite(opt_builder(), 2) + state = opt.init(params) + grads_fn = jax.grad(self.variant(fn.apply)) + # Do one successful param update + grads = grads_fn(params, one) + updates, state = opt.update(grads, state, params) + params = update.apply_updates(params, updates) + # We know exactly what should be the value of params since we are + # effectively using sgd in all cases. + self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) + self.assertTrue(bool(state.last_finite)) + # Check 2 rejected param updates + for step in range(2): + grads = grads_fn(params, nan) + updates, state = opt.update(grads, state, params) + params = update.apply_updates(params, updates) + self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) + self.assertFalse(bool(state.last_finite)) + self.assertEqual(step + 1, int(state.notfinite_count)) + # Next successful param update + grads = grads_fn(params, one) + updates, state = opt.update(grads, state, params) + params = update.apply_updates(params, updates) + self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) + self.assertTrue(bool(state.last_finite)) + # Again 2 rejected param updates + for step in range(2): + grads = grads_fn(params, nan) + updates, state = opt.update(grads, state, params) + params = update.apply_updates(params, updates) + self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) + self.assertFalse(bool(state.last_finite)) + self.assertEqual(step + 1, int(state.notfinite_count)) + # Next param update with NaN is accepted since we reached maximum + grads = grads_fn(params, nan) + updates, state = opt.update(grads, state, params) + params = update.apply_updates(params, updates) + self.assertTrue(bool(jnp.isnan(jax.tree_util.tree_flatten(params)[0][0]))) + self.assertEqual(5, int(state.total_notfinite)) + + def test_apply_if_finite_pmap(self): + # Unlike in `test_apply_if_finite`: + # * pmap is applied to the gradient computation and the optimisation; + # * the NaNs are caused inside the function and do not come from the inputs. + half = jnp.ones([1]) / 2. + two = jnp.ones([1]) * 2. # Causes a NaN in arctanh + def fn(x): + return jnp.arctanh(x) * hk.get_parameter( + 'p', [], init=hk.initializers.Constant(0.)) + fn = hk.without_apply_rng(hk.transform(fn)) + + opt = wrappers.apply_if_finite(alias.sgd(1.), 2) + def fn_update(params, opt_state, x): + grads = jax.grad(fn.apply)(params, x) + grads = jax.lax.psum(grads, axis_name='i') + updates, new_opt_state = opt.update(grads, opt_state, params) + new_params = update.apply_updates(params, updates) + return new_params, new_opt_state + fn_update = jax.pmap(fn_update, axis_name='i') + + params = fn.init(jax.random.PRNGKey(1905), half) + opt_state = opt.init(params) + params = jax.tree_util.tree_map(lambda x: x[None], params) + opt_state = jax.tree_util.tree_map(lambda x: x[None], opt_state) + # Do one successful param update + params, opt_state = fn_update(params, opt_state, half) + self.assertTrue(bool(opt_state.last_finite)) + # Check 2 rejected param updates + for step in range(2): + params, opt_state = fn_update(params, opt_state, two) + self.assertFalse(bool(opt_state.last_finite)) + self.assertEqual(step + 1, int(opt_state.notfinite_count)) + # Next successful param update + params, opt_state = fn_update(params, opt_state, half) + self.assertTrue(bool(opt_state.last_finite)) + # Again 2 rejected param updates + for step in range(2): + params, opt_state = fn_update(params, opt_state, two) + self.assertFalse(bool(opt_state.last_finite)) + self.assertEqual(step + 1, int(opt_state.notfinite_count)) + # Next param update with NaN is accepted since we reached maximum + params, opt_state = fn_update(params, opt_state, two) + self.assertEqual(5, int(opt_state.total_notfinite)) + + @chex.variants(with_jit=True, without_jit=True, with_pmap=True) + def test_multi_steps(self): + batch_size = 32 + x_size = 7 + # Parameters should be updated only every `k_steps` optimisation steps. + k_steps = 4 + data = jnp.ones([batch_size, x_size]) + + def get_loss(x): + loss = jnp.sum(hk.Linear(10)(x)**2) + return loss + + loss_init, loss_apply = hk.without_apply_rng(hk.transform(get_loss)) + params = loss_init(jax.random.PRNGKey(1915), data) + + ms_opt = wrappers.MultiSteps( + # Use a non-trivial inner optimiser: + # * it has a state, + # * it requires the params for the update. + combine.chain(transform.scale_by_adam(), + transform.additive_weight_decay(1e-2), + transform.scale(-1e-4)), k_steps) + opt_init, opt_update = ms_opt.gradient_transformation() + + # Put the training in one function, to check that the update is indeed + # jittable. + def train_step(data, opt_state, params): + grad = jax.grad(loss_apply)(params, data) + updates, opt_state = opt_update(grad, opt_state, params) + return updates, opt_state + + opt_state = opt_init(params) + + prev_loss = loss_apply(params, data) + for idx in range(5 * k_steps): + updates, opt_state = self.variant(train_step)(data, opt_state, params) + new_params = update.apply_updates(params, updates) + new_loss = loss_apply(new_params, data) + if idx % k_steps < k_steps - 1: + # The parameters should not have changed and the loss should be + # constant. + jax.tree_util.tree_map( + np.testing.assert_array_equal, new_params, params) + np.testing.assert_equal(new_loss, prev_loss) + self.assertFalse(ms_opt.has_updated(opt_state)) + else: + # This is a step where parameters should actually have been updated, and + # the loss should accordingly go down. + np.testing.assert_array_less(new_loss, prev_loss) + prev_loss = new_loss + self.assertTrue(ms_opt.has_updated(opt_state)) + params = new_params + + def test_multi_steps_every_k_schedule(self): + # Test a non-trivial schedule which varies over time. + ms_opt = wrappers.MultiSteps( + alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3)) + opt_init, opt_update = ms_opt.gradient_transformation() + params = dict(a=jnp.zeros([])) + opt_state = opt_init(params) + grad = dict(a=jnp.zeros([])) + self.assertFalse(ms_opt.has_updated(opt_state)) + # First two steps have 1 mini-step per update. + for _ in range(2): + _, opt_state = opt_update(grad, opt_state, params) + self.assertTrue(ms_opt.has_updated(opt_state)) + # Subsequently, mini-steps should have 3 mini-steps per update. + for _ in range(5): + for _ in range(2): + _, opt_state = opt_update(grad, opt_state, params) + self.assertFalse(ms_opt.has_updated(opt_state)) + _, opt_state = opt_update(grad, opt_state, params) + self.assertTrue(ms_opt.has_updated(opt_state)) + + def test_multi_steps_computes_mean(self): + k_steps = 4 + ms_opt = wrappers.MultiSteps( + transform.scale(1.0), k_steps, use_grad_mean=True) + opt_init, opt_update = ms_opt.gradient_transformation() + params = dict(a=jnp.zeros([])) + opt_state = opt_init(params) + grads = [dict(a=jnp.ones([]) * i) for i in [1, 2, 3, 4]] + self.assertFalse(ms_opt.has_updated(opt_state)) + + # First 3 steps don't update. + for grad in grads[:-1]: + _, opt_state = opt_update(grad, opt_state, params) + self.assertFalse(ms_opt.has_updated(opt_state)) + + # Actual update. + new_params, opt_state = opt_update(grads[-1], opt_state, params) + self.assertTrue(ms_opt.has_updated(opt_state)) + np.testing.assert_array_equal(new_params['a'], 2.5) + + def test_skip_not_finite(self): + step = jnp.zeros([], dtype=jnp.int32) + + with self.subTest('test_pos_inf'): + should_skip, skip_state = wrappers.skip_not_finite( + [jnp.array(float('inf')), jnp.zeros([])], step, None) + self.assertTrue(bool(should_skip)) + self.assertTrue(bool(skip_state['should_skip'])) + self.assertEqual(int(skip_state['num_not_finite']), 1) + + with self.subTest('test_neg_inf'): + should_skip, skip_state = wrappers.skip_not_finite( + [jnp.array(-float('inf')), jnp.zeros([])], step, None) + self.assertTrue(bool(should_skip)) + self.assertTrue(bool(skip_state['should_skip'])) + self.assertEqual(int(skip_state['num_not_finite']), 1) + + with self.subTest('test_nan'): + should_skip, skip_state = wrappers.skip_not_finite( + [jnp.array(float('nan')), jnp.zeros([])], step, None) + self.assertTrue(bool(should_skip)) + self.assertTrue(bool(skip_state['should_skip'])) + self.assertEqual(int(skip_state['num_not_finite']), 1) + + with self.subTest('test_finite'): + should_skip, skip_state = wrappers.skip_not_finite( + [jnp.array(11.), jnp.zeros([])], step, None) + self.assertFalse(bool(should_skip)) + self.assertFalse(bool(skip_state['should_skip'])) + self.assertEqual(int(skip_state['num_not_finite']), 0) + + def test_skip_large_updates(self): + step = jnp.zeros([], dtype=jnp.int32) + + with self.subTest('test_inf'): + should_skip, skip_state = wrappers.skip_large_updates( + [jnp.array(float('inf')), jnp.zeros([])], step, None, 100.) + self.assertTrue(bool(should_skip)) + self.assertTrue(bool(skip_state['should_skip'])) + self.assertEqual(float(skip_state['norm_squared']), float('inf')) + + with self.subTest('test_nan'): + should_skip, skip_state = wrappers.skip_large_updates( + [jnp.array(float('nan')), jnp.zeros([])], step, None, 100.) + self.assertTrue(bool(should_skip)) + self.assertTrue(bool(skip_state['should_skip'])) + # Recall that NaN != NaN. + norm_squared = float(skip_state['norm_squared']) + self.assertNotEqual(norm_squared, norm_squared) + + with self.subTest('test_large'): + should_skip, skip_state = wrappers.skip_large_updates( + [jnp.array(11.), jnp.zeros([])], step, None, 100.) + self.assertTrue(bool(should_skip)) + self.assertTrue(bool(skip_state['should_skip'])) + self.assertEqual(float(skip_state['norm_squared']), 121.) + + with self.subTest('test_small'): + should_skip, skip_state = wrappers.skip_large_updates( + [jnp.zeros([]), jnp.zeros([])], step, None, 100.) + self.assertFalse(bool(should_skip)) + self.assertFalse(bool(skip_state['should_skip'])) + self.assertEqual(float(skip_state['norm_squared']), 0.) + + def test_multi_steps_skip_not_finite(self): + k_steps = 2 + ms_opt = wrappers.MultiSteps( + alias.sgd(1.), k_steps, should_skip_update_fn=wrappers.skip_not_finite) + opt_init, opt_update = ms_opt.gradient_transformation() + opt_init = jax.jit(opt_init) + opt_update = jax.jit(opt_update) + params = dict(a=jnp.zeros([])) + opt_state = opt_init(params) + + with self.subTest('test_good_updates'): + updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + self.assertEqual(int(opt_state.mini_step), 1) + params = update.apply_updates(params, updates) + updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + self.assertEqual(int(opt_state.mini_step), 0) + params = update.apply_updates(params, updates) + np.testing.assert_array_equal(params['a'], -jnp.ones([])) + + with self.subTest('test_inf_updates'): + updates, opt_state = opt_update( + dict(a=jnp.array(float('inf'))), opt_state, params) + self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step + params = update.apply_updates(params, updates) + np.testing.assert_array_equal(params['a'], -jnp.ones([])) + + with self.subTest('test_nan_updates'): + updates, opt_state = opt_update( + dict(a=jnp.full([], float('nan'))), opt_state, params) + self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step + params = update.apply_updates(params, updates) + np.testing.assert_array_equal(params['a'], -jnp.ones([])) + + with self.subTest('test_final_good_updates'): + updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + self.assertEqual(int(opt_state.mini_step), 1) + params = update.apply_updates(params, updates) + updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) + self.assertEqual(int(opt_state.mini_step), 0) + params = update.apply_updates(params, updates) + np.testing.assert_array_equal(params['a'], -jnp.full([], 2.)) + + +class MaskedTest(chex.TestCase): + """Tests for the masked wrapper.""" + + @chex.all_variants + @parameterized.named_parameters( + ('sgd', _build_sgd, False), + ('stateful_sgd', _build_stateful_sgd, False), + ('sgd_w_mask_fn', _build_sgd, True), + ('stateful_sgd_w_mask_fn', _build_stateful_sgd, True), + ) + def test_masked(self, opt_builder, use_fn): + mask = {'a': True, + 'b': [False, True], + 'c': {'d': True, 'e': (False, True)}} + mask_arg = lambda _: mask if use_fn else mask + params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} + params = jax.tree_util.tree_map(jnp.asarray, params) + input_updates = jax.tree_util.tree_map(lambda x: x/10., params) + + # Negate the updates wherever the mask is True + def masked_negate(updates): + return jax.tree_util.tree_map( + lambda upd, m: -upd if m else upd, updates, mask) + correct_updates = masked_negate(input_updates) + + init_fn, update_fn = wrappers.masked(opt_builder(), mask_arg) + update_fn = self.variant(update_fn) + state = self.variant(init_fn)(params) + updates, state = update_fn(input_updates, state, params) + chex.assert_trees_all_close(updates, correct_updates) + + # Check repeated application, this time with no params. + correct_updates = masked_negate(correct_updates) + updates, state = update_fn(updates, state) + chex.assert_trees_all_close(updates, correct_updates) + + @chex.all_variants + @parameterized.named_parameters( + ('sgd', _build_sgd), + ('stateful_sgd', _build_stateful_sgd), + ) + def test_prefix_mask(self, opt_builder): + """Test when the mask is a prefix of the updates PyTree.""" + mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}} + params = {'a': 1., 'b': {'f': 2.}, 'c': {'d': 3., 'e': ([4., 5.], 6.)}} + params = jax.tree_util.tree_map(jnp.asarray, params) + input_updates = jax.tree_util.tree_map(lambda x: x/10., params) + + # Negate the updates wherever the mask (or mask parent) is True + def _masked_sgd_on_updates(m, upd): + return jax.tree_util.tree_map(lambda x: -x, upd) if m else upd + correct_updates = jax.tree_util.tree_map( + _masked_sgd_on_updates, mask, input_updates) + + init_fn, update_fn = wrappers.masked(opt_builder(), mask) + update_fn = self.variant(update_fn) + state = self.variant(init_fn)(params) + updates, state = update_fn(input_updates, state, params) + chex.assert_trees_all_close(updates, correct_updates) + + # Check repeated application, this time with no params. + correct_updates = jax.tree_util.tree_map( + _masked_sgd_on_updates, mask, correct_updates) + updates, state = update_fn(updates, state) + chex.assert_trees_all_close(updates, correct_updates) + + @chex.all_variants + def test_update_requires_params(self): + weight_decay = 0.1 + mask = {'a': True, + 'b': [False, True], + 'c': {'d': True, 'e': (False, True)}} + params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} + params = jax.tree_util.tree_map(jnp.asarray, params) + input_updates = jax.tree_util.tree_map(lambda x: x/10., params) + + correct_updates = jax.tree_util.tree_map( + lambda m, u, p: u + weight_decay * p if m else u, + mask, input_updates, params) + + init_fn, update_fn = wrappers.masked( + transform.additive_weight_decay(weight_decay), mask) + update_fn = self.variant(update_fn) + + state = self.variant(init_fn)(params) + updates, state = update_fn(input_updates, state, params) + chex.assert_trees_all_close(updates, correct_updates) + + params = update.apply_updates(params, updates) + + # Test repeated application + new_correct_updates = jax.tree_util.tree_map( + lambda m, u, p: u + weight_decay * p if m else u, + mask, correct_updates, params) + updates, state = update_fn(correct_updates, state, params) + chex.assert_trees_all_close(updates, new_correct_updates) + + @parameterized.parameters(list, tuple, dict) + def test_empty(self, container): + init_fn, update_fn = wrappers.masked(_build_sgd(), container()) + update_fn(container(), init_fn(container())) + + @parameterized.parameters( + (False, False), (False, True), (True, False), (True, True)) + def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn): + mask = {'a': True, + 'b': [False, True], + 'c': {'d': True, 'e': (False, True)}} + mask_arg = lambda _: mask if use_fn else mask + params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} + params = jax.tree_util.tree_map(jnp.asarray, params) + + if extra_key_in_mask: + mask['c']['extra'] = True + else: + params['c']['extra'] = 7 + + init_fn = wrappers.masked(_build_sgd(), mask_arg)[0] + with self.assertRaises(ValueError): + init_fn(params) + + @chex.all_variants + def test_mask_fn(self): + params = {'a': jnp.ones((1, 2)), 'b': (jnp.ones((1,)), np.ones((1, 2, 3)))} + mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim > 1, p) + init_fn, update_fn = wrappers.masked(transform.add_decayed_weights(0.1), + mask_fn) + update_fn = self.variant(update_fn) + + state = self.variant(init_fn)(params) + grads = jax.tree_util.tree_map(lambda x: x*2, params) + updates, state = update_fn(grads, state, params) + np.testing.assert_allclose(updates['a'], grads['a'] + 0.1*params['a']) + np.testing.assert_allclose(updates['b'][0], grads['b'][0]) + np.testing.assert_allclose(updates['b'][1], + grads['b'][1] + 0.1*params['b'][1]) + + @chex.all_variants + @parameterized.named_parameters( + ('sgd', _build_sgd), + ('stateful_sgd', _build_stateful_sgd), + ) + def test_nested_mask(self, opt_builder): + # https://github.com/deepmind/optax/issues/271 + params = {'linear_1': {'w': jnp.zeros((1, 1)), 'b': jnp.zeros(1)}, + 'linear_2': {'w': jnp.zeros((1, 2)), 'b': jnp.zeros(2)}, + 'linear_3': {'w': jnp.zeros((2, 3)), 'b': jnp.zeros(3)}} + + outer_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim > 1, p) + inner_mask = jax.tree_util.tree_map(lambda _: True, params) + inner_mask['linear_2'] = False + + inner = wrappers.masked(opt_builder(), inner_mask) + init_fn, update_fn = wrappers.masked(inner, outer_mask) + + input_updates = jax.tree_util.tree_map(jnp.ones_like, params) + correct_updates = copy.deepcopy(input_updates) + correct_updates['linear_1']['w'] *= -1.0 + correct_updates['linear_3']['w'] *= -1.0 + + state = self.variant(init_fn)(params) + updates, state = self.variant(update_fn)(input_updates, state, params) + chex.assert_trees_all_close(updates, correct_updates) + + @chex.all_variants + def test_masked_state_structure(self): + # https://github.com/deepmind/optax/issues/271 + params = {'a': [jnp.ones(1), (jnp.ones(2), jnp.ones(3))], + 'b': {'c': jnp.ones(4), 'd': jnp.ones(5)}} + mask = {'a': [True, (True, False)], 'b': False} + tx = wrappers.masked(_build_stateful_sgd(), mask) + trace = self.variant(tx.init)(params).inner_state[0].trace + expected_trace = { + 'a': [jnp.zeros(1), (jnp.zeros(2), wrappers.MaskedNode())], + 'b': wrappers.MaskedNode() + } + chex.assert_tree_all_equal_structs(trace, expected_trace) + + def test_masked_state_is_compatible_with_deepmind_tree(self): + """Checks that the masked state is compatible with deepmind/tree. + + DeepMind's tree library and `jax.tree_util` have slightly different + behavior: jax treats `None`s as tree nodes without children while + deepmind/tree treats them as leaves with `None` values. This has led to bugs + when users used deepmind/tree to manipulate masked optimizer states. + + This test ensures that masked parts of the optimizer state are also ignored + by deepmind/tree. + """ + params = { + 'a': [jnp.ones(1), (jnp.ones(2), jnp.ones(3))], + 'b': [jnp.ones(4)] + } + mask = {'a': [True, (True, False)], 'b': False} + opt_init, _ = wrappers.masked(_build_stateful_sgd(), mask) + state = opt_init(params) + chex.assert_trees_all_equal(tree.map_structure(np.array, state), state) + + +class MaybeUpdateTest(chex.TestCase): + """Tests for the maybe_update wrapper.""" + + NUM_STEPS = 3 + + @chex.all_variants + def test_stateless_inner(self): + params = jnp.zeros([]) + grads = jnp.ones([]) + + def should_update(step): + return step < MaybeUpdateTest.NUM_STEPS + + opt = wrappers.maybe_update(transform.scale(2.), should_update) + state = opt.init(params) + update_fn = self.variant(opt.update) + for _ in range(MaybeUpdateTest.NUM_STEPS): + updates, state = update_fn(grads, state) + self.assertEqual(updates, 2.) + # Further updates stop calling the inner optimiser. + for _ in range(5): + updates, state = update_fn(grads, state) + self.assertEqual(updates, 1.) + + @chex.all_variants + def test_statefull_inner(self): + params = jnp.zeros([]) + grads_with_nan = jnp.array(float('nan')) + grads = jnp.ones([]) + + def should_update(step): + return step < MaybeUpdateTest.NUM_STEPS + + opt = wrappers.maybe_update(constrain.zero_nans(), should_update) + state = opt.init(params) + update_fn = self.variant(opt.update) + for _ in range(MaybeUpdateTest.NUM_STEPS - 1): + updates, state = update_fn(grads_with_nan, state) + self.assertEqual(updates, 0.) + self.assertEqual(state.inner_state.found_nan, True) + updates, state = update_fn(grads, state) + self.assertEqual(updates, 1.) + self.assertEqual(state.inner_state.found_nan, False) + # Further updates stop calling the inner optimiser. + for _ in range(5): + updates, state = update_fn(grads_with_nan, state) + # Warning: do not use assertEqual with a NaN as NaN == NaN returns False. + self.assertTrue(jnp.isnan(updates)) + # Inner state is not be updated. + self.assertEqual(state.inner_state.found_nan, False) + + +if __name__ == '__main__': + absltest.main() diff --git a/optax_add_eve/experimental/__init__.py b/optax_add_eve/experimental/__init__.py new file mode 100644 index 00000000..61cb5150 --- /dev/null +++ b/optax_add_eve/experimental/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental features in Optax. + +Features may be removed or modified at any time. +""" + +from optax_add_eve._src.experimental.complex_valued import split_real_and_imaginary +from optax_add_eve._src.experimental.complex_valued import SplitRealAndImaginaryState +from optax_add_eve._src.experimental.extra_args import GradientTransformationWithExtraArgs +from optax_add_eve._src.experimental.extra_args import named_chain diff --git a/optax_add_eve/optax_test.py b/optax_add_eve/optax_test.py new file mode 100644 index 00000000..ea6af7b9 --- /dev/null +++ b/optax_add_eve/optax_test.py @@ -0,0 +1,29 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax.""" + +from absl.testing import absltest +import optax_add_eve + + +class OptaxTest(absltest.TestCase): + """Test optax can be imported correctly.""" + + def test_import(self): + self.assertTrue(hasattr(optax_add_eve, 'GradientTransformation')) + + +if __name__ == '__main__': + absltest.main() From 5e63aeeca7af5f3233bef5071a08c7b632ed5287 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:32:03 -0600 Subject: [PATCH 03/35] reverted rename --- docs/conf.py | 4 ++-- docs/ext/coverage_check.py | 8 ++++---- examples/differentially_private_sgd.py | 10 +++++----- examples/flax_example.py | 10 +++++----- examples/haiku_example.py | 10 +++++----- examples/lookahead_mnist.py | 12 ++++++------ examples/mnist.py | 10 +++++----- examples/mnist_test.py | 4 ++-- 8 files changed, 34 insertions(+), 34 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index fc8fc231..936006a0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -121,7 +121,7 @@ def new_process_docstring(app, what, name, obj, options, lines): sys.path.insert(0, os.path.abspath('../')) sys.path.append(os.path.abspath('ext')) -import optax_add_eve +import optax from sphinxcontrib import katex # -- Project information ----------------------------------------------------- @@ -246,7 +246,7 @@ def linkcode_resolve(domain, info): # TODO(slebedev): support tags after we release an initial version. return 'https://github.com/deepmind/optax/tree/master/optax/%s#L%d#L%d' % ( os.path.relpath(filename, start=os.path.dirname( - optax_add_eve.__file__)), lineno, lineno + len(source) - 1) + optax.__file__)), lineno, lineno + len(source) - 1) # -- Intersphinx configuration ----------------------------------------------- diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index 9b42626e..c31cb75f 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -16,8 +16,8 @@ from typing import Any, Mapping -import optax_add_eve -from optax_add_eve._src import test_utils +import optax +from optax._src import test_utils from sphinx import application from sphinx import builders from sphinx import errors @@ -25,7 +25,7 @@ def optax_public_symbols(): names = set() - for module_name, module in test_utils.find_internal_python_modules(optax_add_eve): + for module_name, module in test_utils.find_internal_python_modules(optax): for name in module.__all__: names.add(module_name + "." + name) return names @@ -55,4 +55,4 @@ def finish(self) -> None: def setup(app: application.Sphinx) -> Mapping[str, Any]: app.add_builder(OptaxCoverageCheck) - return dict(version=optax_add_eve.__version__, parallel_read_safe=True) + return dict(version=optax.__version__, parallel_read_safe=True) diff --git a/examples/differentially_private_sgd.py b/examples/differentially_private_sgd.py index e011713a..5cce0953 100644 --- a/examples/differentially_private_sgd.py +++ b/examples/differentially_private_sgd.py @@ -70,7 +70,7 @@ import jax from jax.example_libraries import stax import jax.numpy as jnp -import optax_add_eve +import optax # pylint: disable=g-bad-import-order import datasets # Located in the examples folder. @@ -119,7 +119,7 @@ def compute_epsilon(steps, target_delta=1e-5): def loss_fn(params, batch): logits = predict(params, batch['image']) - return optax_add_eve.softmax_cross_entropy(logits, batch['label']).mean(), logits + return optax.softmax_cross_entropy(logits, batch['label']).mean(), logits @jax.jit @@ -136,12 +136,12 @@ def main(_): full_test_batch = next(test_dataset.as_numpy_iterator()) if FLAGS.dpsgd: - tx = optax_add_eve.dpsgd(learning_rate=FLAGS.learning_rate, + tx = optax.dpsgd(learning_rate=FLAGS.learning_rate, l2_norm_clip=FLAGS.l2_norm_clip, noise_multiplier=FLAGS.noise_multiplier, seed=FLAGS.seed) else: - tx = optax_add_eve.sgd(learning_rate=FLAGS.learning_rate) + tx = optax.sgd(learning_rate=FLAGS.learning_rate) @jax.jit def train_step(params, opt_state, batch): @@ -154,7 +154,7 @@ def train_step(params, opt_state, batch): grads, _ = grad_fn(params, batch) updates, new_opt_state = tx.update(grads, opt_state, params) - new_params = optax_add_eve.apply_updates(params, updates) + new_params = optax.apply_updates(params, updates) return new_params, new_opt_state key = jax.random.PRNGKey(FLAGS.seed) diff --git a/examples/flax_example.py b/examples/flax_example.py index d47460a6..a507a02c 100644 --- a/examples/flax_example.py +++ b/examples/flax_example.py @@ -19,7 +19,7 @@ from flax import linen as nn import jax import jax.numpy as jnp -import optax_add_eve +import optax def main(argv): @@ -70,11 +70,11 @@ def squared_error(x, y): # Construct a simple Adam optimiser using the transforms in optax. # You could also just use the `optax.adam` alias, but we show here how # to do so manually so that you may construct your own `custom` optimiser. - tx = optax_add_eve.chain( + tx = optax.chain( # Set the parameters of Adam. Note the learning_rate is not here. - optax_add_eve.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), + optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), # Put a minus sign to *minimise* the loss. - optax_add_eve.scale(-learning_rate) + optax.scale(-learning_rate) ) # Create optimiser state. @@ -89,7 +89,7 @@ def squared_error(x, y): # Update the optimiser state, create an update to the params. updates, opt_state = tx.update(grads, opt_state) # Update the parameters. - params = optax_add_eve.apply_updates(params, updates) + params = optax.apply_updates(params, updates) print(f'Loss[{step}] = {loss_val}') diff --git a/examples/haiku_example.py b/examples/haiku_example.py index 0854ab87..3d8bbe2d 100644 --- a/examples/haiku_example.py +++ b/examples/haiku_example.py @@ -19,7 +19,7 @@ import haiku as hk import jax import jax.numpy as jnp -import optax_add_eve +import optax def main(argv): @@ -48,11 +48,11 @@ def mean_square_loss(params, x): # Construct a simple Adam optimiser using the transforms in optax. # You could also just use the `optax.adam` alias, but we show here how # to do so manually so that you may construct your own `custom` optimiser. - opt_init, opt_update = optax_add_eve.chain( + opt_init, opt_update = optax.chain( # Set the parameters of Adam. Note the learning_rate is not here. - optax_add_eve.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), + optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), # Put a minus sign to *minimise* the loss. - optax_add_eve.scale(-learning_rate) + optax.scale(-learning_rate) ) # Initialise the model's parameters and the optimiser's state. @@ -71,7 +71,7 @@ def mean_square_loss(params, x): # Transform the gradients using the optimiser. updates, opt_state = opt_update(grad, opt_state, params) # Update parameters. - params = optax_add_eve.apply_updates(params, updates) + params = optax.apply_updates(params, updates) if __name__ == '__main__': diff --git a/examples/lookahead_mnist.py b/examples/lookahead_mnist.py index df08264c..fce49efa 100644 --- a/examples/lookahead_mnist.py +++ b/examples/lookahead_mnist.py @@ -19,7 +19,7 @@ import jax from jax import random import jax.numpy as jnp -import optax_add_eve +import optax # pylint: disable=g-bad-import-order import datasets # Located in the examples folder. @@ -45,18 +45,18 @@ def main(unused_argv) -> None: (*HIDDEN_SIZES, num_classes)) # Set up the fast optimizer (adam) and wrap lookahead around it. - fast_optimizer = optax_add_eve.adam(LEARNING_RATE) - optimizer = optax_add_eve.lookahead(fast_optimizer, SYNC_PERIOD, SLOW_LEARNING_RATE) + fast_optimizer = optax.adam(LEARNING_RATE) + optimizer = optax.lookahead(fast_optimizer, SYNC_PERIOD, SLOW_LEARNING_RATE) def get_loss(fast_params, batch): logits = apply_params_fn(fast_params, batch['image']) - return jnp.mean(optax_add_eve.softmax_cross_entropy(logits, batch['label'])) + return jnp.mean(optax.softmax_cross_entropy(logits, batch['label'])) @jax.jit def train_step(params, optimizer_state, batch): grads = jax.grad(get_loss)(params.fast, batch) updates, opt_state = optimizer.update(grads, optimizer_state, params) - return optax_add_eve.apply_updates(params, updates), opt_state + return optax.apply_updates(params, updates), opt_state example_input = next(train_dataset.as_numpy_iterator())['image'] initial_params = init_params_fn(random.PRNGKey(SEED), example_input) @@ -66,7 +66,7 @@ def train_step(params, optimizer_state, batch): # initial model parameters. The first line below is only necessary for the # lookahead wrapper; without it the initial parameters could be used in the # initialization function of the optimizer directly. - params = optax_add_eve.LookaheadParams.init_synced(initial_params) + params = optax.LookaheadParams.init_synced(initial_params) opt_state = optimizer.init(params) # Training loop diff --git a/examples/mnist.py b/examples/mnist.py index ac1c395a..d79f97af 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -22,7 +22,7 @@ import jax from jax import random import jax.numpy as jnp -import optax_add_eve +import optax # pylint: disable=g-bad-import-order import datasets # Located in the examples folder. @@ -70,7 +70,7 @@ def mlp_model(inputs: chex.Array) -> chex.Array: return hk.without_apply_rng(mlp_model) -def train_on_mnist(optimizer: optax_add_eve.GradientTransformation, +def train_on_mnist(optimizer: optax.GradientTransformation, hidden_sizes: Sequence[int]) -> float: """Trains an MLP on MNIST using a given optimizer. @@ -90,13 +90,13 @@ def train_on_mnist(optimizer: optax_add_eve.GradientTransformation, def get_loss(params, batch): logits = apply_params_fn(params, batch['image']) - return jnp.mean(optax_add_eve.softmax_cross_entropy(logits, batch['label'])) + return jnp.mean(optax.softmax_cross_entropy(logits, batch['label'])) @jax.jit def train_step(params, optimizer_state, batch): grads = jax.grad(get_loss)(params, batch) updates, opt_state = optimizer.update(grads, optimizer_state, params) - return optax_add_eve.apply_updates(params, updates), opt_state + return optax.apply_updates(params, updates), opt_state example_input = next(train_dataset.as_numpy_iterator())['image'] params = init_params_fn(random.PRNGKey(SEED), example_input) @@ -116,7 +116,7 @@ def train_step(params, optimizer_state, batch): def main(unused_argv): """Trains an MLP on MNIST using the adam optimizers.""" - return train_on_mnist(optax_add_eve.adam(LEARNING_RATE), DEFAULT_HIDDEN_SIZES) + return train_on_mnist(optax.adam(LEARNING_RATE), DEFAULT_HIDDEN_SIZES) if __name__ == '__main__': diff --git a/examples/mnist_test.py b/examples/mnist_test.py index 9c0d8f48..afc8b636 100644 --- a/examples/mnist_test.py +++ b/examples/mnist_test.py @@ -21,7 +21,7 @@ import haiku as hk import jax import numpy as np -import optax_add_eve +import optax import tensorflow as tf # pylint: disable=g-bad-import-order @@ -71,7 +71,7 @@ def test_train_on_mnist_can_fit_linear_mock_data(self): dataset = tf.data.Dataset.from_tensor_slices(data).repeat(8).batch(10) with mock.patch.object( datasets, 'load_image_dataset', return_value=dataset): - final_accuracy = mnist.train_on_mnist(optax_add_eve.adam(0.01), hidden_sizes=(1,)) + final_accuracy = mnist.train_on_mnist(optax.adam(0.01), hidden_sizes=(1,)) self.assertEqual(final_accuracy, 1.) From 3f3b2a06f04c6846fb2b3e167350a66c45a2bb80 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:32:12 -0600 Subject: [PATCH 04/35] reverted rename --- optax_add_eve/__init__.py | 347 ----- optax_add_eve/_src/alias.py | 926 ------------- optax_add_eve/_src/alias_test.py | 186 --- optax_add_eve/_src/base_test.py | 139 -- optax_add_eve/_src/clipping.py | 222 --- optax_add_eve/_src/clipping_test.py | 96 -- optax_add_eve/_src/combine.py | 150 -- optax_add_eve/_src/combine_test.py | 152 --- optax_add_eve/_src/constrain.py | 97 -- optax_add_eve/_src/constrain_test.py | 116 -- optax_add_eve/_src/control_variates.py | 419 ------ optax_add_eve/_src/control_variates_test.py | 595 -------- optax_add_eve/_src/equivalence_test.py | 176 --- .../_src/experimental/complex_valued.py | 121 -- .../_src/experimental/complex_valued_test.py | 79 -- optax_add_eve/_src/experimental/extra_args.py | 167 --- .../_src/experimental/extra_args_test.py | 65 - optax_add_eve/_src/factorized.py | 199 --- optax_add_eve/_src/factorized_test.py | 45 - optax_add_eve/_src/float64_test.py | 94 -- optax_add_eve/_src/linear_algebra.py | 201 --- optax_add_eve/_src/linear_algebra_test.py | 62 - optax_add_eve/_src/lookahead.py | 192 --- optax_add_eve/_src/lookahead_test.py | 140 -- optax_add_eve/_src/loss.py | 521 ------- optax_add_eve/_src/loss_test.py | 500 ------- optax_add_eve/_src/numerics_test.py | 112 -- optax_add_eve/_src/privacy.py | 74 - optax_add_eve/_src/privacy_test.py | 112 -- optax_add_eve/_src/schedule.py | 620 --------- optax_add_eve/_src/schedule_test.py | 649 --------- optax_add_eve/_src/second_order_test.py | 93 -- .../_src/stochastic_gradient_estimators.py | 317 ----- .../stochastic_gradient_estimators_test.py | 371 ----- optax_add_eve/_src/transform.py | 1206 ----------------- optax_add_eve/_src/transform_test.py | 305 ----- optax_add_eve/_src/update.py | 103 -- optax_add_eve/_src/update_test.py | 83 -- optax_add_eve/_src/utils.py | 152 --- optax_add_eve/_src/utils_test.py | 65 - optax_add_eve/_src/wrappers.py | 547 -------- optax_add_eve/_src/wrappers_test.py | 623 --------- optax_add_eve/experimental/__init__.py | 23 - optax_add_eve/optax_test.py | 29 - setup.py | 2 +- 45 files changed, 1 insertion(+), 11492 deletions(-) delete mode 100644 optax_add_eve/__init__.py delete mode 100644 optax_add_eve/_src/alias.py delete mode 100644 optax_add_eve/_src/alias_test.py delete mode 100644 optax_add_eve/_src/base_test.py delete mode 100644 optax_add_eve/_src/clipping.py delete mode 100644 optax_add_eve/_src/clipping_test.py delete mode 100644 optax_add_eve/_src/combine.py delete mode 100644 optax_add_eve/_src/combine_test.py delete mode 100644 optax_add_eve/_src/constrain.py delete mode 100644 optax_add_eve/_src/constrain_test.py delete mode 100644 optax_add_eve/_src/control_variates.py delete mode 100644 optax_add_eve/_src/control_variates_test.py delete mode 100644 optax_add_eve/_src/equivalence_test.py delete mode 100644 optax_add_eve/_src/experimental/complex_valued.py delete mode 100644 optax_add_eve/_src/experimental/complex_valued_test.py delete mode 100644 optax_add_eve/_src/experimental/extra_args.py delete mode 100644 optax_add_eve/_src/experimental/extra_args_test.py delete mode 100644 optax_add_eve/_src/factorized.py delete mode 100644 optax_add_eve/_src/factorized_test.py delete mode 100644 optax_add_eve/_src/float64_test.py delete mode 100644 optax_add_eve/_src/linear_algebra.py delete mode 100644 optax_add_eve/_src/linear_algebra_test.py delete mode 100644 optax_add_eve/_src/lookahead.py delete mode 100644 optax_add_eve/_src/lookahead_test.py delete mode 100644 optax_add_eve/_src/loss.py delete mode 100644 optax_add_eve/_src/loss_test.py delete mode 100644 optax_add_eve/_src/numerics_test.py delete mode 100644 optax_add_eve/_src/privacy.py delete mode 100644 optax_add_eve/_src/privacy_test.py delete mode 100644 optax_add_eve/_src/schedule.py delete mode 100644 optax_add_eve/_src/schedule_test.py delete mode 100644 optax_add_eve/_src/second_order_test.py delete mode 100644 optax_add_eve/_src/stochastic_gradient_estimators.py delete mode 100644 optax_add_eve/_src/stochastic_gradient_estimators_test.py delete mode 100644 optax_add_eve/_src/transform.py delete mode 100644 optax_add_eve/_src/transform_test.py delete mode 100644 optax_add_eve/_src/update.py delete mode 100644 optax_add_eve/_src/update_test.py delete mode 100644 optax_add_eve/_src/utils.py delete mode 100644 optax_add_eve/_src/utils_test.py delete mode 100644 optax_add_eve/_src/wrappers.py delete mode 100644 optax_add_eve/_src/wrappers_test.py delete mode 100644 optax_add_eve/experimental/__init__.py delete mode 100644 optax_add_eve/optax_test.py diff --git a/optax_add_eve/__init__.py b/optax_add_eve/__init__.py deleted file mode 100644 index ac576393..00000000 --- a/optax_add_eve/__init__.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Optax: composable gradient processing and optimization, in JAX.""" - -from optax_add_eve import experimental -from optax_add_eve._src.alias import adabelief -from optax_add_eve._src.alias import adafactor -from optax_add_eve._src.alias import adagrad -from optax_add_eve._src.alias import adam -from optax_add_eve._src.alias import adamax -from optax_add_eve._src.alias import adamaxw -from optax_add_eve._src.alias import adamw -from optax_add_eve._src.alias import amsgrad -from optax_add_eve._src.alias import dpsgd -from optax_add_eve._src.alias import fromage -from optax_add_eve._src.alias import lamb -from optax_add_eve._src.alias import lars -from optax_add_eve._src.alias import MaskOrFn -from optax_add_eve._src.alias import noisy_sgd -from optax_add_eve._src.alias import novograd -from optax_add_eve._src.alias import optimistic_gradient_descent -from optax_add_eve._src.alias import radam -from optax_add_eve._src.alias import rmsprop -from optax_add_eve._src.alias import ScalarOrSchedule -from optax_add_eve._src.alias import sgd -from optax_add_eve._src.alias import sm3 -from optax_add_eve._src.alias import yogi -from optax_add_eve._src.base import EmptyState -from optax_add_eve._src.base import GradientTransformation -from optax_add_eve._src.base import identity -from optax_add_eve._src.base import OptState -from optax_add_eve._src.base import Params -from optax_add_eve._src.base import Schedule -from optax_add_eve._src.base import set_to_zero -from optax_add_eve._src.base import stateless -from optax_add_eve._src.base import stateless_with_tree_map -from optax_add_eve._src.base import TransformInitFn -from optax_add_eve._src.base import TransformUpdateFn -from optax_add_eve._src.base import Updates -from optax_add_eve._src.clipping import adaptive_grad_clip -from optax_add_eve._src.clipping import AdaptiveGradClipState -from optax_add_eve._src.clipping import clip -from optax_add_eve._src.clipping import clip_by_block_rms -from optax_add_eve._src.clipping import clip_by_global_norm -from optax_add_eve._src.clipping import ClipByGlobalNormState -from optax_add_eve._src.clipping import ClipState -from optax_add_eve._src.clipping import per_example_global_norm_clip -from optax_add_eve._src.combine import chain -from optax_add_eve._src.combine import multi_transform -from optax_add_eve._src.combine import MultiTransformState -from optax_add_eve._src.constrain import keep_params_nonnegative -from optax_add_eve._src.constrain import NonNegativeParamsState -from optax_add_eve._src.constrain import zero_nans -from optax_add_eve._src.constrain import ZeroNansState -from optax_add_eve._src.control_variates import control_delta_method -from optax_add_eve._src.control_variates import control_variates_jacobians -from optax_add_eve._src.control_variates import moving_avg_baseline -from optax_add_eve._src.factorized import FactoredState -from optax_add_eve._src.factorized import scale_by_factored_rms -from optax_add_eve._src.linear_algebra import global_norm -from optax_add_eve._src.linear_algebra import matrix_inverse_pth_root -from optax_add_eve._src.linear_algebra import power_iteration -from optax_add_eve._src.lookahead import lookahead -from optax_add_eve._src.lookahead import LookaheadParams -from optax_add_eve._src.lookahead import LookaheadState -from optax_add_eve._src.loss import cosine_distance -from optax_add_eve._src.loss import cosine_similarity -from optax_add_eve._src.loss import ctc_loss -from optax_add_eve._src.loss import ctc_loss_with_forward_probs -from optax_add_eve._src.loss import hinge_loss -from optax_add_eve._src.loss import huber_loss -from optax_add_eve._src.loss import l2_loss -from optax_add_eve._src.loss import log_cosh -from optax_add_eve._src.loss import sigmoid_binary_cross_entropy -from optax_add_eve._src.loss import smooth_labels -from optax_add_eve._src.loss import softmax_cross_entropy -from optax_add_eve._src.loss import softmax_cross_entropy_with_integer_labels -from optax_add_eve._src.numerics import safe_int32_increment -from optax_add_eve._src.numerics import safe_norm -from optax_add_eve._src.numerics import safe_root_mean_squares -from optax_add_eve._src.privacy import differentially_private_aggregate -from optax_add_eve._src.privacy import DifferentiallyPrivateAggregateState -from optax_add_eve._src.schedule import constant_schedule -from optax_add_eve._src.schedule import cosine_decay_schedule -from optax_add_eve._src.schedule import cosine_onecycle_schedule -from optax_add_eve._src.schedule import exponential_decay -from optax_add_eve._src.schedule import inject_hyperparams -from optax_add_eve._src.schedule import InjectHyperparamsState -from optax_add_eve._src.schedule import join_schedules -from optax_add_eve._src.schedule import linear_onecycle_schedule -from optax_add_eve._src.schedule import linear_schedule -from optax_add_eve._src.schedule import piecewise_constant_schedule -from optax_add_eve._src.schedule import piecewise_interpolate_schedule -from optax_add_eve._src.schedule import polynomial_schedule -from optax_add_eve._src.schedule import sgdr_schedule -from optax_add_eve._src.schedule import warmup_cosine_decay_schedule -from optax_add_eve._src.schedule import warmup_exponential_decay_schedule -from optax_add_eve._src.second_order import fisher_diag -from optax_add_eve._src.second_order import hessian_diag -from optax_add_eve._src.second_order import hvp -from optax_add_eve._src.stochastic_gradient_estimators import measure_valued_jacobians -from optax_add_eve._src.stochastic_gradient_estimators import pathwise_jacobians -from optax_add_eve._src.stochastic_gradient_estimators import score_function_jacobians -from optax_add_eve._src.transform import add_decayed_weights -from optax_add_eve._src.transform import add_noise -from optax_add_eve._src.transform import AddDecayedWeightsState -from optax_add_eve._src.transform import additive_weight_decay -from optax_add_eve._src.transform import AdditiveWeightDecayState -from optax_add_eve._src.transform import AddNoiseState -from optax_add_eve._src.transform import apply_every -from optax_add_eve._src.transform import ApplyEvery -from optax_add_eve._src.transform import bias_correction -from optax_add_eve._src.transform import centralize -from optax_add_eve._src.transform import ema -from optax_add_eve._src.transform import EmaState -from optax_add_eve._src.transform import scale -from optax_add_eve._src.transform import scale_by_adam -from optax_add_eve._src.transform import scale_by_adamax -from optax_add_eve._src.transform import scale_by_amsgrad -from optax_add_eve._src.transform import scale_by_belief -from optax_add_eve._src.transform import scale_by_novograd -from optax_add_eve._src.transform import scale_by_optimistic_gradient -from optax_add_eve._src.transform import scale_by_param_block_norm -from optax_add_eve._src.transform import scale_by_param_block_rms -from optax_add_eve._src.transform import scale_by_radam -from optax_add_eve._src.transform import scale_by_rms -from optax_add_eve._src.transform import scale_by_rss -from optax_add_eve._src.transform import scale_by_schedule -from optax_add_eve._src.transform import scale_by_sm3 -from optax_add_eve._src.transform import scale_by_stddev -from optax_add_eve._src.transform import scale_by_trust_ratio -from optax_add_eve._src.transform import scale_by_yogi -from optax_add_eve._src.transform import ScaleByAdamState -from optax_add_eve._src.transform import ScaleByAmsgradState -from optax_add_eve._src.transform import ScaleByBeliefState -from optax_add_eve._src.transform import ScaleByNovogradState -from optax_add_eve._src.transform import ScaleByRmsState -from optax_add_eve._src.transform import ScaleByRssState -from optax_add_eve._src.transform import ScaleByRStdDevState -from optax_add_eve._src.transform import ScaleByScheduleState -from optax_add_eve._src.transform import ScaleBySM3State -from optax_add_eve._src.transform import ScaleByTrustRatioState -from optax_add_eve._src.transform import ScaleState -from optax_add_eve._src.transform import trace -from optax_add_eve._src.transform import TraceState -from optax_add_eve._src.transform import update_infinity_moment -from optax_add_eve._src.transform import update_moment -from optax_add_eve._src.transform import update_moment_per_elem_norm -from optax_add_eve._src.update import apply_updates -from optax_add_eve._src.update import incremental_update -from optax_add_eve._src.update import periodic_update -from optax_add_eve._src.utils import multi_normal -from optax_add_eve._src.utils import scale_gradient -from optax_add_eve._src.wrappers import apply_if_finite -from optax_add_eve._src.wrappers import ApplyIfFiniteState -from optax_add_eve._src.wrappers import flatten -from optax_add_eve._src.wrappers import masked -from optax_add_eve._src.wrappers import MaskedNode -from optax_add_eve._src.wrappers import MaskedState -from optax_add_eve._src.wrappers import maybe_update -from optax_add_eve._src.wrappers import MaybeUpdateState -from optax_add_eve._src.wrappers import MultiSteps -from optax_add_eve._src.wrappers import MultiStepsState -from optax_add_eve._src.wrappers import ShouldSkipUpdateFunction -from optax_add_eve._src.wrappers import skip_large_updates -from optax_add_eve._src.wrappers import skip_not_finite - -__version__ = "0.1.5.dev" - -__all__ = ( - "adabelief", - "adafactor", - "adagrad", - "adam", - "adamax", - "adamaxw", - "adamw", - "adaptive_grad_clip", - "AdaptiveGradClipState", - "add_decayed_weights", - "add_noise", - "AddDecayedWeightsState", - "additive_weight_decay", - "AdditiveWeightDecayState", - "AddNoiseState", - "amsgrad", - "apply_every", - "apply_if_finite", - "apply_updates", - "ApplyEvery", - "ApplyIfFiniteState", - "centralize", - "chain", - "clip_by_block_rms", - "clip_by_global_norm", - "clip", - "ClipByGlobalNormState", - "ClipState", - "constant_schedule", - "ctc_loss", - "ctc_loss_with_forward_probs", - "control_delta_method", - "control_variates_jacobians", - "cosine_decay_schedule", - "cosine_distance", - "cosine_onecycle_schedule", - "cosine_similarity", - "differentially_private_aggregate", - "DifferentiallyPrivateAggregateState", - "dpsgd", - "ema", - "EmaState", - "EmptyState", - "exponential_decay", - "FactoredState", - "fisher_diag", - "flatten", - "fromage", - "global_norm", - "GradientTransformation", - "hinge_loss", - "hessian_diag", - "huber_loss", - "hvp", - "identity", - "incremental_update", - "inject_hyperparams", - "InjectHyperparamsState", - "join_schedules", - "keep_params_nonnegative", - "l2_loss", - "lamb", - "lars", - "linear_onecycle_schedule", - "linear_schedule", - "log_cosh", - "lookahead", - "LookaheadParams", - "LookaheadState", - "masked", - "MaskOrFn", - "MaskedState", - "matrix_inverse_pth_root", - "maybe_update", - "MaybeUpdateState", - "measure_valued_jacobians", - "moving_avg_baseline", - "multi_normal", - "multi_transform", - "MultiSteps", - "MultiStepsState", - "MultiTransformState", - "noisy_sgd", - "novograd", - "NonNegativeParamsState", - "OptState", - "Params", - "pathwise_jacobians", - "periodic_update", - "per_example_global_norm_clip", - "piecewise_constant_schedule", - "piecewise_interpolate_schedule", - "polynomial_schedule", - "power_iteration", - "radam", - "rmsprop", - "safe_int32_increment", - "safe_norm", - "safe_root_mean_squares", - "ScalarOrSchedule", - "scale_by_adam", - "scale_by_adamax", - "scale_by_amsgrad", - "scale_by_belief", - "scale_by_factored_rms", - "scale_by_novograd", - "scale_by_param_block_norm", - "scale_by_param_block_rms", - "scale_by_radam", - "scale_by_rms", - "scale_by_rss", - "scale_by_schedule", - "scale_by_sm3", - "scale_by_stddev", - "scale_by_trust_ratio", - "scale_by_yogi", - "scale_gradient", - "scale", - "ScaleByAdamState", - "ScaleByAmsgradState", - "ScaleByBeliefState", - "ScaleByNovogradState", - "ScaleByRmsState", - "ScaleByRssState", - "ScaleByRStdDevState", - "ScaleByScheduleState", - "ScaleBySM3State", - "ScaleByTrustRatioState", - "ScaleState", - "Schedule", - "score_function_jacobians", - "set_to_zero", - "sgd", - "sgdr_schedule", - "ShouldSkipUpdateFunction", - "sigmoid_binary_cross_entropy", - "skip_large_updates", - "skip_not_finite", - "sm3", - "smooth_labels", - "softmax_cross_entropy", - "stateless", - "stateless_with_tree_map", - "trace", - "TraceState", - "TransformInitFn", - "TransformUpdateFn", - "Updates", - "warmup_cosine_decay_schedule", - "warmup_exponential_decay_schedule", - "yogi", - "zero_nans", - "ZeroNansState", -) - -# _________________________________________ -# / Please don't use symbols in `_src` they \ -# \ are not part of the Optax public API. / -# ----------------------------------------- -# \ ^__^ -# \ (oo)\_______ -# (__)\ )\/\ -# ||----w | -# || || -# diff --git a/optax_add_eve/_src/alias.py b/optax_add_eve/_src/alias.py deleted file mode 100644 index b5935ae9..00000000 --- a/optax_add_eve/_src/alias.py +++ /dev/null @@ -1,926 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Aliases for popular optimizers.""" - -from typing import Any, Callable, Optional, Union - -import jax.numpy as jnp - -from optax_add_eve._src import base -from optax_add_eve._src import clipping -from optax_add_eve._src import combine -from optax_add_eve._src import factorized -from optax_add_eve._src import privacy -from optax_add_eve._src import transform -from optax_add_eve._src import wrappers - - -ScalarOrSchedule = Union[float, base.Schedule] -MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] - - -def _scale_by_learning_rate(learning_rate: ScalarOrSchedule, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return transform.scale_by_schedule(lambda count: m * learning_rate(count)) - return transform.scale(m * learning_rate) - - -def adabelief( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-16, - eps_root: float = 1e-16) -> base.GradientTransformation: - """The AdaBelief optimizer. - - AdaBelief is an adaptive learning rate optimizer that focuses on fast - convergence, generalization, and stability. It adapts the step size depending - on its "belief" in the gradient direction — the optimizer adaptively scales - the step size by the difference between the predicted and observed gradients. - AdaBelief is a modified version of Adam and contains the same number of - parameters. - - References: - Zhuang et al, 2020: https://arxiv.org/abs/2010.07468 - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the second moment of past gradients. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the second moment of the prediction error to - improve numerical stability. If backpropagating gradients through the - gradient transformation (e.g. for meta-learning), this must be non-zero. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_belief(b1=b1, b2=b2, eps=eps, eps_root=eps_root), - _scale_by_learning_rate(learning_rate), - ) - - -def adafactor( - learning_rate: Optional[ScalarOrSchedule] = None, - min_dim_size_to_factor: int = 128, - decay_rate: float = 0.8, - decay_offset: int = 0, - multiply_by_parameter_scale: float = True, - clipping_threshold: Optional[float] = 1.0, - momentum: Optional[float] = None, - dtype_momentum: Any = jnp.float32, - weight_decay_rate: Optional[float] = None, - eps: float = 1e-30, - factored: bool = True, - weight_decay_mask: MaskOrFn = None, - ) -> base.GradientTransformation: - """The Adafactor optimizer. - - Adafactor is an adaptive learning rate optimizer that focuses on fast - training of large scale neural networks. It saves memory by using a factored - estimate of the second order moments used to scale gradients. - - References: - Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235 - - Args: - learning_rate: A fixed global scaling factor. Note: the natural scale for - Adafactor's LR is markedly different from Adam, one doesn't use the - 1/sqrt(hidden) correction for this optim with attention-based models. - min_dim_size_to_factor: Only factor the statistics if two array dimensions - have at least this size. - decay_rate: Controls second-moment exponential decay schedule. - decay_offset: For fine-tuning, one may set this to the starting step - number of the fine-tuning phase. - multiply_by_parameter_scale: If True, then scale learning_rate by - parameter norm. If False, provided learning_rate is absolute step size. - clipping_threshold: Optional clipping threshold. Must be >= 1. If None, - clipping is disabled. - momentum: Optional value between 0 and 1, enables momentum and uses extra - memory if non-None! None by default. - dtype_momentum: Data type of momentum buffers. - weight_decay_rate: Optional rate at which to decay weights. - eps: Regularization constant for root mean squared gradient. - factored: Whether to use factored second-moment estimates. - weight_decay_mask: A tree with same structure as (or a prefix of) - the params PyTree, or a Callable that returns such a pytree given - the params/updates. The leaves should be booleans, `True` - for leaves/subtrees you want to apply the transformation to, - and `False` for those you want to skip. - - Returns: - The corresponding `GradientTransformation`. - """ - # The core of the algorithm is a procedure for rescaling gradients - # by a factored estimate of the root mean squared gradients. - # This reduces memory compared to algorithms such as Adam or RmsProp, - # by not having to hold a separate estimate for each weight. - tx = [ - factorized.scale_by_factored_rms( - factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)] - # This basic rescaling is typically combined with one or more of the following - # transformation (all can be disabled via adafactor's constructor args). - if clipping_threshold is not None: - tx.append(clipping.clip_by_block_rms(clipping_threshold)) - if learning_rate is not None: - tx.append(_scale_by_learning_rate(learning_rate, flip_sign=False)) - if multiply_by_parameter_scale: - tx.append(transform.scale_by_param_block_rms()) - if momentum is not None: - tx.append( - transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum)) - if weight_decay_rate is not None: - tx.append(transform.add_decayed_weights( - weight_decay_rate, mask=weight_decay_mask)) - # In gradient "descent" we follow the negative gradient. - tx.append(transform.scale(-1)) - return combine.chain(*tx) - - -def adagrad( - learning_rate: ScalarOrSchedule, - initial_accumulator_value: float = 0.1, - eps: float = 1e-7 -) -> base.GradientTransformation: - """The Adagrad optimizer. - - Adagrad is an algorithm for gradient based optimization that anneals the - learning rate for each parameter during the course of training. - - WARNING: Adagrad's main limit is the monotonic accumulation of squared - gradients in the denominator: since all terms are >0, the sum keeps growing - during training and the learning rate eventually becomes vanishingly small. - - References: - Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html - - Args: - learning_rate: A fixed global scaling factor. - initial_accumulator_value: Initial value for the accumulator. - eps: A small constant applied to denominator inside of the square root - (as in RMSProp) to avoid dividing by zero when rescaling. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_rss( - initial_accumulator_value=initial_accumulator_value, eps=eps), - _scale_by_learning_rate(learning_rate), - ) - - -def adam( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - mu_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - r"""The classic Adam optimizer. - - Adam is an SGD variant with gradient scaling adaptation. The scaling - used for each parameter is computed from estimates of first and second-order - moments of the gradients (using suitable exponential moving averages). - - Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`, - :math:`\varepsilon`, :math:`\bar{\varepsilon}` represent the arguments - ``b1``, ``b2``, ``eps`` and ``eps_root`` respectievly. The learning rate is - indexed by :math:`t` since the learning rate may also be provided by a - schedule function. - - The ``init`` function of this optimizer initializes an internal state - :math:`S_0 := (m_0, v_0) = (0, 0)`, representing initial estimates for the - first and second moments. In practice these values are stored as pytrees - containing all zeros, with the same shape as the model updates. - At step :math:`t`, the ``update`` function of this optimizer takes as - arguments the incoming gradients :math:`g_t` and optimizer state :math:`S_t` - and computes updates :math:`u_t` and new state :math:`S_{t+1}`. Thus, for - :math:`t > 0`, we have, - - .. math:: - \begin{align*} - m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ - v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ - \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ - \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ - u_t &\leftarrow \alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + - \bar{\varepsilon}} + \varepsilon} \right)\\ - S_t &\leftarrow (m_t, v_t). - \end{align*} - - References: - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the second moment of past gradients. - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: A small constant applied to denominator inside the square root (as - in RMSProp), to avoid dividing by zero when rescaling. This is needed for - example when computing (meta-)gradients through Adam. - mu_dtype: Optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_adam( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), - _scale_by_learning_rate(learning_rate), - ) - - -def adamw( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - mu_dtype: Optional[Any] = None, - weight_decay: float = 1e-4, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, -) -> base.GradientTransformation: - """Adam with weight decay regularization. - - AdamW uses weight decay to regularize learning towards small weights, as - this leads to better generalization. In SGD you can also use L2 regularization - to implement this as an additive loss term, however L2 regularization - does not behave as intended for adaptive gradient algorithms such as Adam. - - References: - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the second moment of past gradients. - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: A small constant applied to denominator inside the square root (as - in RMSProp), to avoid dividing by zero when rescaling. This is needed for - instance when computing (meta-)gradients through Adam. - mu_dtype: Optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent - with other frameworks such as PyTorch, but different from - (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - mask: A tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Adam gradient transformations are applied to all parameters. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_adam( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), - transform.add_decayed_weights(weight_decay, mask), - _scale_by_learning_rate(learning_rate), - ) - - -def amsgrad( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - mu_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """The AMSGrad optimiser. - - The original Adam can fail to converge to the optimal solution in some cases. - AMSGrad guarantees convergence by using a long-term memory of past gradients. - - References: - Reddi et al, 2018: https://openreview.net/forum?id=ryQu7f-RZ - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the second moment of past gradients. - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: A small constant applied to denominator inside the square root (as - in RMSProp), to avoid dividing by zero when rescaling. This is needed for - instance when computing (meta-)gradients through Adam. - mu_dtype: Optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_amsgrad( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), - _scale_by_learning_rate(learning_rate), - ) - -def eve( - learning_rate: float = 1e-3, - b1: float = 0.9, - b2: float = 0.999, - b3: float = 0.999, - c: float = 10., - eps: float = 1e-8, - f_star: float = 0., - mu_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """The Eve optimizer. - - Eve is an SGD variant with adaptive global and local learning rates. The `learning_rate` - used for each weight is computed from estimates of first- and second-order - moments of the gradients (using suitable exponential moving averages) as in ADAM. - The global learning rate is scaled by some notion of sub-optimality and is increased - when far from optimal and is decreased when approaching optimality - - References: - Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 - - Args: - learning_rate: this is the initial global scaling factor. - b1: the exponential decay rate to track the first moment of past gradients. - b2: the exponential decay rate to track the second moment of past gradients. - b3: the exponential decay rate to track the sub-optimality. - c: the clipping limit to prevent extreme global learning rate changes - eps: a small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - f_star: estimation of the global minimum - mu_dtype: optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - - Returns: - the corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_eve( - b1=b1, b2=b2, b3=b3, c=c, eps=eps, f_star=f_star, mu_dtype=mu_dtype), - _scale_by_learning_rate(learning_rate), - ) - -def fromage( - learning_rate: float, - min_norm: float = 1e-6 -) -> base.GradientTransformation: - """The Frobenius matched gradient descent (Fromage) optimizer. - - Fromage is a learning algorithm that does not require learning rate tuning. - The optimizer is based on modeling neural network gradients via deep relative - trust (a distance function on deep neural networks). Fromage is similar to the - LARS optimizer and can work on a range of standard neural network benchmarks, - such as natural language Transformers and generative adversarial networks. - - References: - Bernstein et al, 2020: https://arxiv.org/abs/2002.03432 - - Args: - learning_rate: A fixed global scaling factor. - min_norm: A minimum value that the norm of the gradient updates and the norm - of the layer parameters can be clipped to to avoid dividing by zero when - computing the trust ratio (as in the LARS paper). - - Returns: - The corresponding `GradientTransformation`. - """ - mult = 1 / jnp.sqrt(1 + learning_rate ** 2) - return combine.chain( - transform.scale_by_trust_ratio(min_norm), - _scale_by_learning_rate(learning_rate * mult), - transform.add_decayed_weights((mult - 1)), - ) - - -def lars( - learning_rate: ScalarOrSchedule, - weight_decay: float = 0., - weight_decay_mask: MaskOrFn = True, - trust_coefficient: float = 0.001, - eps: float = 0., - trust_ratio_mask: MaskOrFn = True, - momentum: float = 0.9, - nesterov: bool = False, -) -> base.GradientTransformation: - """The LARS optimizer. - - LARS is a layer-wise adaptive optimizer introduced to help scale SGD to - larger batch sizes. LARS later inspired the LAMB optimizer. - - References: - You et al, 2017: https://arxiv.org/abs/1708.03888 - - Args: - learning_rate: A fixed global scaling factor. - weight_decay: Strength of the weight decay regularization. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the transformation to, and `False` for those you want to skip. - trust_coefficient: A multiplier for the trust ratio. - eps: Optional additive constant in the trust ratio denominator. - trust_ratio_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the transformation to, and `False` for those you want to skip. - momentum: Decay rate for momentum. - nesterov: Whether to use Nesterov momentum. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), - wrappers.masked( - inner=transform.scale_by_trust_ratio( - trust_coefficient=trust_coefficient, eps=eps), - mask=trust_ratio_mask), - _scale_by_learning_rate(learning_rate), - transform.trace(decay=momentum, nesterov=nesterov), - ) - - -def lamb( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-6, - eps_root: float = 0.0, - weight_decay: float = 0., - mask: MaskOrFn = None, -) -> base.GradientTransformation: - """The LAMB optimizer. - - LAMB is a general purpose layer-wise adaptive large batch optimizer designed - to provide consistent training performance across a wide range of tasks, - including those that use attention-based models (such as Transformers) and - ResNet-50. The optimizer is able to work with small and large batch sizes. - LAMB was inspired by the LARS learning algorithm. - - References: - You et al, 2019: https://arxiv.org/abs/1904.00962 - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the second moment of past gradients. - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: A small constant applied to denominator inside the square root (as - in RMSProp), to avoid dividing by zero when rescaling. This is needed for - instance when computing (meta-)gradients through Adam. - weight_decay: Strength of the weight decay regularization. - mask: A tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the transformation to, and `False` for those you want to skip. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), - transform.add_decayed_weights(weight_decay=weight_decay, mask=mask), - transform.scale_by_trust_ratio(), - _scale_by_learning_rate(learning_rate), - ) - - -def noisy_sgd( - learning_rate: ScalarOrSchedule, - eta: float = 0.01, - gamma: float = 0.55, - seed: int = 0 -) -> base.GradientTransformation: - r"""A variant of SGD with added noise. - - It has been found that adding noise to the gradients can improve - both the training error and the generalization error in very deep networks. - - References: - Neelakantan et al, 2014: https://arxiv.org/abs/1511.06807 - - Args: - learning_rate: A fixed global scaling factor. - eta: Initial variance for the Gaussian noise added to gradients. - gamma: A parameter controlling the annealing of noise over time, the - variance decays according to `(1+t)^-\gamma`. - seed: Seed for the pseudo-random generation process. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.add_noise(eta, gamma, seed), - _scale_by_learning_rate(learning_rate), - ) - - -def novograd( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.25, - eps: float = 1e-6, - eps_root: float = 0.0, - weight_decay: float = 0., -) -> base.GradientTransformation: - """NovoGrad optimizer. - - NovoGrad is more robust to the initial learning rate and - weight initialization than other methods. For example, - NovoGrad works well without LR warm-up, while other methods require it. - NovoGrad performs exceptionally well for large batch training, e.g. it - outperforms other methods for ResNet-50 for all batches up to 32K. - In addition, NovoGrad requires half the memory compared to Adam. - It was introduced together with Jasper ASR model. - - References: - Ginsburg et al, 2019: https://arxiv.org/abs/1905.11286 - Li et al, 2019: https://arxiv.org/abs/1904.03288 - - Args: - learning_rate: A fixed global scaling factor. - b1: An exponential decay rate to track the first moment of past gradients. - b2: An exponential decay rate to track the second moment of past gradients. - eps: A small constant applied to denominator outside of the square root (as - in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: A small constant applied to denominator inside - the square root (as in RMSProp), to avoid dividing by zero when rescaling. - This is needed for instance when computing (meta-)gradients through Adam. - weight_decay: Strength of the weight decay regularization. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_novograd( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, weight_decay=weight_decay), - _scale_by_learning_rate(learning_rate), - ) - - -def optimistic_gradient_descent( - learning_rate: ScalarOrSchedule, - alpha: ScalarOrSchedule = 1.0, - beta: ScalarOrSchedule = 1.0 -) -> base.GradientTransformation: - """An Optimistic Gradient Descent optimizer. - - Optimistic gradient descent is an approximation of extra-gradient methods - which require multiple gradient calls to compute the next update. It has - strong formal guarantees for last-iterate convergence in min-max games, for - which standard gradient descent can oscillate or even diverge. - - References: - Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2 - - Args: - learning_rate: A fixed global scaling factor. - alpha: Coefficient for generalized OGD. - beta: Coefficient for generalized OGD negative momentum. - - Returns: - A `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_optimistic_gradient(alpha=alpha, beta=beta), - _scale_by_learning_rate(learning_rate) - ) - - -def radam( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - threshold: float = 5.0 -) -> base.GradientTransformation: - """The Rectified Adam optimizer. - - The adaptive learning rate in Adam has undesirably large variance in early - stages of training, due to the limited number of training samples used to - estimate the optimizer's statistics. Rectified Adam addresses this issue - by analytically reducing the large variance. - - References: - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the second moment of past gradients. - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: A small constant applied to denominator inside the square root (as - in RMSProp), to avoid dividing by zero when rescaling. This is needed for - instance when computing (meta-)gradients through Adam. - threshold: Threshold for variance tractability. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_radam( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, threshold=threshold), - _scale_by_learning_rate(learning_rate), - ) - - -def rmsprop( - learning_rate: ScalarOrSchedule, - decay: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0., - centered: bool = False, - momentum: Optional[float] = None, - nesterov: bool = False -) -> base.GradientTransformation: - # pylint: disable=line-too-long - """A flexible RMSProp optimizer. - - RMSProp is an SGD variant with learning rate adaptation. The `learning_rate` - used for each weight is scaled by a suitable estimate of the magnitude of the - gradients on previous steps. Several variants of RMSProp can be found - in the literature. This alias provides an easy to configure RMSProp - optimizer that can be used to switch between several of these variants. - - References: - Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf - Graves, 2013: https://arxiv.org/abs/1308.0850 - - Args: - learning_rate: A fixed global scaling factor. - decay: Decay used to track the magnitude of previous gradients. - eps: A small numerical constant to avoid dividing by zero when rescaling. - initial_scale: Initial value of accumulators tracking the magnitude of - previous updates. PyTorch uses `0`, TF1 uses `1`. When reproducing results - from a paper, verify the value used by the authors. - centered: Whether the second moment or the variance of the past gradients is - used to rescale the latest gradients. - momentum: Decay rate used by the momentum term, when it is set to `None`, - then momentum is not used at all. - nesterov: Whether Nesterov momentum is used. - - Returns: - The corresponding `GradientTransformation`. - """ - # pylint: enable=line-too-long - if centered: - return combine.chain( - transform.scale_by_stddev( - decay=decay, eps=eps, initial_scale=initial_scale), - _scale_by_learning_rate(learning_rate), - (transform.trace(decay=momentum, nesterov=nesterov) - if momentum is not None else base.identity()) - ) - return combine.chain( - transform.scale_by_rms( - decay=decay, eps=eps, initial_scale=initial_scale), - _scale_by_learning_rate(learning_rate), - (transform.trace(decay=momentum, nesterov=nesterov) - if momentum is not None else base.identity()) - ) - - -def sgd( - learning_rate: ScalarOrSchedule, - momentum: Optional[float] = None, - nesterov: bool = False, - accumulator_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """A canonical Stochastic Gradient Descent optimizer. - - This implements stochastic gradient descent. It also includes support for - momentum, and nesterov acceleration, as these are standard practice when - using stochastic gradient descent to train deep neural networks. - - References: - Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf - - Args: - learning_rate: A fixed global scaling factor. - momentum: Decay rate used by the momentum term, when it is set to `None`, - then momentum is not used at all. - nesterov: Whether Nesterov momentum is used. - accumulator_dtype: Optional `dtype` to be used for the accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - - Returns: - A `GradientTransformation`. - """ - return combine.chain( - (transform.trace(decay=momentum, nesterov=nesterov, - accumulator_dtype=accumulator_dtype) - if momentum is not None else base.identity()), - _scale_by_learning_rate(learning_rate) - ) - - -def sm3( - learning_rate: float, - momentum: float = 0.9 -) -> base.GradientTransformation: - """The SM3 optimizer. - - SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a - memory-efficient adaptive optimizer designed to decrease memory overhead when - training very large models, such as the Transformer for machine translation, - BERT for language modeling, and AmoebaNet-D for image classification. SM3: 1) - applies to tensors of arbitrary dimensions and any predefined cover of the - parameters; 2) adapts the learning rates in an adaptive and data-driven manner - (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence - guarantees in stochastic convex optimization settings. - - References: - Anil et al, 2019: https://arxiv.org/abs/1901.11150 - - Args: - learning_rate: A fixed global scaling factor. - momentum: Decay rate used by the momentum term (when it is not set to - `None`, then momentum is not used at all). - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_sm3(momentum), - transform.scale(-learning_rate), - ) - - -def yogi( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-3, -) -> base.GradientTransformation: - # pylint: disable=line-too-long - """The Yogi optimizer. - - Yogi is an adaptive optimizer, which provides control in tuning the effective - learning rate to prevent it from increasing. By doing so, it focuses on - addressing the issues of convergence and generalization in exponential moving - average-based adaptive methods (such as Adam and RMSprop). Yogi is a - modification of Adam and uses the same parameters. - - References: - Zaheer et al, 2018: https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the second moment of past gradients. - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - - Returns: - The corresponding `GradientTransformation`. - """ - # pylint: enable=line-too-long - return combine.chain( - transform.scale_by_yogi(b1=b1, b2=b2, eps=eps), - _scale_by_learning_rate(learning_rate), - ) - - -def dpsgd( - learning_rate: ScalarOrSchedule, - l2_norm_clip: float, - noise_multiplier: float, - seed: int, - momentum: Optional[float] = None, - nesterov: bool = False -) -> base.GradientTransformation: - """The DPSGD optimizer. - - Differential privacy is a standard for privacy guarantees of algorithms - learning from aggregate databases including potentially sensitive information. - DPSGD offers protection against a strong adversary with full knowledge of the - training mechanism and access to the model’s parameters. - - WARNING: This `GradientTransformation` expects input updates to have a batch - dimension on the 0th axis. That is, this function expects per-example - gradients as input (which are easy to obtain in JAX using `jax.vmap`). - - References: - Abadi et al, 2016: https://arxiv.org/abs/1607.00133 - - Args: - learning_rate: A fixed global scaling factor. - l2_norm_clip: Maximum L2 norm of the per-example gradients. - noise_multiplier: Ratio of standard deviation to the clipping norm. - seed: Initial seed used for the jax.random.PRNGKey - momentum: Decay rate used by the momentum term, when it is set to `None`, - then momentum is not used at all. - nesterov: Whether Nesterov momentum is used. - - Returns: - A `GradientTransformation`. - """ - return combine.chain( - privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, - noise_multiplier=noise_multiplier, - seed=seed), - (transform.trace(decay=momentum, nesterov=nesterov) - if momentum is not None else base.identity()), - _scale_by_learning_rate(learning_rate) - ) - - -def adamax( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, -) -> base.GradientTransformation: - """A variant of the Adam optimizer that uses the infinity norm. - - References: - Kingma et al, 2014: https://arxiv.org/abs/1412.6980 - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the maximum of past gradients. - eps: A small constant applied to denominator to avoid dividing by zero when - rescaling. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_adamax(b1=b1, b2=b2, eps=eps,), - _scale_by_learning_rate(learning_rate), - ) - - -def adamaxw( - learning_rate: ScalarOrSchedule, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - weight_decay: float = 1e-4, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, -) -> base.GradientTransformation: - """Adamax with weight decay regularization. - - AdamaxW uses weight decay to regularize learning towards small weights, as - this leads to better generalization. In SGD you can also use L2 regularization - to implement this as an additive loss term, however L2 regularization - does not behave as intended for adaptive gradient algorithms such as Adam. - - WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or - for the bias parameters. You can use `optax.masked` to make your own AdamaxW - variant where `additive_weight_decay` is applied only to a subset of `params`. - - References: - Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 - - Args: - learning_rate: A fixed global scaling factor. - b1: Exponential decay rate to track the first moment of past gradients. - b2: Exponential decay rate to track the maximum of past gradients. - eps: A small constant applied to denominator to avoid dividing by zero when - rescaling. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent - with other frameworks such as PyTorch, but different from - (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - mask: A tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Adamax gradient transformations are applied to all parameters. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - transform.scale_by_adamax(b1=b1, b2=b2, eps=eps), - transform.add_decayed_weights(weight_decay, mask), - _scale_by_learning_rate(learning_rate), - ) diff --git a/optax_add_eve/_src/alias_test.py b/optax_add_eve/_src/alias_test.py deleted file mode 100644 index 46f0643d..00000000 --- a/optax_add_eve/_src/alias_test.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `alias.py`.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import alias -from optax_add_eve._src import numerics -from optax_add_eve._src import schedule -from optax_add_eve._src import update - -_OPTIMIZERS_UNDER_TEST = ( - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3, momentum=0.9)), - dict(opt_name='adafactor', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='adagrad', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adam', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), - dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1e-3)), - dict( - opt_name='optimistic_gradient_descent', - opt_kwargs=dict(learning_rate=2e-3, alpha=0.7, beta=0.1)), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=5e-3, momentum=0.9)), - dict(opt_name='fromage', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1e-2)), - dict(opt_name='radam', opt_kwargs=dict(learning_rate=5e-3)), - dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1e-1)), - dict( - opt_name='dpsgd', - opt_kwargs=dict( - learning_rate=1e-3, - l2_norm_clip=10., - noise_multiplier=1e-3, - seed=0, - momentum=0.2)), -) - - -def _setup_parabola(dtype): - """Quadratic function as an optimization target.""" - initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) - final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) - - if jnp.iscomplexobj(dtype): - final_params *= 1 + 1j - - @jax.grad - def get_updates(params): - return jnp.sum(numerics.abs_sq(params - final_params)) - - return initial_params, final_params, get_updates - - -def _setup_rosenbrock(dtype): - """Rosenbrock function as an optimization target.""" - a = 1.0 - b = 100.0 - - if jnp.iscomplexobj(dtype): - a *= 1 + 1j - - initial_params = jnp.array([0.0, 0.0], dtype=dtype) - final_params = jnp.array([a, a**2], dtype=dtype) - - @jax.grad - def get_updates(params): - return (numerics.abs_sq(a - params[0]) + - b * numerics.abs_sq(params[1] - params[0]**2)) - - return initial_params, final_params, get_updates - - -class AliasTest(chex.TestCase): - - @parameterized.product( - _OPTIMIZERS_UNDER_TEST, - target=(_setup_parabola, _setup_rosenbrock), - dtype=(jnp.float32, jnp.complex64), - ) - def test_optimization(self, opt_name, opt_kwargs, target, dtype): - if (opt_name - in ('fromage', 'noisy_sgd', 'sm3', 'optimistic_gradient_descent') and - jnp.iscomplexobj(dtype)): - raise absltest.SkipTest( - f'{opt_name} does not support complex parameters.') - - opt = getattr(alias, opt_name)(**opt_kwargs) - initial_params, final_params, get_updates = target(dtype) - - @jax.jit - def step(params, state): - updates = get_updates(params) - if opt_name == 'dpsgd': - updates = updates[None] - # Complex gradients need to be conjugated before being added to parameters - # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 - updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) - updates, state = opt.update(updates, state, params) - params = update.apply_updates(params, updates) - return params, state - - params = initial_params - state = opt.init(params) - for _ in range(10000): - params, state = step(params, state) - - chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) - - @chex.all_variants - @parameterized.product(_OPTIMIZERS_UNDER_TEST) - def test_optimizers_can_be_wrapped_in_inject_hyperparams( - self, opt_name, opt_kwargs): - """Checks that optimizers can be wrapped in inject_hyperparams.""" - # See also https://github.com/deepmind/optax/issues/412. - opt_factory = getattr(alias, opt_name) - opt = opt_factory(**opt_kwargs) - if opt_name == 'adafactor': - # Adafactor wrapped in inject_hyperparams currently needs a static - # argument to be specified in order to be jittable. See issue - # https://github.com/deepmind/optax/issues/412. - opt_inject = schedule.inject_hyperparams( - opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs) - else: - opt_inject = schedule.inject_hyperparams(opt_factory)(**opt_kwargs) - - params = [-jnp.ones((2, 3)), jnp.ones((2, 5, 2))] - grads = [jnp.ones((2, 3)), -jnp.ones((2, 5, 2))] - - state = self.variant(opt.init)(params) - updates, new_state = self.variant(opt.update)(grads, state, params) - - state_inject = self.variant(opt_inject.init)(params) - updates_inject, new_state_inject = self.variant(opt_inject.update)( - grads, state_inject, params) - - with self.subTest('Equality of updates.'): - chex.assert_trees_all_close(updates_inject, updates, rtol=1e-4) - with self.subTest('Equality of new optimizer states.'): - chex.assert_trees_all_close( - new_state_inject.inner_state, new_state, rtol=1e-4) - - @parameterized.named_parameters([ - ('float32', 'float32'), - ('bfloat16', 'bfloat16'), - ('complex64', 'complex64'), - ('None', None), - ]) - def test_explicit_dtype(self, dtype): - expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32 - tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype) - trace_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, trace_state.trace.dtype) - tx = alias.adam(0.1, mu_dtype=dtype) - adam_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, adam_state.mu.dtype) - tx = alias.adamw(0.1, mu_dtype=dtype) - adam_state, _, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, adam_state.mu.dtype) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/base_test.py b/optax_add_eve/_src/base_test.py deleted file mode 100644 index 65c898b4..00000000 --- a/optax_add_eve/_src/base_test.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for base.py.""" - -from absl.testing import absltest - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import base - -# pylint:disable=no-value-for-parameter - - -class BaseTest(chex.TestCase): - - def test_typing(self): - """Ensure that the type annotations work for the update function.""" - - def f(updates, opt_state, params=None): - del params - return updates, opt_state - - def g(f: base.TransformUpdateFn): - updates = np.zeros([]) - params = np.zeros([]) - opt_state = np.zeros([]) - - f(updates, opt_state) - f(updates, opt_state, params) - f(updates, opt_state, params=params) - - g(f) - - @chex.all_variants - def test_set_to_zero_returns_tree_of_correct_zero_arrays(self): - """Tests that zero transform returns a tree of zeros of correct shape.""" - grads = ({'a': np.ones((3, 4)), 'b': 1.}, np.ones((1, 2, 3))) - updates, _ = self.variant(base.set_to_zero().update)(grads, - base.EmptyState()) - correct_zeros = ({'a': np.zeros((3, 4)), 'b': 0.}, np.zeros((1, 2, 3))) - chex.assert_trees_all_close(updates, correct_zeros, rtol=0) - - @chex.all_variants(with_pmap=False) - def test_set_to_zero_is_stateless(self): - """Tests that the zero transform returns an empty state.""" - self.assertEqual( - self.variant(base.set_to_zero().init)(params=None), base.EmptyState()) - - -class StatelessTest(chex.TestCase): - """Tests for the stateless transformation.""" - - @chex.all_variants - def test_stateless(self): - params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1,))} - updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1,), 2.0)} - - @base.stateless - def opt(g, p): - return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p) - - state = opt.init(params) - update_fn = self.variant(opt.update) - new_updates, _ = update_fn(updates, state, params) - expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])} - chex.assert_trees_all_close(new_updates, expected_updates) - - @chex.all_variants - def test_stateless_no_params(self): - updates = {'linear': jnp.full((5, 3), 3.0)} - - @base.stateless - def opt(g, _): - return jax.tree_util.tree_map(lambda g_: g_ * 2, g) - - state = opt.init(None) - update_fn = self.variant(opt.update) - new_updates, _ = update_fn(updates, state) - expected_updates = {'linear': jnp.full((5, 3), 6.0)} - chex.assert_trees_all_close(new_updates, expected_updates) - - def test_init_returns_emptystate(self): - def weight_decay(g, p): - return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p) - - opt = base.stateless(weight_decay) - state = opt.init(None) - self.assertIsInstance(state, base.EmptyState) - - -class StatelessWithTreeMapTest(chex.TestCase): - """Tests for the stateless_with_tree_map transformation.""" - - @chex.all_variants - def test_stateless_with_tree_map(self): - params = {'a': jnp.zeros((1, 2)), 'b': jnp.ones((1,))} - updates = {'a': jnp.ones((1, 2)), 'b': jnp.full((1,), 2.0)} - - opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p) - state = opt.init(params) - update_fn = self.variant(opt.update) - new_updates, _ = update_fn(updates, state, params) - expected_updates = {'a': jnp.ones((1, 2)), 'b': jnp.array([2.1])} - chex.assert_trees_all_close(new_updates, expected_updates) - - @chex.all_variants - def test_stateless_with_tree_map_no_params(self): - updates = {'linear': jnp.full((5, 3), 3.0)} - - opt = base.stateless_with_tree_map(lambda g, _: g * 2.0) - state = opt.init(None) - update_fn = self.variant(opt.update) - new_updates, _ = update_fn(updates, state) - expected_updates = {'linear': jnp.full((5, 3), 6.0)} - chex.assert_trees_all_close(new_updates, expected_updates) - - def test_init_returns_emptystate(self): - opt = base.stateless_with_tree_map(lambda g, p: g + 0.1 * p) - state = opt.init(None) - self.assertIsInstance(state, base.EmptyState) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/clipping.py b/optax_add_eve/_src/clipping.py deleted file mode 100644 index 5eb1dc9d..00000000 --- a/optax_add_eve/_src/clipping.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gradient clipping transformations. - -Note that complex numbers are also supported, see -https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 -""" -from typing import Tuple - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base -from optax_add_eve._src import linear_algebra -from optax_add_eve._src import numerics - -ClipState = base.EmptyState - - -def clip(max_delta: chex.Numeric) -> base.GradientTransformation: - """Clips updates element-wise, to be in ``[-max_delta, +max_delta]``. - - Args: - max_delta: The maximum absolute value for each element in the update. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return ClipState() - - def update_fn(updates, state, params=None): - del params - updates = jax.tree_util.tree_map( - lambda g: jnp.clip(g, -max_delta, max_delta), updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def clip_by_block_rms(threshold: float) -> base.GradientTransformation: - """Clips updates to a max rms for the gradient of each param vector or matrix. - - A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix - (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. - - Args: - threshold: The maximum rms for the gradient of each param vector or matrix. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return base.EmptyState() - - def update_fn(updates, state, params=None): - del params - - def _clip_fn(u): - clip_denom = jnp.maximum( - 1.0, - jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold) - return u / clip_denom - - updates = jax.tree_util.tree_map(_clip_fn, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -ClipByGlobalNormState = base.EmptyState - - -def clip_by_global_norm(max_norm: float) -> base.GradientTransformation: - """Clips updates using their global norm. - - References: - [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) - - Args: - max_norm: The maximum global norm for an update. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return ClipByGlobalNormState() - - def update_fn(updates, state, params=None): - del params - g_norm = linear_algebra.global_norm(updates) - # TODO(b/163995078): revert back to the following (faster) implementation - # once analysed how it affects backprop through update (e.g. meta-gradients) - # g_norm = jnp.maximum(max_norm, g_norm) - # updates = jax.tree_util.tree_map( - # lambda t: (t / g_norm) * max_norm, updates) - trigger = jnp.squeeze(g_norm < max_norm) - chex.assert_shape(trigger, ()) # A scalar. - - def clip_fn(t): - return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm) - - updates = jax.tree_util.tree_map(clip_fn, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def per_example_global_norm_clip(grads: chex.Array, - l2_norm_clip: float) -> Tuple[chex.Array, int]: - """Applies gradient clipping per-example using their global norm. - - References: - [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) - - Args: - grads: flattened update; the function expects these to have a batch - dimension on the 0th axis. - l2_norm_clip: maximum L2 norm of the per-example gradients. - - Returns: - A tuple containing sum of the clipped per-example grads, and the number of - per-example grads that were clipped. - """ - bsize = grads[0].shape[0] - - if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): - raise ValueError( - 'Unlike other transforms, `per_example_global_norm_clip` expects' - ' `grads` to have a batch dimension in the 0th axis.') - - global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads) - divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) - num_clipped = jnp.greater(divisors, 1.0).sum() - clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads] - return clipped_sum, num_clipped - - -def unitwise_norm(x: chex.Array) -> chex.Array: - """Computes norms of each output unit separately.""" - if jnp.squeeze(x).ndim <= 1: # Scalars and vectors - squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True) - # Note that this assumes parameters with a shape of length 3 are multihead - # linear parameters--if you wish to apply AGC to 1D convs, you may need - # to modify this line. - elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear - squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True) - elif x.ndim == 4: # Conv kernels of shape HWIO - squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True) - else: - raise ValueError( - f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.') - chex.assert_is_broadcastable(squared_norm.shape, x.shape) - return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape) - - -def unitwise_clip(g_norm: chex.Array, - max_norm: chex.Array, - grad: chex.Array, - div_eps: float = 1e-6) -> chex.Array: - """Applies gradient clipping unit-wise.""" - # This little max(., div_eps) is distinct from the normal eps and just - # prevents division by zero. It technically should be impossible to engage. - clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps)) - chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad)) - return jnp.where(g_norm < max_norm, grad, clipped_grad) - - -AdaptiveGradClipState = base.EmptyState - - -def adaptive_grad_clip(clipping: float, - eps: float = 1e-3) -> base.GradientTransformation: - """Clips updates to be at most ``clipping * parameter_norm``, unit-wise. - - References: - [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image - Recognition Without Normalization. (https://arxiv.org/abs/2102.06171) - - Args: - clipping: The maximum allowed ratio of update norm to parameter norm. - eps: An epsilon term to prevent clipping of zero-initialized params. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return AdaptiveGradClipState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - g_norm, p_norm = jax.tree_util.tree_map(unitwise_norm, (updates, params)) - # Maximum allowable norm. - max_norm = jax.tree_util.tree_map( - lambda x: clipping * jnp.maximum(x, eps), p_norm) - # If grad norm > clipping * param_norm, rescale. - updates = jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/clipping_test.py b/optax_add_eve/_src/clipping_test.py deleted file mode 100644 index e2676284..00000000 --- a/optax_add_eve/_src/clipping_test.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `clipping.py`.""" - -from absl.testing import absltest - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import clipping -from optax_add_eve._src import linear_algebra - -STEPS = 50 -LR = 1e-2 - - -class ClippingTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) - self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) - - def test_clip(self): - updates = self.per_step_updates - # For a sufficiently high delta the update should not be changed. - clipper = clipping.clip(1e6) - clipped_updates, _ = clipper.update(updates, None) - chex.assert_trees_all_close(clipped_updates, clipped_updates) - # Clipping at delta=1 should make all updates exactly 1. - clipper = clipping.clip(1.) - clipped_updates, _ = clipper.update(updates, None) - chex.assert_trees_all_close( - clipped_updates, jax.tree_util.tree_map(jnp.ones_like, updates)) - - def test_clip_by_block_rms(self): - rmf_fn = lambda t: jnp.sqrt(jnp.mean(t**2)) - updates = self.per_step_updates - for i in range(1, STEPS + 1): - clipper = clipping.clip_by_block_rms(1. / i) - # Check that the clipper actually works and block rms is <= threshold - updates, _ = clipper.update(updates, None) - self.assertAlmostEqual(rmf_fn(updates[0]), 1. / i) - self.assertAlmostEqual(rmf_fn(updates[1]), 1. / i) - # Check that continuously clipping won't cause numerical issues. - updates_step, _ = clipper.update(self.per_step_updates, None) - chex.assert_trees_all_close(updates, updates_step) - - def test_clip_by_global_norm(self): - updates = self.per_step_updates - for i in range(1, STEPS + 1): - clipper = clipping.clip_by_global_norm(1. / i) - # Check that the clipper actually works and global norm is <= max_norm - updates, _ = clipper.update(updates, None) - self.assertAlmostEqual( - linear_algebra.global_norm(updates), 1. / i, places=6) - # Check that continuously clipping won't cause numerical issues. - updates_step, _ = clipper.update(self.per_step_updates, None) - chex.assert_trees_all_close(updates, updates_step) - - def test_adaptive_grad_clip(self): - updates = self.per_step_updates - params = self.init_params - for i in range(1, STEPS + 1): - clip_r = 1. / i - clipper = clipping.adaptive_grad_clip(clip_r) - - # Check that the clipper actually works and upd_norm is < c * param_norm. - updates, _ = clipper.update(updates, None, params) - u_norm, p_norm = jax.tree_util.tree_map( - clipping.unitwise_norm, (updates, params)) - cmp = jax.tree_util.tree_map( - lambda u, p, c=clip_r: u - c * p < 1e-6, u_norm, p_norm) - for leaf in jax.tree_util.tree_leaves(cmp): - self.assertTrue(leaf.all()) - - # Check that continuously clipping won't cause numerical issues. - updates_step, _ = clipper.update(self.per_step_updates, None, params) - chex.assert_trees_all_close(updates, updates_step) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/combine.py b/optax_add_eve/_src/combine.py deleted file mode 100644 index a3a4542a..00000000 --- a/optax_add_eve/_src/combine.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flexibly compose gradient transformations.""" - -from typing import Callable, NamedTuple, Union, Mapping, Hashable - -import jax - -from optax_add_eve._src import base -from optax_add_eve._src import wrappers - - -def chain( - *args: base.GradientTransformation -) -> base.GradientTransformation: - """Applies a list of chainable update transformations. - - Given a sequence of chainable transforms, `chain` returns an `init_fn` - that constructs a `state` by concatenating the states of the individual - transforms, and returns an `update_fn` which chains the update transformations - feeding the appropriate state to each. - - Args: - *args: a sequence of chainable (init_fn, update_fn) tuples. - - Returns: - A single (init_fn, update_fn) tuple. - """ - - init_fns, update_fns = zip(*args) - - def init_fn(params): - return tuple(fn(params) for fn in init_fns) - - def update_fn(updates, state, params=None): - if len(update_fns) != len(state): - raise ValueError('The number of updates and states has to be the same in ' - 'chain! Make sure you have called init first!') - - new_state = [] - for s, fn in zip(state, update_fns): - updates, new_s = fn(updates, s, params) - new_state.append(new_s) - return updates, tuple(new_state) - - return base.GradientTransformation(init_fn, update_fn) - - -class MultiTransformState(NamedTuple): - inner_states: Mapping[Hashable, NamedTuple] - - -def multi_transform( - transforms: Mapping[Hashable, base.GradientTransformation], - param_labels: Union[base.PyTree, Callable[[base.PyTree], base.PyTree]] -) -> base.GradientTransformation: - """Partitions params and applies a different transformation to each subset. - - Below is an example where we apply Adam to the weights and SGD to the biases - of a 2-layer neural network:: - - import optax - import jax - import jax.numpy as jnp - - def map_nested_fn(fn): - '''Recursively apply `fn` to the key-value pairs of a nested dict''' - def map_fn(nested_dict): - return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) - for k, v in nested_dict.items()} - return map_fn - - params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, - 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} - gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients - - label_fn = map_nested_fn(lambda k, _: k) - tx = optax.multi_transform({'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, - label_fn) - state = tx.init(params) - updates, new_state = tx.update(gradients, state, params) - new_params = optax.apply_updates(params, updates) - - Instead of providing a ``label_fn``, you may provide a PyTree of labels - directly. Also, this PyTree may be a prefix of the parameters PyTree. This - is demonstrated in the GAN pseudocode below:: - - generator_params = ... - discriminator_params = ... - all_params = (generator_params, discriminator_params) - param_labels = ('generator', 'discriminator') - - tx = optax.multi_transform( - {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, - param_labels) - - If you would like to not optimize some parameters, you may wrap - ``optax.multi_transform`` with :func:`optax.masked`. - - Args: - transforms: A mapping from labels to transformations. Each transformation - will be only be applied to parameters with the same label. - param_labels: A PyTree that is the same shape or a prefix of the - parameters/updates (or a function that returns one given the parameters as - input). The leaves of this PyTree correspond to the keys of the transforms - (therefore the values at the leaves must be a subset of the keys). - - Returns: - An ``optax.GradientTransformation``. - """ - def make_mask(labels, group): - return jax.tree_util.tree_map(lambda label: label == group, labels) - - def init_fn(params): - labels = param_labels(params) if callable(param_labels) else param_labels - - label_set = set(jax.tree_util.tree_leaves(labels)) - if not label_set.issubset(transforms.keys()): - raise ValueError('Some parameters have no corresponding transformation.\n' - f'Parameter labels: {list(sorted(label_set))} \n' - f'Transforms keys: {list(sorted(transforms.keys()))} \n') - - inner_states = { - group: wrappers.masked(tx, make_mask(labels, group)).init(params) - for group, tx in transforms.items() - } - return MultiTransformState(inner_states) - - def update_fn(updates, state, params=None): - labels = param_labels(updates) if callable(param_labels) else param_labels - new_inner_state = {} - for group, tx in transforms.items(): - masked_tx = wrappers.masked(tx, make_mask(labels, group)) - updates, new_inner_state[group] = masked_tx.update( - updates, state.inner_states[group], params) - return updates, MultiTransformState(new_inner_state) - - return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/combine_test.py b/optax_add_eve/_src/combine_test.py deleted file mode 100644 index 122858e7..00000000 --- a/optax_add_eve/_src/combine_test.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `combine.py`.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import alias -from optax_add_eve._src import combine -from optax_add_eve._src import transform -from optax_add_eve._src import update - - -STEPS = 50 -LR = 1e-2 - - -class ComposeTest(chex.TestCase): - - def setUp(self): - super().setUp() - self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) - self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) - - @chex.all_variants - def test_chain(self): - transformations = [ - transform.scale_by_adam(), - transform.trace(decay=0, nesterov=False), - transform.scale(-LR)] - - # Apply updates with chain. - chain_params = self.init_params - chained_transforms = combine.chain(*transformations) - state = chained_transforms.init(chain_params) - self.assertIsInstance(state, tuple) - - @self.variant - def update_fn(updates, state): - return chained_transforms.update(updates, state) - - for _ in range(STEPS): - updates, state = update_fn(self.per_step_updates, state) - self.assertIsInstance(state, tuple) - chain_params = update.apply_updates(chain_params, updates) - - # Manually apply sequence of transformations. - manual_params = self.init_params - states = [t.init(manual_params) for t in transformations] - for _ in range(STEPS): - updates = self.per_step_updates - new_states = [] - for t, s in zip(transformations, states): - updates, state = t.update(updates, s) - new_states.append(state) - manual_params = update.apply_updates(manual_params, updates) - states = new_states - - # Check equivalence. - chex.assert_trees_all_close(manual_params, chain_params, rtol=1e-4) - - -def _map_keys_fn(fn): - def map_fn(nested_dict): - return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) - for k, v in nested_dict.items()} - return map_fn - - -class MultiTransformTest(chex.TestCase): - """Tests for the multi_transform wrapper.""" - - @chex.all_variants - @parameterized.parameters(True, False) - def test_multi_transform(self, use_fn): - params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}} - params = jax.tree_util.tree_map(jnp.asarray, params) - input_updates = jax.tree_util.tree_map(lambda x: x / 10.0, params) - tx_dict = {'a': transform.scale(-1.0), - 'b': transform.ema(0.0), # stateful - 'c': transform.scale(2.0)} - param_labels = _map_keys_fn(lambda k, _: k[0]) - if not use_fn: - param_labels = param_labels(params) - tx = combine.multi_transform(tx_dict, param_labels) - update_fn = self.variant(tx.update) - state = self.variant(tx.init)(params) - - correct_update_fn = _map_keys_fn( - lambda k, v: {'a': -v, 'b': v, 'c': 2.0*v}[k[0]]) - - updates, state = update_fn(input_updates, state, params) - correct_updates = correct_update_fn(input_updates) - chex.assert_trees_all_close(updates, correct_updates) - - # Check repeated application, this time with no params. - correct_updates = correct_update_fn(correct_updates) - updates, state = update_fn(updates, state) - chex.assert_trees_all_close(updates, correct_updates) - - @parameterized.parameters(list, tuple, dict) - def test_empty(self, container): - init_fn, update_fn = combine.multi_transform( - {0: alias.sgd(1.)}, lambda _: 0) - updates, _ = update_fn(container(), init_fn(container())) - self.assertEqual(updates, container()) - - @chex.all_variants - @parameterized.parameters( - (False, False), (False, True), (True, False), (True, True)) - def test_labels_mismatch(self, use_extra_label, use_fn): - # The labels from label_fn must be a subet of the keys for the tx. - params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} - params = jax.tree_util.tree_map(jnp.asarray, params) - label_tree = {'a': 0, 'b': [1, 0], 'c': 1} # prefix of params - - if use_extra_label: - label_tree['a'] = 3 - - transforms = {0: alias.sgd(1.), - 1: alias.adam(1., b1=0., b2=0.), - 2: transform.trace(1.0)} - init_fn, update_fn = combine.multi_transform( - transforms, (lambda _: label_tree) if use_fn else label_tree) - - if use_extra_label: - with self.assertRaises(ValueError): - self.variant(init_fn)(params) - else: - state = self.variant(init_fn)(params) - updates = jax.tree_util.tree_map(lambda x: x / 10.0, params) - self.variant(update_fn)(updates, state) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/constrain.py b/optax_add_eve/_src/constrain.py deleted file mode 100644 index f1bf38e1..00000000 --- a/optax_add_eve/_src/constrain.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gradient transformations used to enforce specific constraints.""" - -from typing import Any, NamedTuple - -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base - -# pylint:disable=no-value-for-parameter - - -NonNegativeParamsState = base.EmptyState - - -def keep_params_nonnegative() -> base.GradientTransformation: - """Modifies the updates to keep parameters non-negative, i.e. >= 0. - - This transformation ensures that parameters after the update will be - larger than or equal to zero. - In a chain of transformations, this should be the last one. - - WARNING: the transformation expects input params to be non-negative. - When params is negative the transformed update will move them to 0. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return NonNegativeParamsState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - - updates = jax.tree_util.tree_map( - lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class ZeroNansState(NamedTuple): - """Contains a tree. - - The entry `found_nan` has the same tree structure as that of the parameters. - Each leaf is a single boolean which contains True iff a NaN was detected in - the corresponding parameter array at the last call to `update`. - """ - found_nan: Any - - -def zero_nans() -> base.GradientTransformation: - """A transformation which replaces NaNs with 0. - - Zeroing values in gradients is guaranteed to produce a direction of - non-increasing loss. - - The state of the transformation has the same tree structure as that of the - parameters. Each leaf is a single boolean which contains True iff a NaN was - detected in the corresponding parameter array at the last call to `update`. - This state is not used by the transformation internally, but lets users be - aware when NaNs have been zeroed out. - - Returns: - A `GradientTransformation`. - """ - - def init_fn(params): - return ZeroNansState(jax.tree_util.tree_map( - lambda p: jnp.array(False, dtype=jnp.bool_), params)) - - def update_fn(updates, opt_state, params=None): - del params - opt_state = ZeroNansState( - jax.tree_util.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates)) - updates = jax.tree_util.tree_map( - lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates) - return updates, opt_state - - return base.GradientTransformation(init=init_fn, update=update_fn) diff --git a/optax_add_eve/_src/constrain_test.py b/optax_add_eve/_src/constrain_test.py deleted file mode 100644 index ca52232b..00000000 --- a/optax_add_eve/_src/constrain_test.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for optax._src.constrain.""" - -from absl.testing import absltest - -import chex -import jax.numpy as jnp - -from optax_add_eve._src import combine -from optax_add_eve._src import constrain -from optax_add_eve._src import transform -from optax_add_eve._src import update - -STEPS = 50 -LR = 1e-2 - - -class ConstraintsTest(chex.TestCase): - - def test_keep_params_nonnegative(self): - grads = (jnp.array([500., -500., 0.]), - jnp.array([500., -500., 0.]), - jnp.array([500., -500., 0.])) - - params = (jnp.array([-1., -1., -1.]), - jnp.array([1., 1., 1.]), - jnp.array([0., 0., 0.])) - - # vanilla sgd - opt = combine.chain( - transform.trace(decay=0, nesterov=False), transform.scale(-LR)) - opt_state = opt.init(params) - - updates, _ = opt.update(grads, opt_state, params) - new_params = update.apply_updates(params, updates) - - chex.assert_trees_all_close(new_params, (jnp.array([-6., 4., -1.]), - jnp.array([-4., 6., 1.]), - jnp.array([-5., 5., 0.]))) - - # sgd with keeping parameters non-negative - opt = combine.chain( - transform.trace(decay=0, nesterov=False), transform.scale(-LR), - constrain.keep_params_nonnegative()) - opt_state = opt.init(params) - - updates, _ = opt.update(grads, opt_state, params) - new_params = update.apply_updates(params, updates) - - chex.assert_trees_all_close(new_params, (jnp.array([0., 4., 0.]), - jnp.array([0., 6., 1.]), - jnp.array([0., 5., 0.]))) - - @chex.all_variants - def test_zero_nans(self): - params = (jnp.zeros([3]), jnp.zeros([3]), jnp.zeros([3])) - - opt = constrain.zero_nans() - opt_state = self.variant(opt.init)(params) - update_fn = self.variant(opt.update) - - chex.assert_trees_all_close( - opt_state, - constrain.ZeroNansState((jnp.array(False),) * 3)) - - # Check an upate with nans - grads_with_nans = (jnp.ones([3]), - jnp.array([1., float('nan'), float('nan')]), - jnp.array([float('nan'), 1., 1.])) - updates, opt_state = update_fn(grads_with_nans, opt_state) - chex.assert_trees_all_close( - opt_state, - constrain.ZeroNansState( - (jnp.array(False), jnp.array(True), jnp.array(True)))) - chex.assert_trees_all_close( - updates, - (jnp.ones([3]), jnp.array([1., 0., 0.]), jnp.array([0., 1., 1.]))) - - # Check an upate with nans and infs - grads_with_nans_infs = (jnp.ones([3]), - jnp.array([1., float('nan'), - float('nan')]), - jnp.array([float('inf'), 1., 1.])) - updates, opt_state = update_fn(grads_with_nans_infs, opt_state) - chex.assert_trees_all_close( - opt_state, - constrain.ZeroNansState( - (jnp.array(False), jnp.array(True), jnp.array(False)))) - chex.assert_trees_all_close(updates, (jnp.ones([3]), jnp.array( - [1., 0., 0.]), jnp.array([float('inf'), 1., 1.]))) - - # Check an upate with only good values - grads = (jnp.ones([3]), jnp.ones([3]), jnp.ones([3])) - updates, opt_state = update_fn(grads, opt_state) - chex.assert_trees_all_close( - opt_state, - constrain.ZeroNansState( - (jnp.array(False), jnp.array(False), jnp.array(False)))) - chex.assert_trees_all_close(updates, grads) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/control_variates.py b/optax_add_eve/_src/control_variates.py deleted file mode 100644 index 33316a76..00000000 --- a/optax_add_eve/_src/control_variates.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Implementation of control variates. - - We are interested in computing the gradient using control variates: - \nabla_{\theta} E_{p(x; \theta)} f(x) - = \nabla_{\theta} [E_{p(x; \theta)} f(x) - h(x; \theta) + E_{p(x; \theta)}] - = \nabla_{\theta} [E_{p(x; \theta)} f(x) - h(x; \theta)] - + \nabla_{\theta} E_{p(x; \theta)}] - = \nabla_{\theta} [E_{p(x; \theta)} f(x) - h(x; \theta)] - + \nabla_{\theta} E_{p(x; \theta)}] - = \nabla_{\theta} \int {p(x; \theta)} (f(x) - h(x; \theta)) dx - + \nabla_{\theta} E_{p(x; \theta)}] - = \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx - + [E_{p(x; \theta)} \nabla_{\theta} (f(x) - h(x; \theta)) - + \nabla_{\theta} E_{p(x; \theta)}] - = \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx - - [E_{p(x; \theta)} \nabla_{\theta} h(x; \theta) - + \nabla_{\theta} E_{p(x; \theta)}] - - The above computation is performed in `control_variates_jacobians`. - - When adding a new control variate, one does not need to implement the jacobian - computation, but instead has to implement the forward computation. - - Each control variate implemented has to satisfy the following API: - * control_variate(function) - This returns a tuple of three functions: - * The first element of the tuple is a function which returns the - control variate value for a set of samples. It takes in as - arguments the parameters used to construct the distribution, - the distributional samples, and the state of the control variate - (if any). The return value of this function will have shape - `num_samples`, where `num_samples` is the number of samples - provided as input. - * The second is a function returns the expected value of the control - variate. The input arguments of this function are the parameters - of the distribution and the state of the control variate. - * The third is a function which updates the state of the control - variate, and returns the updated states. - - For examples, see `control_delta_method` and `moving_avg_baseline`. -""" -from typing import Any, Callable, Sequence, Tuple - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base - - -CvState = Any -ComputeCv = Callable[[base.Params, chex.Array, CvState], float] -CvExpectedValue = Callable[[base.Params, CvState], CvState] -UpdateCvState = Callable[[base.Params, chex.Array, CvState], CvState] -ControlVariate = Tuple[ComputeCv, CvExpectedValue, UpdateCvState] - - -def control_delta_method( - function: Callable[[chex.Array], float]) -> ControlVariate: - """The control delta covariate method. - - Control variate obtained by performing a second order Taylor expansion - on the cost function f at the mean of the input distribution. - - Only implemented for Gaussian random variables. - - For details, see: https://icml.cc/2012/papers/687.pdf - - Args: - function: The function for which to compute the control variate. - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - - Returns: - A tuple of three functions, to compute the control variate, the - expected value of the control variate, and to update the control variate - state. - """ - - def delta( - params: base.Params, - sample: chex.Array, - state: CvState = None) -> chex.Array: - """"Second order expansion of `function` at the mean of the input dist.""" - del state - mean_dist = params[0] - centered_sample = sample - mean_dist - # Function is a function of samples. Here, we use the mean as the input - # since we do a Taylor expansion of function around the mean. - grads = jax.grad(function)(mean_dist) - hessians = jax.hessian(function)(mean_dist) - assert hessians.ndim == 2 - control_variate = function(mean_dist) - control_variate += jnp.dot(centered_sample, grads) - control_variate += jnp.dot( - jnp.dot(centered_sample, hessians), centered_sample) / 2. - return control_variate - - def expected_value_delta( - params: base.Params, state: CvState) -> float: - """"Expected value of second order expansion of `function` at dist mean.""" - del state - mean_dist = params[0] - var_dist = jnp.square(jnp.exp(params[1])) - hessians = jax.hessian(function)(mean_dist) - - assert hessians.ndim == 2 - hess_diags = jnp.diag(hessians) - assert hess_diags.ndim == 1 - - # Trace (Hessian * Sigma) and we use that Sigma is diagonal. - expected_second_order_term = jnp.sum(var_dist * hess_diags) / 2. - - expected_control_variate = function(mean_dist) - expected_control_variate += expected_second_order_term - return expected_control_variate - - def update_state( - params: base.Params, - samples: chex.Array, - state: CvState = None) -> CvState: - """"No state kept, so no operation is done.""" - del params, samples - return state - - return delta, expected_value_delta, update_state - - -def moving_avg_baseline( - function: Callable[[chex.Array], float], - decay: float = 0.99, - zero_debias: bool = True, - use_decay_early_training_heuristic=True) -> ControlVariate: - """A moving average baseline. - - It has no effect on the pathwise or measure valued estimator. - - Args: - function: The function for which to compute the control variate. - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - decay: The decay rate for the moving average. - zero_debias: Whether or not to use zero debiasing for the moving average. - use_decay_early_training_heuristic: Whether or not to use a heuristic which - overrides the decay value early in training based on - min(decay, (1.0 + i) / (10.0 + i)). This stabilises training and was - adapted from the Tensorflow codebase. - - Returns: - A tuple of three functions, to compute the control variate, the - expected value of the control variate, and to update the control variate - state. - """ - def moving_avg( - params: base.Params, - samples: chex.Array, - state: CvState = None) -> CvState: - """"Return the moving average.""" - del params, samples - return state[0] - - def expected_value_moving_avg( - params: base.Params, state: CvState) -> chex.Array: - """"Return the moving average.""" - del params - return state[0] - - def update_state( - params: base.Params, - samples: chex.Array, - state: CvState = None) -> CvState: - """"Update the moving average.""" - del params - value, i = state - - if use_decay_early_training_heuristic: - iteration_decay = jnp.minimum(decay, (1.0 + i) / (10.0 + i)) - else: - iteration_decay = decay - - updated_value = iteration_decay * value - updated_value += (1 - iteration_decay) * jnp.mean( - jax.vmap(function)(samples)) - - if zero_debias: - updated_value /= (jnp.ones([]) - jnp.power(iteration_decay, i + 1)) - - return (jax.lax.stop_gradient(updated_value), i + 1) - - return moving_avg, expected_value_moving_avg, update_state - - -def _map(cv, params, samples, state): - return jax.vmap(lambda x: cv(params, x, state))(samples) - - -def control_variates_jacobians( - function: Callable[[chex.Array], float], - control_variate_from_function: Callable[[Callable[[chex.Array], float]], - ControlVariate], - grad_estimator: Callable[..., jnp.array], - params: base.Params, - dist_builder: Callable[..., Any], - rng: chex.PRNGKey, - num_samples: int, - control_variate_state: CvState = None, - estimate_cv_coeffs: bool = False, - estimate_cv_coeffs_num_samples: int = 20) -> Tuple[ - Sequence[chex.Array], CvState]: - r"""Obtain jacobians using control variates. - - We will compute each term individually. The first term will use stochastic - gradient estimation. The second term will be computes using Monte - Carlo estimation and automatic differentiation to compute - \nabla_{\theta} h(x; \theta). The the third term will be computed using - automatic differentiation, as we restrict ourselves to control variates - which compute this expectation in closed form. - - This function updates the state of the control variate (once), before - computing the control variate coefficients. - - Args: - function: Function f(x) for which to estimate grads_{params} E_dist f(x). - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - control_variate_from_function: The control variate to use to reduce - variance. See `control_delta_method` and `moving_avg_baseline` examples. - grad_estimator: The gradient estimator to be used to compute the gradients. - Note that not all control variates will reduce variance for all - estimators. For example, the `moving_avg_baseline` will make no difference - to the measure valued or pathwise estimators. - params: A tuple of jnp arrays. - The parameters for which to construct the distribution and for which we - want to compute the jacobians. - dist_builder: a constructor which builds a distribution given the input - parameters specified by params. `dist_builder(params)` should return a - valid distribution. - rng: a PRNGKey key. - num_samples: Int, the number of samples used to compute the grads. - control_variate_state: The control variate state. This is used for control - variates which keep states (such as the moving average baselines). - estimate_cv_coeffs: Boolean. Whether or not to estimate the optimal control - variate coefficient via `estimate_control_variate_coefficients`. - estimate_cv_coeffs_num_samples: The number of samples to use to estimate - the optimal coefficient. These need to be new samples to ensure that the - objective is unbiased. - - Returns: - A tuple of size two: - * A tuple of size `params`, each element is `num_samples x param.shape` - jacobian vector containing the estimates of the gradients obtained - for each sample. - The mean of this vector is the gradient wrt to parameters that can be - used for learning. The entire jacobian vector can be used to assess - estimator variance. - * The updated CV state. - """ - control_variate = control_variate_from_function(function) - stochastic_cv, expected_value_cv, update_state_cv = control_variate - data_dim = params[0].shape[0] - if estimate_cv_coeffs: - cv_coeffs = estimate_control_variate_coefficients( - function, control_variate_from_function, grad_estimator, params, - dist_builder, rng, estimate_cv_coeffs_num_samples, - control_variate_state) - else: - cv_coeffs = [1.0] * len(params) - - # \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx - function_jacobians = grad_estimator( - function, params, dist_builder, rng, num_samples) - - # Chain rule since CVs can also depend on parameters - for example, for the - # pathwise gradient estimator they have in order to have an effect on - # gradient. - # The rng has to be the same as passed to the grad_estimator above so that we - # obtain the same samples. - samples = dist_builder(*params).sample((num_samples,), seed=rng) - # If the CV has state, update it. - control_variate_state = update_state_cv( - params, samples, control_variate_state) - - def samples_fn(x): - return stochastic_cv( - jax.lax.stop_gradient(params), x, control_variate_state) - - cv_jacobians = grad_estimator( - samples_fn, params, dist_builder, rng, num_samples) - - # The gradients of the stochastic covariate with respect to the parameters. - def param_fn(x): - return jnp.mean(_map( - stochastic_cv, x, - jax.lax.stop_gradient(samples), control_variate_state)) - - # [E_{p(x; \theta)} \nabla_{\theta} h(x; \theta) - cv_param_grads = jax.grad(param_fn)(params) - # The gradients of the closed form expectation of the control variate - # with respect to the parameters: # \nabla_{\theta} E_{p(x; \theta)}]. - expected_value_grads = jax.grad( - lambda x: expected_value_cv(x, control_variate_state))(params) - - jacobians = [] - for param_index, param in enumerate(params): - chex.assert_shape(function_jacobians[param_index], (num_samples, data_dim)) - chex.assert_shape(cv_jacobians[param_index], (num_samples, data_dim)) - chex.assert_shape(cv_param_grads[param_index], (data_dim,)) - chex.assert_shape(expected_value_grads[param_index], (data_dim,)) - - cv_coeff = cv_coeffs[param_index] - # \int \nabla_{\theta} {p(x; \theta)} (f(x) - h(x; \theta)) dx - param_jacobians = function_jacobians[param_index] - param_jacobians -= cv_coeff * cv_jacobians[param_index] - # - [E_{p(x; \theta)} \nabla_{\theta} h(x; \theta) - param_jacobians -= cv_coeff * cv_param_grads[param_index] - # \nabla_{\theta} E_{p(x; \theta)}] - param_jacobians += cv_coeff * expected_value_grads[param_index] - - chex.assert_shape(param_jacobians, (num_samples,) + param.shape) - jacobians.append(param_jacobians) - - return jacobians, control_variate_state - - -def estimate_control_variate_coefficients( - function: Callable[[chex.Array], float], - control_variate_from_function: Callable[[Callable[[chex.Array], float]], - ControlVariate], - grad_estimator: Callable[..., jnp.array], - params: base.Params, - dist_builder: Callable[..., Any], - rng: chex.PRNGKey, - num_samples: int, - control_variate_state: CvState = None, - eps: float = 1e-3) -> Sequence[float]: - r"""Estimates the control variate coefficients for the given parameters. - - For each variable `var_k`, the coefficient is given by: - \sum_k cov(df/d var_k, d cv/d var_k) / (\sum var(d cv/d var_k) + eps) - - Where var_k is the k'th element of the parameters in `params`. - The covariance and variance calculations are done from samples obtained - from the distribution obtained by calling `dist_builder` on the input - `params`. - - This function does not update the state of the control variate. - - Args: - function: Function f(x) for which to estimate grads_{params} E_dist f(x). - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - control_variate_from_function: The control variate to use to reduce - variance. See `control_delta_method` and `moving_avg_baseline` examples. - grad_estimator: The gradient estimator to be used to compute the gradients. - Note that not all control variates will reduce variance for all - estimators. For example, the `moving_avg_baseline` will make no difference - to the measure valued or pathwise estimators. - params: A tuple of jnp arrays. - The parameters for which to construct the distribution and for which we - want to compute the jacobians. - dist_builder: a constructor which builds a distribution given the input - parameters specified by params. `dist_builder(params)` should return a - valid distribution. - rng: a PRNGKey key. - num_samples: Int, the number of samples used to compute the grads. - control_variate_state: The control variate state. This is used for control - variates which keep states (such as the moving average baselines). - eps: A small constant used to avoid numerical issues. Float. - - Returns: - A list of control variate coefficients (each a scalar), for each parameter - in `params`. - """ - # Resample to avoid biased gradients. - cv_rng, _ = jax.random.split(rng) - del rng # Avoid using rng in this function. - stochastic_cv, _, _ = control_variate_from_function(function) - - # Samples have to be the same so we use the same rng. - cv_jacobians = grad_estimator( - lambda x: stochastic_cv(params, x, control_variate_state), - params, dist_builder, cv_rng, num_samples) - function_jacobians = grad_estimator( - function, params, dist_builder, cv_rng, num_samples) - - def compute_coeff(param_cv_jacs, param_f_jacs): - assert param_f_jacs.ndim == 2 - assert param_cv_jacs.ndim == 2 - - mean_f = jnp.mean(param_f_jacs, axis=0) - mean_cv = jnp.mean(param_cv_jacs, axis=0) - - cov = jnp.mean((param_f_jacs - mean_f) * (param_cv_jacs - mean_cv), axis=0) - - assert cov.ndim == 1 - - # Compute the coefficients which minimize variance. - # Since we want to minimize the variances across parameter dimensions, - # the optimal coefficients are given by the sum of covariances per - # dimensions over the sum of variances per dimension. - cv_coeff = jnp.sum(cov) / (jnp.sum(jnp.var(param_cv_jacs, axis=0)) + eps) - return jax.lax.stop_gradient(cv_coeff) - - return [compute_coeff(cv_jacobians[i], function_jacobians[i]) - for i in range(len(params))] diff --git a/optax_add_eve/_src/control_variates_test.py b/optax_add_eve/_src/control_variates_test.py deleted file mode 100644 index 3dc2edd7..00000000 --- a/optax_add_eve/_src/control_variates_test.py +++ /dev/null @@ -1,595 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `control_variates.py`.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import control_variates -from optax_add_eve._src import stochastic_gradient_estimators as sge -from optax_add_eve._src import utils - - -# Set seed for deterministic sampling. -np.random.seed(42) - - -def _assert_equal(actual, expected, rtol=1e-2, atol=1e-2): - """Asserts that arrays are equal.""" - # Note: assert_allclose does not check shapes - chex.assert_equal_shape((actual, expected)) - - # Scalar. - if not actual.shape: - np.testing.assert_allclose( - np.asarray(actual), np.asarray(expected), rtol, atol) - return - - # We get around the bug https://github.com/numpy/numpy/issues/13801 - zero_indices = np.argwhere(expected == 0) - if not np.all(np.abs(actual[zero_indices]) <= atol): - raise AssertionError(f'Larger than {atol} diff in {actual[zero_indices]}') - - non_zero_indices = np.argwhere(expected != 0) - np.testing.assert_allclose( - np.asarray(actual)[non_zero_indices], - expected[non_zero_indices], rtol, atol) - - -def _map(cv, params, samples, state=None): - return jax.vmap(lambda x: cv(params, x, state))(samples) - - -def _map_variant(variant): - return variant(_map, static_argnums=0) - - -def _cv_jac_variant(variant): - return variant( - control_variates.control_variates_jacobians, - static_argnums=(0, 1, 2, 4, 6, 7, 8)) - - -class DeltaControlVariateTest(chex.TestCase): - - @chex.all_variants - @parameterized.parameters([(1.0, 0.5)]) - def testQuadraticFunction(self, effective_mean, effective_log_scale): - data_dims = 20 - num_samples = 10**6 - rng = jax.random.PRNGKey(1) - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - params = [mean, log_scale] - - dist = utils.multi_normal(*params) - dist_samples = dist.sample((num_samples,), rng) - function = lambda x: jnp.sum(x**2) - - cv, expected_cv, _ = control_variates.control_delta_method(function) - avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) - expected_cv_value = jnp.sum(dist_samples**2) / num_samples - - # This should be an analytical computation, the result needs to be - # accurate. - _assert_equal(avg_cv, expected_cv_value, rtol=1e-1, atol=1e-3) - _assert_equal(expected_cv(params, None), expected_cv_value, rtol=0.02) - - @chex.all_variants - @parameterized.parameters([(1.0, 1.0)]) - def testPolinomialFunction(self, effective_mean, effective_log_scale): - data_dims = 10 - num_samples = 10**3 - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - params = [mean, log_scale] - - dist = utils.multi_normal(*params) - rng = jax.random.PRNGKey(1) - dist_samples = dist.sample((num_samples,), rng) - function = lambda x: jnp.sum(x**5) - - cv, expected_cv, _ = control_variates.control_delta_method(function) - avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) - - # Check that the average value of the control variate is close to the - # expected value. - _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3) - - @chex.all_variants - def testNonPolynomialFunction(self): - data_dims = 10 - num_samples = 10**3 - - mean = jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = jnp.ones(shape=(data_dims), dtype=jnp.float32) - params = [mean, log_scale] - - rng = jax.random.PRNGKey(1) - dist = utils.multi_normal(*params) - dist_samples = dist.sample((num_samples,), rng) - function = lambda x: jnp.sum(jnp.log(x**2)) - - cv, expected_cv, _ = control_variates.control_delta_method(function) - avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) - - # Check that the average value of the control variate is close to the - # expected value. - _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3) - - # Second order expansion is log(\mu**2) + 1/2 * \sigma**2 (-2 / \mu**2) - expected_cv_val = - np.exp(1.) ** 2 * data_dims - _assert_equal( - expected_cv(params, None), expected_cv_val, rtol=1e-1, atol=1e-3) - - -class MovingAverageBaselineTest(chex.TestCase): - - @chex.all_variants - @parameterized.parameters( - [(1.0, 0.5, 0.9), - (1.0, 0.5, 0.99)]) - def testLinearFunction( - self, effective_mean, effective_log_scale, decay): - weights = jnp.array([1., 2., 3.], dtype=jnp.float32) - num_samples = 10**4 - data_dims = len(weights) - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) - - rng = jax.random.PRNGKey(1) - dist = utils.multi_normal(*params) - dist_samples = dist.sample((num_samples,), rng) - - cv, expected_cv, update_state = control_variates.moving_avg_baseline( - function, decay=decay, zero_debias=False, - use_decay_early_training_heuristic=False) - - state_1 = jnp.array(1.) - avg_cv = jnp.mean(_map_variant(self.variant)( - cv, params, dist_samples, (state_1, 0))) - _assert_equal(avg_cv, state_1) - _assert_equal(expected_cv(params, (state_1, 0)), state_1) - - state_2 = jnp.array(2.) - avg_cv = jnp.mean( - _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0))) - _assert_equal(avg_cv, state_2) - _assert_equal(expected_cv(params, (state_2, 0)), state_2) - - update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] - _assert_equal( - update_state_1, - decay * state_1 + (1 - decay) * function(mean)) - - update_state_2 = update_state(params, dist_samples, (state_2, 0))[0] - _assert_equal( - update_state_2, - decay * state_2 + (1 - decay) * function(mean)) - - @chex.all_variants - @parameterized.parameters( - [(1.0, 0.5, 0.9), - (1.0, 0.5, 0.99)]) - def testLinearFunctionWithHeuristic( - self, effective_mean, effective_log_scale, decay): - weights = jnp.array([1., 2., 3.], dtype=jnp.float32) - num_samples = 10**5 - data_dims = len(weights) - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) - - rng = jax.random.PRNGKey(1) - dist = utils.multi_normal(*params) - dist_samples = dist.sample((num_samples,), rng) - - cv, expected_cv, update_state = control_variates.moving_avg_baseline( - function, decay=decay, zero_debias=False, - use_decay_early_training_heuristic=True) - - state_1 = jnp.array(1.) - avg_cv = jnp.mean(_map_variant(self.variant)( - cv, params, dist_samples, (state_1, 0))) - _assert_equal(avg_cv, state_1) - _assert_equal(expected_cv(params, (state_1, 0)), state_1) - - state_2 = jnp.array(2.) - avg_cv = jnp.mean( - _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0))) - _assert_equal(avg_cv, state_2) - _assert_equal(expected_cv(params, (state_2, 0)), state_2) - - first_step_decay = 0.1 - update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] - _assert_equal( - update_state_1, - first_step_decay * state_1 + (1 - first_step_decay) * function(mean)) - - second_step_decay = 2. / 11 - update_state_2 = update_state(params, dist_samples, (state_2, 1))[0] - _assert_equal( - update_state_2, - second_step_decay * state_2 + (1 - second_step_decay) * function(mean)) - - @parameterized.parameters( - [(1.0, 0.5, 0.9), - (1.0, 0.5, 0.99)]) - def testLinearFunctionZeroDebias( - self, effective_mean, effective_log_scale, decay): - weights = jnp.array([1., 2., 3.], dtype=jnp.float32) - num_samples = 10**5 - data_dims = len(weights) - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) - - rng = jax.random.PRNGKey(1) - dist = utils.multi_normal(*params) - dist_samples = dist.sample((num_samples,), rng) - - update_state = control_variates.moving_avg_baseline( - function, decay=decay, zero_debias=False, - use_decay_early_training_heuristic=False)[-1] - - update_state_zero_debias = control_variates.moving_avg_baseline( - function, decay=decay, zero_debias=True, - use_decay_early_training_heuristic=False)[-1] - - updated_state = update_state(params, dist_samples, (jnp.array(0.), 0))[0] - _assert_equal(updated_state, (1 - decay) * function(mean)) - - updated_state_zero_debias = update_state_zero_debias( - params, dist_samples, (jnp.array(0.), 0))[0] - _assert_equal( - updated_state_zero_debias, function(mean)) - - -class DeltaMethodAnalyticalExpectedGrads(chex.TestCase): - """Tests for grads approximations.""" - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), - ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), - ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), - ], [ - ('estimate_cv_coeffs', True), - ('no_estimate_cv_coeffs', False), - ], - named=True)) - def testQuadraticFunction(self, effective_mean, effective_log_scale, - grad_estimator, estimate_cv_coeffs): - data_dims = 3 - num_samples = 10**3 - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.sum(x**2) - rng = jax.random.PRNGKey(1) - - jacobians = _cv_jac_variant(self.variant)( - function, - control_variates.control_delta_method, - grad_estimator, - params, - utils.multi_normal, # dist_builder - rng, - num_samples, - None, # No cv state. - estimate_cv_coeffs)[0] - - expected_mean_grads = 2 * effective_mean * np.ones( - data_dims, dtype=np.float32) - expected_log_scale_grads = 2 * np.exp(2 * effective_log_scale) * np.ones( - data_dims, dtype=np.float32) - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) - - _assert_equal(mean_grads_from_jacobian, expected_mean_grads, - rtol=1e-1, atol=1e-3) - _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, - rtol=1e-1, atol=1e-3) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), - ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), - ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), - ], [ - ('estimate_cv_coeffs', True), - ('no_estimate_cv_coeffs', False), - ], - named=True)) - def testCubicFunction( - self, effective_mean, effective_log_scale, grad_estimator, - estimate_cv_coeffs): - data_dims = 1 - num_samples = 10**5 - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.sum(x**3) - rng = jax.random.PRNGKey(1) - - jacobians = _cv_jac_variant(self.variant)( - function, - control_variates.control_delta_method, - grad_estimator, - params, - utils.multi_normal, - rng, - num_samples, - None, # No cv state. - estimate_cv_coeffs)[0] - - # The third order uncentered moment of the Gaussian distribution is - # mu**3 + 2 mu * sigma **2. We use that to compute the expected value - # of the gradients. Note: for the log scale we need use the chain rule. - expected_mean_grads = ( - 3 * effective_mean**2 + 3 * np.exp(effective_log_scale)**2) - expected_mean_grads *= np.ones(data_dims, dtype=np.float32) - expected_log_scale_grads = ( - 6 * effective_mean * np.exp(effective_log_scale) ** 2) - expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32) - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) - - _assert_equal(mean_grads_from_jacobian, expected_mean_grads, - rtol=1e-1, atol=1e-3) - - _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, - rtol=1e-1, atol=1e-3) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), - ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), - ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), - ], [ - ('estimate_cv_coeffs', True), - ('no_estimate_cv_coeffs', False), - ], - named=True)) - def testForthPowerFunction( - self, effective_mean, effective_log_scale, grad_estimator, - estimate_cv_coeffs): - data_dims = 1 - num_samples = 10**5 - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.sum(x**4) - rng = jax.random.PRNGKey(1) - - jacobians = _cv_jac_variant(self.variant)( - function, - control_variates.control_delta_method, - grad_estimator, - params, - utils.multi_normal, - rng, - num_samples, - None, # No cv state - estimate_cv_coeffs)[0] - # The third order uncentered moment of the Gaussian distribution is - # mu**4 + 6 mu **2 sigma **2 + 3 sigma**4. We use that to compute the - # expected value of the gradients. - # Note: for the log scale we need use the chain rule. - expected_mean_grads = ( - 3 * effective_mean**3 - + 12 * effective_mean * np.exp(effective_log_scale)**2) - expected_mean_grads *= np.ones(data_dims, dtype=np.float32) - expected_log_scale_grads = 12 * ( - effective_mean**2 * np.exp(effective_log_scale) + - np.exp(effective_log_scale) ** 3) * np.exp(effective_log_scale) - expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32) - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) - - _assert_equal(mean_grads_from_jacobian, expected_mean_grads, - rtol=1e-1, atol=1e-3) - - _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, - rtol=1e-1, atol=1e-3) - - -class ConsistencyWithStandardEstimators(chex.TestCase): - """Tests for consistency between estimators.""" - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', 1, 1, sge.score_function_jacobians, - 10**6), - ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), - ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians, - 10**5), - ], [ - ('control_delta_method', control_variates.control_delta_method), - ('moving_avg_baseline', control_variates.moving_avg_baseline), - ], - named=True)) - def testWeightedLinearFunction(self, effective_mean, effective_log_scale, - grad_estimator, num_samples, - control_variate_from_function): - """Check that the gradients are consistent between estimators.""" - weights = jnp.array([1., 2., 3.], dtype=jnp.float32) - data_dims = len(weights) - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.sum(weights * x) - rng = jax.random.PRNGKey(1) - cv_rng, ge_rng = jax.random.split(rng) - - jacobians = _cv_jac_variant(self.variant)( - function, - control_variate_from_function, - grad_estimator, - params, - utils.multi_normal, # dist_builder - cv_rng, # rng - num_samples, - (0., 0), # control_variate_state - False)[0] - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads = jnp.mean(mean_jacobians, axis=0) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads = jnp.mean(log_scale_jacobians, axis=0) - - # We use a different random number generator for the gradient estimator - # without the control variate. - no_cv_jacobians = grad_estimator( - function, [mean, log_scale], - utils.multi_normal, ge_rng, num_samples=num_samples) - - no_cv_mean_jacobians = no_cv_jacobians[0] - chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims)) - no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0) - - no_cv_log_scale_jacobians = no_cv_jacobians[1] - chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims)) - no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0) - - _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2) - _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1, atol=5e-2) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', 1, 1, sge.score_function_jacobians, - 10**5), - ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), - ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians, - 10**5), - ], [ - ('control_delta_method', control_variates.control_delta_method), - ('moving_avg_baseline', control_variates.moving_avg_baseline), - ], - named=True)) - def testNonPolynomialFunction( - self, effective_mean, effective_log_scale, - grad_estimator, num_samples, control_variate_from_function): - """Check that the gradients are consistent between estimators.""" - data_dims = 3 - - mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) - log_scale = effective_log_scale * jnp.ones( - shape=(data_dims), dtype=jnp.float32) - - params = [mean, log_scale] - function = lambda x: jnp.log(jnp.sum(x**2)) - rng = jax.random.PRNGKey(1) - cv_rng, ge_rng = jax.random.split(rng) - - jacobians = _cv_jac_variant(self.variant)( - function, - control_variate_from_function, - grad_estimator, - params, - utils.multi_normal, - cv_rng, - num_samples, - (0., 0), # control_variate_state - False)[0] - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads = jnp.mean(mean_jacobians, axis=0) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads = jnp.mean(log_scale_jacobians, axis=0) - - # We use a different random number generator for the gradient estimator - # without the control variate. - no_cv_jacobians = grad_estimator( - function, [mean, log_scale], - utils.multi_normal, ge_rng, num_samples=num_samples) - - no_cv_mean_jacobians = no_cv_jacobians[0] - chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims)) - no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0) - - no_cv_log_scale_jacobians = no_cv_jacobians[1] - chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims)) - no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0) - - _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2) - _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1e-1, atol=5e-2) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/equivalence_test.py b/optax_add_eve/_src/equivalence_test.py deleted file mode 100644 index 9130e0c7..00000000 --- a/optax_add_eve/_src/equivalence_test.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests of equivalence between optax and other optimiser libraries.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -from flax import optim -from jax.example_libraries import optimizers -import jax.numpy as jnp - -from optax_add_eve._src import alias -from optax_add_eve._src import update - - -STEPS = 50 -LR = 1e-2 -LR_SCHED = lambda _: LR # Trivial constant "schedule". - - -class OptimizersEquivalenceTest(chex.TestCase): - - def setUp(self): - super().setUp() - self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4., 5.])) - self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3., 1.])) - - @chex.all_variants - @parameterized.named_parameters( - ('sgd', alias.sgd(LR, 0.0), - optimizers.sgd(LR), 1e-5), - ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), - optimizers.adam(LR, 0.9, 0.999), 1e-4), - ('rmsprop', alias.rmsprop(LR, decay=.9, eps=0.1), - optimizers.rmsprop(LR, .9, 0.1), 1e-5), - ('rmsprop_momentum', alias.rmsprop( - LR, decay=.9, eps=0.1, momentum=0.9), - optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), - ('adagrad', alias.adagrad(LR, 0., 0.,), - optimizers.adagrad(LR, 0.), 1e-5), - ('sgd', alias.sgd(LR_SCHED, 0.0), - optimizers.sgd(LR), 1e-5), - ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), - optimizers.adam(LR, 0.9, 0.999), 1e-4), - ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), - optimizers.rmsprop(LR, .9, 0.1), 1e-5), - ('rmsprop_momentum', alias.rmsprop( - LR_SCHED, decay=.9, eps=0.1, momentum=0.9), - optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), - ('adagrad', alias.adagrad(LR_SCHED, 0., 0.,), - optimizers.adagrad(LR, 0.), 1e-5), - ('sm3', alias.sm3(LR), optimizers.sm3(LR), 1e-2), - ) - def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): - - # example_libraries/optimizers.py - jax_params = self.init_params - opt_init, opt_update, get_params = jax_optimizer - state = opt_init(jax_params) - for i in range(STEPS): - state = opt_update(i, self.per_step_updates, state) - jax_params = get_params(state) - - # optax - optax_params = self.init_params - state = optax_optimizer.init(optax_params) - - @self.variant - def step(updates, state): - return optax_optimizer.update(updates, state) - - for _ in range(STEPS): - updates, state = step(self.per_step_updates, state) - optax_params = update.apply_updates(optax_params, updates) - - # Check equivalence. - chex.assert_trees_all_close(jax_params, optax_params, rtol=rtol) - - -class FlaxOptimizersEquivalenceTest(chex.TestCase): - - def setUp(self): - super().setUp() - self.init_params = ( - jnp.array([1., 0.1, 1., 2.]), jnp.array([3., 4.])) - self.per_step_updates = ( - jnp.array([0., 0.3, 500., 5.]), jnp.array([300., 3.])) - - @parameterized.named_parameters( - ('sgd', - alias.sgd(LR), - optim.GradientDescent(LR)), - ('momentum', - alias.sgd(LR, momentum=0.9), - optim.Momentum(LR, beta=0.9)), # Different names. - ('nesterov_momentum', - alias.sgd(LR, momentum=0.9, nesterov=True), - optim.Momentum(LR, beta=0.9, nesterov=True)), - ('rmsprop', - alias.rmsprop(LR), - optim.RMSProp(LR)), - ('centered_rmsprop', - alias.rmsprop(LR, centered=True), - optim.RMSProp(LR, centered=True)), - ('adam', - alias.adam(LR), - optim.Adam(LR)), - ('adam_w', - alias.adamw(LR, weight_decay=1e-4), - optim.Adam(LR, weight_decay=1e-4)), # Different name. - ('adagrad', - alias.adagrad(LR, initial_accumulator_value=0.), # Different default! - optim.Adagrad(LR)), - ('lamb', - alias.lamb(LR), - optim.LAMB(LR)), - ('lars', - alias.lars( - LR, weight_decay=.5, trust_coefficient=0.003, - momentum=0.9, eps=1e-3), - optim.LARS( - LR, weight_decay=.5, trust_coefficient=0.003, - beta=0.9, eps=1e-3)), - ('adafactor', - alias.adafactor( - learning_rate=LR / 10., - factored=True, - multiply_by_parameter_scale=True, - clipping_threshold=1.0, - decay_rate=0.8, - min_dim_size_to_factor=2), - optim.Adafactor( - learning_rate=LR / 10., - factored=True, - multiply_by_parameter_scale=True, - clipping_threshold=1.0, - decay_rate=0.8, - min_dim_size_to_factor=2)), - ) - def test_flax_optim_equivalence(self, optax_optimizer, flax_optimizer): - - # flax/optim - flax_params = self.init_params - flax_optimizer = flax_optimizer.create(flax_params) - for _ in range(STEPS): - flax_optimizer = flax_optimizer.apply_gradient( - self.per_step_updates) - flax_params = flax_optimizer.target - - # optax - optax_params = self.init_params - state = optax_optimizer.init(optax_params) - for _ in range(STEPS): - updates, state = optax_optimizer.update( - self.per_step_updates, state, optax_params) - optax_params = update.apply_updates(optax_params, updates) - - # Check equivalence. - chex.assert_trees_all_close(flax_params, optax_params, rtol=2e-4) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/experimental/complex_valued.py b/optax_add_eve/_src/experimental/complex_valued.py deleted file mode 100644 index 5c1c7b54..00000000 --- a/optax_add_eve/_src/experimental/complex_valued.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Complex-valued optimization. - -When using `split_real_and_imaginary` to wrap an optimizer, we split the complex -parameters and updates into pairs of real ones before sending them to the -`update` of the wrapped optimizer, and merge the pairs of transformed real -updates into complex ones afterward. In this way, optimizers on complex -parameters behave the same way as if they were running on two real parameters. - -Note that the convention of conjugate for complex gradients in JAX is different -from that in PyTorch and other frameworks, and we need to manually conjugate the -gradients between `jax.grad` and `optimizer.update`. - -See details at https://github.com/deepmind/optax/issues/196 -""" - -from typing import NamedTuple, Union - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base - - -class SplitRealAndImaginaryArrays(NamedTuple): - """A pair of real arrays split from a complex array.""" - real: chex.Array - imaginary: chex.Array - - -def _complex_to_real_pair( - x: chex.Array -) -> Union[chex.Array, SplitRealAndImaginaryArrays]: - """Splits a complex array into a `SplitRealAndImaginaryArrays`. - - Args: - x: The input array, can be complex or real. - - Returns: - `SplitRealAndImaginaryArrays` if the input is a complex array. If the - input is a real array, it is passed through unmodified. - """ - if jnp.iscomplexobj(x): - return SplitRealAndImaginaryArrays(x.real, x.imag) - else: - return x - - -def _real_pair_to_complex( - x: Union[chex.Array, SplitRealAndImaginaryArrays] -) -> chex.Array: - """Merges a `SplitRealAndImaginaryArrays` into a complex array. - - Args: - x: The input `SplitRealAndImaginaryArrays` or array. - - Returns: - A complex array obtained from the real and imaginary parts of the - `SplitRealAndImaginaryArrays`. If the input is not a - `SplitRealAndImaginaryArrays`, it is passed through unmodified. - """ - if isinstance(x, SplitRealAndImaginaryArrays): - return x.real + x.imaginary * 1j - else: - return x - - -class SplitRealAndImaginaryState(NamedTuple): - """Maintains the inner transformation state for `split_real_and_imaginary`.""" - inner_state: base.OptState - - -def split_real_and_imaginary( - inner: base.GradientTransformation -) -> base.GradientTransformation: - """Splits the real and imaginary components of complex updates into two. - - The inner transformation processes real parameters and updates, and the - pairs of transformed real updates are merged into complex updates. - - Parameters and updates that are real before splitting are passed through - unmodified. - - Args: - inner: The inner transformation. - - Returns: - An `optax.GradientTransformation`. - """ - - def init_fn(params): - params = jax.tree_util.tree_map(_complex_to_real_pair, params) - inner_state = inner.init(params) - return SplitRealAndImaginaryState(inner_state) - - def update_fn(updates, state, params=None): - inner_state = state.inner_state - updates = jax.tree_util.tree_map(_complex_to_real_pair, updates) - params = jax.tree_util.tree_map(_complex_to_real_pair, params) - updates, inner_state = inner.update(updates, inner_state, params) - updates = jax.tree_util.tree_map( - _real_pair_to_complex, - updates, - is_leaf=lambda x: isinstance(x, SplitRealAndImaginaryArrays)) - return updates, SplitRealAndImaginaryState(inner_state) - - return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/experimental/complex_valued_test.py b/optax_add_eve/_src/experimental/complex_valued_test.py deleted file mode 100644 index 57ad98e1..00000000 --- a/optax_add_eve/_src/experimental/complex_valued_test.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `complex_valued.py`.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import transform -from optax_add_eve._src import update -from optax_add_eve._src.experimental import complex_valued - - -def _loss_fun_complex_to_real(z): - return (z.conj() * z).real.sum() - - -def _loss_fun_real_to_real(params): - x, y = params - return _loss_fun_complex_to_real(x + y * 1j) - - -class ComplexValuedTest(parameterized.TestCase): - - @chex.all_variants - @parameterized.named_parameters([ - ('adam', transform.scale_by_adam), - ('param_block_norm', transform.scale_by_param_block_norm), - ]) - def test_split_real_and_imaginary(self, scaler_constr): - - def do_update(loss_fun, optimizer, params, opt_state): - loss, grads = jax.value_and_grad(loss_fun)(params) - # Complex gradients need to be conjugated before being added to parameters - grads = jax.tree_util.tree_map(lambda x: x.conj(), grads) - updates, opt_state = self.variant(optimizer.update)( - grads, opt_state, params) - params = update.apply_updates(params, updates) - return loss, grads, params, opt_state - - x = jnp.array([[0.1, 0.2, 0.3], [-0.1, -0.2, -0.3]]) - y = jnp.array([[0.5, -0.5, 0], [0.1, 0.3, -0.2]]) - z = x + y * 1j - - optimizer = scaler_constr() - optimizer_complex = complex_valued.split_real_and_imaginary(optimizer) - opt_state = self.variant(optimizer.init)((x, y)) - opt_state_complex = self.variant(optimizer_complex.init)(z) - - # Check that the loss, the gradients, and the parameters are the same for - # real-to-real and complex-to-real loss functions in each step - for _ in range(3): - loss, (gx, gy), (x, y), opt_state = do_update( - _loss_fun_real_to_real, optimizer, (x, y), opt_state) - loss_complex, gz, z, opt_state_complex = do_update( - _loss_fun_complex_to_real, optimizer_complex, z, opt_state_complex) - np.testing.assert_allclose(loss, loss_complex) - np.testing.assert_allclose(gx + gy * 1j, gz) - np.testing.assert_allclose(x + y * 1j, z) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/experimental/extra_args.py b/optax_add_eve/_src/experimental/extra_args.py deleted file mode 100644 index 7264fbc0..00000000 --- a/optax_add_eve/_src/experimental/extra_args.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Support for extra kwargs in a gradient transformation's `init` and `update`. - -Some users have the need to condition the behavior of a gradient -transformations on dynamical quantities such as the loss. With this experimental -feature we support passing additional kwargs to `init` and `update`. - -We introduce `GradientTransformationWithExtraArgs` as an experimental feature. -You can use the new `named_chain` to combine both old-style and new-style -transformations. We will then monitor users to understand how they use it and -gather feedback from optax users before merging this into the stable API. -""" - -from typing import Any, Mapping, Optional, Tuple, Union, NamedTuple - -from optax_add_eve._src import base -import typing_extensions - - -class InitFnWithExtraArgs(typing_extensions.Protocol): - """Like `TransformInitFn` but with optional `extra_args`.""" - - def __call__( - self, - params: base.Params, - *, - extra_args: Optional[Mapping[str, Any]] = None, - ) -> base.OptState: - """The `init` function.""" - - -class UpdateFnWithExtraArgs(typing_extensions.Protocol): - """Like `TransformUpdateFn` but with optional `extra_args`.""" - - def __call__( - self, - updates: base.Updates, - state: base.OptState, - params: Optional[base.Params] = None, - *, - extra_args: Optional[Mapping[str, Any]] = None, - ) -> Tuple[base.Updates, base.OptState]: - """The `update` function.""" - - -class GradientTransformationWithExtraArgs(NamedTuple): - """A pair of pure functions implementing a gradient transformation. - - GradientTransformationWithExtraArgs is just like GradientTransformation but - both the `init` and `update` functions accept an additional `extra_args` dict. - This can be used to provide additional dynamic information that is not - computed by the GradientTransformation itself (e.g. loss) but that may be - needed by specific algorithms. - """ - init: InitFnWithExtraArgs - update: UpdateFnWithExtraArgs - - -AnyGradientTransformation = Union[ - base.GradientTransformation, GradientTransformationWithExtraArgs] -NamedTransform = Tuple[str, AnyGradientTransformation] - - -def named_chain( - *transforms: NamedTransform) -> GradientTransformationWithExtraArgs: - """Chains optax gradient transformations. - - The `transforms` are `(name, transformation)` pairs, constituted of a string - `name` and an associated gradient transformation `transformation`. The - gradient transformation must be an instance of either - `GradientTransformation` or `GradientTransformationWithExtraArgs`. - - Each `name` is used as key for the state of the corresponding transformation - within the `named_chain` state. Thus the state of the gradient transformation - with a given `name` can be retrieved as `opt_state[name]`. - - The `named_chain` accepts an `extra_args` meta-dictionary whose fields are - the transformations' names and its values are the corresponding extra_args: - - Example: - tx = named_chain(('one', tx1), ('two', tx2)) - - extra_args={ - 'one': {'loss': 0.1}, - 'two': {'loss': 0.3, 'temperature': 0.01}} - tx.init(params, extra_args=extra_args} - tx.update(grads, state, params, extra_args=extra_args) - - # tx1 receives {'loss': 0.1} as extra_args - # tx2 receives {'loss': 0.3, 'temperature': 0.01} as extra_args - - If one of the transformations does not need extra_args the corresponding - name can just be skipped in the `named_chain` extra_args: - - Example: - tx = named_chain(('one', tx1), ('two', tx2)) - - extra_args={'one': {'loss': 0.1}} - tx.init(params, extra_args=extra_args} - tx.update(grads, state, params, extra_args=extra_args) - - # tx1 receives {'loss': 0.1} as extra_args. - # tx2 is called without passing the extra_args. - - Args: - *transforms: an arbitrary number of `(name, tx)` pairs, constituted of a - string `name` and an associated gradient transformation `tx`. The latter - is a `GradientTransformation` or `GradientTransformationWithExtraArgs`. - - Returns: - A single (init_fn, update_fn) tuple. - """ - - names = [name for name, _ in transforms] - if len(names) != len(set(names)): - raise ValueError( - f'Named transformations must have unique names, but got {names}') - - def init_fn(params, *, extra_args=None): - states = {} - for (name, tx) in transforms: - _assert_is_gradient_transformation(tx) - if (extra_args is not None and - isinstance(tx, GradientTransformationWithExtraArgs)): - states[name] = tx.init( - params, extra_args=extra_args.get(name)) - else: - states[name] = tx.init(params) - return states - - def update_fn(updates, state, params=None, *, extra_args=None): - new_state = {} - for (name, tx) in transforms: - _assert_is_gradient_transformation(tx) - if (extra_args is not None and - isinstance(tx, GradientTransformationWithExtraArgs)): - updates, new_state[name] = tx.update( - updates, state[name], params, extra_args=extra_args.get(name)) - else: - updates, new_state[name] = tx.update(updates, state[name], params) - return updates, new_state - - return GradientTransformationWithExtraArgs(init_fn, update_fn) - - -def _assert_is_gradient_transformation(tx): - valid_types = ( - base.GradientTransformation, - GradientTransformationWithExtraArgs) - if not isinstance(tx, valid_types): - raise ValueError( - 'The transformation `tx` must be a valid gradient transformation, ' - 'that is an instance of either `GradientTransformation` or ' - 'an instance of `GradientTransformationWithExtraArgs`') diff --git a/optax_add_eve/_src/experimental/extra_args_test.py b/optax_add_eve/_src/experimental/extra_args_test.py deleted file mode 100644 index b24fca48..00000000 --- a/optax_add_eve/_src/experimental/extra_args_test.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for extra_kwargs.""" - -from absl.testing import absltest -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base -from optax_add_eve._src import transform -from optax_add_eve._src.experimental import extra_args as extra - - -def scale_by_loss(): - """Scale the gradient by the absolute value of the loss.""" - - def init_fn(params, *, extra_args): - del params, extra_args - return base.EmptyState() - - def update_fn(updates, state, params, *, extra_args): - del params - updates = jax.tree_util.tree_map( - lambda u: u / extra_args['loss'], updates) - return updates, state - - return extra.GradientTransformationWithExtraArgs(init_fn, update_fn) - - -class ExtraArgsTest(absltest.TestCase): - - def test_named_chain(self): - tx = extra.named_chain( - ('scale', transform.scale(0.1)), - ('scale_by_policy_loss', scale_by_loss()), - ('scale_by_value_loss', scale_by_loss()), - ) - - params = {'a': jnp.ones((4,))} - grads = params - extra_args = { - 'scale_by_policy_loss': {'loss': 0.01}, - 'scale_by_value_loss': {'loss': 10.0}} - - opt_state = tx.init(params, extra_args=extra_args) - updates, opt_state = tx.update( - grads, opt_state, params, extra_args=extra_args) - chex.assert_trees_all_close(updates, {'a': jnp.ones((4,))}) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/factorized.py b/optax_add_eve/_src/factorized.py deleted file mode 100644 index b3bbec45..00000000 --- a/optax_add_eve/_src/factorized.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Factorized optimizers.""" - -import dataclasses -from typing import NamedTuple, Optional, Tuple, Callable - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import base -from optax_add_eve._src import numerics -from optax_add_eve._src import utils - -# pylint:disable=no-value-for-parameter - - -def _decay_rate_pow(i: int, exponent: float = 0.8) -> float: - """Second-order moment decay schedule.""" - t = jnp.array(i, jnp.float32) + 1.0 - return 1.0 - t**(-exponent) - - -def _factored_dims( - shape: base.Shape, - factored: bool, - min_dim_size_to_factor: int -) -> Optional[Tuple[int, int]]: - """Whether to use a factored second moment estimator. - - This function returns a tuple with the two largest axes to reduce over. - If no two dimensions have size >= min_dim_size_to_factor, return None. - - Args: - shape: an input shape - factored: whether to use factored second-moment estimator for 2d vars. - min_dim_size_to_factor: only factor accumulator if two array dimensions - have at least this size. - - Returns: - None or a tuple of ints - """ - if not factored or len(shape) < 2: - return None - sorted_dims = np.argsort(shape) - if shape[sorted_dims[-2]] < min_dim_size_to_factor: - return None - return int(sorted_dims[-2]), int(sorted_dims[-1]) - - -@dataclasses.dataclass -class _UpdateResult: - """Opaque containter that is not traversed by jax.tree_util.tree_map.""" - update: chex.Array # the update to apply to params - v_row: chex.Array # used for factored params. - v_col: chex.Array # used for factored params. - v: chex.Array # used for params where factoring is skipped. - - -class FactoredState(NamedTuple): - """Overall state of the gradient transformation.""" - count: chex.Array # number of update steps. - v_row: chex.ArrayTree # Tree of factored params. - v_col: chex.ArrayTree # Tree of factored params. - v: chex.ArrayTree # Tree for params where factoring is skipped. - - -def scale_by_factored_rms( - factored: bool = True, - decay_rate: float = 0.8, - step_offset: int = 0, - min_dim_size_to_factor: int = 128, - epsilon: float = 1e-30, - decay_rate_fn: Callable[[int, float], chex.Array] = _decay_rate_pow): - """Scaling by a factored estimate of the gradient rms (as in Adafactor). - - This is a so-called "1+epsilon" scaling algorithms, that is extremely memory - efficient compared to RMSProp/Adam, and has had wide success when applied to - large-scale training of attention-based models. - - References: - [Shazeer et al, 2018](https://arxiv.org/abs/1804.04235) - - Args: - factored: boolean: whether to use factored second-moment estimates.. - decay_rate: float: controls second-moment exponential decay schedule. - step_offset: for finetuning, one may set this to the starting step-number - of the fine tuning phase. - min_dim_size_to_factor: only factor accumulator if two array dimensions - are at least this size. - epsilon: Regularization constant for squared gradient. - decay_rate_fn: A function that accepts the current step, the decay rate - parameter and controls the schedule for the second momentum. Defaults to - the original adafactor's power decay schedule. One potential shortcoming - of the orignal schedule is the fact that second momentum converges to 1, - which effectively freezes the second momentum. To prevent this the user - can opt for a custom schedule that sets an upper bound for the second - momentum, like in [Zhai et al., 2021](https://arxiv.org/abs/2106.04560). - - Returns: - the corresponding `GradientTransformation`. - """ - - def _to_state(count: chex.Array, result_tree): - """Maps from a tree of (factored) values to separate trees of values.""" - return FactoredState( - count=count, - v_row=jax.tree_util.tree_map(lambda o: o.v_row, result_tree), - v_col=jax.tree_util.tree_map(lambda o: o.v_col, result_tree), - v=jax.tree_util.tree_map(lambda o: o.v, result_tree)) - - def init_fn(params): - """Initialise the optimiser's state.""" - - def _init(param): - shape = param.shape - factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) - if factored_dims is not None: - d1, d0 = factored_dims - vr_shape = np.delete(shape, d0) - vc_shape = np.delete(shape, d1) - return _UpdateResult( - update=jnp.zeros((1,)), - v_row=jnp.zeros(vr_shape), - v_col=jnp.zeros(vc_shape), - v=jnp.zeros((1,))) - else: - return _UpdateResult( - update=jnp.zeros((1,)), - v_row=jnp.zeros((1,)), - v_col=jnp.zeros((1,)), - v=jnp.zeros(param.shape)) - - return _to_state( - jnp.zeros([], jnp.int32), jax.tree_util.tree_map(_init, params)) - - def update_fn(grads, state, params): - """Apply gradient transformation.""" - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - - def _update(grad, v_row, v_col, v, param, step): - shape = param.shape - decay_rate_t = decay_rate_fn(step - step_offset, decay_rate) - - # Scaled by factorized second moment statistics. - new_v_row = jnp.zeros((1,)) - new_v_col = jnp.zeros((1,)) - new_v = jnp.zeros((1,)) - - factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) - if factored_dims is not None: - d1, d0 = factored_dims - grad_sqr = numerics.abs_sq(grad) + epsilon - new_v_row = ( - decay_rate_t * v_row + - (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d0)) - new_v_col = ( - decay_rate_t * v_col + - (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1)) - reduced_d1 = d1-1 if d1 > d0 else d1 - row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) - row_factor = (new_v_row / row_col_mean) ** -0.5 - col_factor = (new_v_col) ** -0.5 - update = ( - grad * - jnp.expand_dims(row_factor, axis=d0) * - jnp.expand_dims(col_factor, axis=d1)) - else: - grad_sqr = numerics.abs_sq(grad) + epsilon - new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr - update = grad * (new_v)**-0.5 - - return _UpdateResult(update, new_v_row, new_v_col, new_v) - - # Transform grad and compute new per-parameter stats. - output = jax.tree_util.tree_map( - lambda *args: _update(*args, state.count), - grads, state.v_row, state.v_col, state.v, params) - - # Unpack updates / stats and return. - updates = jax.tree_util.tree_map(lambda o: o.update, output) - return updates, _to_state(utils.safe_int32_increment(state.count), output) - - return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/factorized_test.py b/optax_add_eve/_src/factorized_test.py deleted file mode 100644 index d0e2f90a..00000000 --- a/optax_add_eve/_src/factorized_test.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `factorized.py`.""" - -from absl.testing import parameterized - -import chex -import jax.numpy as jnp - -from optax_add_eve._src import factorized - - -class FactorizedTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) - self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) - - @chex.all_variants - def test_scale_by_factored_rms(self): - params = self.init_params - - scaler = factorized.scale_by_factored_rms() - init_fn = self.variant(scaler.init) - transform_fn = self.variant(scaler.update) - - state = init_fn(params) - chex.assert_tree_all_finite(state) - - updates, state = transform_fn(self.per_step_updates, state, params) - chex.assert_tree_all_finite((params, updates, state)) - chex.assert_tree_all_equal_shapes(params, updates) diff --git a/optax_add_eve/_src/float64_test.py b/optax_add_eve/_src/float64_test.py deleted file mode 100644 index 9f22516d..00000000 --- a/optax_add_eve/_src/float64_test.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests that types are preserved by the `update` calls when jax_enbable_x64.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -from jax.config import config -import jax.numpy as jnp - -from optax_add_eve._src import alias -from optax_add_eve._src import base -from optax_add_eve._src import clipping -from optax_add_eve._src import transform -from optax_add_eve._src import update - - -ALL_MODULES = [ - ('identity', base.identity, {}), - ('clip', clipping.clip, dict(max_delta=1.0)), - ('clip_by_global_norm', clipping.clip_by_global_norm, dict(max_norm=1.0)), - ('trace', transform.trace, dict(decay=0.5, nesterov=False)), - ('trace_with_nesterov', transform.trace, dict(decay=0.5, nesterov=True)), - ('scale_by_rss', transform.scale_by_rss, {}), - ('scale_by_rms', transform.scale_by_rms, {}), - ('scale_by_stddev', transform.scale_by_stddev, {}), - ('adam', transform.scale_by_adam, {}), - ('scale', transform.scale, dict(step_size=3.0)), - ('additive_weight_decay', transform.additive_weight_decay, - dict(weight_decay=0.1)), - ('scale_by_schedule', transform.scale_by_schedule, - dict(step_size_fn=lambda x: x * 0.1)), - ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), - ('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)), - ('apply_every_k', transform.apply_every, {}), - ('adagrad', alias.adagrad, dict(learning_rate=0.1)), - ('adam', alias.adam, dict(learning_rate=0.1)), - ('adamw', alias.adamw, dict(learning_rate=0.1)), - ('fromage', alias.fromage, dict(learning_rate=0.1)), - ('lamb', alias.lamb, dict(learning_rate=0.1)), - ('noisy_sgd', alias.noisy_sgd, dict(learning_rate=0.1)), - ('rmsprop', alias.rmsprop, dict(learning_rate=0.1)), - ('sgd', alias.sgd, dict(learning_rate=0.1)), - ('dpsgd', alias.dpsgd, - dict(learning_rate=0.1, l2_norm_clip=0.9, noise_multiplier=1.1, seed=42)), -] - - -class Float64Test(parameterized.TestCase): - - def _assert_dtype_equals(self, tree1, tree2): - tree1_types = jax.tree_util.tree_map(lambda t: t.dtype, tree1) - tree2_types = jax.tree_util.tree_map(lambda t: t.dtype, tree2) - self.assertEqual(tree1_types, tree2_types) - - @chex.all_variants - @parameterized.named_parameters(ALL_MODULES) - def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs): - initial_params = ( - jnp.array([1., 2.], dtype=jnp.float32), - jnp.array([3., 4.], dtype=jnp.float64)) - updates = ( - jnp.array([10., 21.], dtype=jnp.float32), - jnp.array([33., 42.], dtype=jnp.float64)) - scaler = transform_constr(**transform_kwargs) - init_fn = self.variant(scaler.init) - update_fn = self.variant(scaler.update) - - initial_state = init_fn(initial_params) - updates, new_state = update_fn( - updates, initial_state, params=initial_params) - new_params = update.apply_updates(initial_params, updates) - - self._assert_dtype_equals(initial_state, new_state) - self._assert_dtype_equals(initial_params, new_params) - - -if __name__ == '__main__': - config.update('jax_enable_x64', True) - absltest.main() diff --git a/optax_add_eve/_src/linear_algebra.py b/optax_add_eve/_src/linear_algebra.py deleted file mode 100644 index 0caedd69..00000000 --- a/optax_add_eve/_src/linear_algebra.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Linear algebra utilities used in optimisation.""" - -import chex -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import base -from optax_add_eve._src import numerics - - -def global_norm(updates: base.Updates) -> base.Updates: - """Compute the global norm across a nested structure of tensors.""" - return jnp.sqrt(sum( - jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates))) - - -def power_iteration(matrix: chex.Array, - num_iters: int = 100, - error_tolerance: float = 1e-6, - precision: lax.Precision = lax.Precision.HIGHEST): - r"""Power iteration algorithm. - - The power iteration algorithm takes a symmetric PSD matrix `A`, and produces - a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue - of `A`, and a vector v, which is the corresponding eigenvector of `A`. - - References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) - - Args: - matrix: the symmetric PSD matrix. - num_iters: Number of iterations. - error_tolerance: Iterative exit condition. - precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise); - b) lax.Precision.HIGH (increased precision, slower); - c) lax.Precision.HIGHEST (best possible precision, slowest). - - Returns: - eigen vector, eigen value - """ - matrix_size = matrix.shape[-1] - def _iter_condition(state): - i, unused_v, unused_s, unused_s_v, run_step = state - return jnp.logical_and(i < num_iters, run_step) - - def _iter_body(state): - """One step of power iteration.""" - i, new_v, s, s_v, unused_run_step = state - new_v = new_v / jnp.linalg.norm(new_v) - - s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision) - s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision) - return (i + 1, s_v, s_new, s_v, - jnp.greater(jnp.abs(s_new - s), error_tolerance)) - - # Figure out how to use step as seed for random. - v_0 = np.random.uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype) - - init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) - _, v_out, s_out, _, _ = lax.while_loop( - _iter_condition, _iter_body, init_state) - v_out = v_out / jnp.linalg.norm(v_out) - return v_out, s_out - - -def matrix_inverse_pth_root(matrix: chex.Array, - p: int, - num_iters: int = 100, - ridge_epsilon: float = 1e-6, - error_tolerance: float = 1e-6, - precision: lax.Precision = lax.Precision.HIGHEST): - """Computes `matrix^(-1/p)`, where `p` is a positive integer. - - This function uses the Coupled newton iterations algorithm for - the computation of a matrix's inverse pth root. - - - References: - [Functions of Matrices, Theory and Computation, - Nicholas J Higham, Pg 184, Eq 7.18]( - https://epubs.siam.org/doi/book/10.1137/1.9780898717778) - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - num_iters: Maximum number of iterations. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise); - b) lax.Precision.HIGH (increased precision, slower); - c) lax.Precision.HIGHEST (best possible precision, slowest). - - Returns: - matrix^(-1/p) - """ - - # We use float32 for the matrix inverse pth root. - # Switch to f64 if you have hardware that supports it. - matrix_size = matrix.shape[0] - alpha = jnp.asarray(-1.0 / p, jnp.float32) - identity = jnp.eye(matrix_size, dtype=jnp.float32) - _, max_ev = power_iteration( - matrix=matrix, num_iters=100, - error_tolerance=1e-6, precision=precision) - ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) - - def _unrolled_mat_pow_1(mat_m): - """Computes mat_m^1.""" - return mat_m - - def _unrolled_mat_pow_2(mat_m): - """Computes mat_m^2.""" - return jnp.matmul(mat_m, mat_m, precision=precision) - - def _unrolled_mat_pow_4(mat_m): - """Computes mat_m^4.""" - mat_pow_2 = _unrolled_mat_pow_2(mat_m) - return jnp.matmul( - mat_pow_2, mat_pow_2, precision=precision) - - def _unrolled_mat_pow_8(mat_m): - """Computes mat_m^4.""" - mat_pow_4 = _unrolled_mat_pow_4(mat_m) - return jnp.matmul( - mat_pow_4, mat_pow_4, precision=precision) - - def mat_power(mat_m, p): - """Computes mat_m^p, for p == 1, 2, 4 or 8. - - Args: - mat_m: a square matrix - p: a positive integer - - Returns: - mat_m^p - """ - # We unrolled the loop for performance reasons. - exponent = jnp.round(jnp.log2(p)) - return lax.switch( - jnp.asarray(exponent, jnp.int32), [ - _unrolled_mat_pow_1, - _unrolled_mat_pow_2, - _unrolled_mat_pow_4, - _unrolled_mat_pow_8, - ], (mat_m)) - - def _iter_condition(state): - (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, - run_step) = state - error_above_threshold = jnp.logical_and( - error > error_tolerance, run_step) - return jnp.logical_and(i < num_iters, error_above_threshold) - - def _iter_body(state): - (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state - mat_m_i = (1 - alpha) * identity + alpha * mat_m - new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) - new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) - new_error = jnp.max(jnp.abs(new_mat_m - identity)) - # sometimes error increases after an iteration before decreasing and - # converging. 1.2 factor is used to bound the maximal allowed increase. - return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, - new_error < error * 1.2) - - if matrix_size == 1: - resultant_mat_h = (matrix + ridge_epsilon)**alpha - error = 0 - else: - damped_matrix = matrix + ridge_epsilon * identity - - z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) - new_mat_m_0 = damped_matrix * z - new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) - new_mat_h_0 = identity * jnp.power(z, 1.0 / p) - init_state = tuple( - [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) - _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( - _iter_condition, _iter_body, init_state) - error = jnp.max(jnp.abs(mat_m - identity)) - is_converged = jnp.asarray(convergence, old_mat_h.dtype) - resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h - resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) - return resultant_mat_h, error diff --git a/optax_add_eve/_src/linear_algebra_test.py b/optax_add_eve/_src/linear_algebra_test.py deleted file mode 100644 index 5ad8172b..00000000 --- a/optax_add_eve/_src/linear_algebra_test.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for optax._src.linear_algebra.""" - -from absl.testing import absltest - -import jax.numpy as jnp -import numpy as np -from optax_add_eve._src import linear_algebra -import scipy.stats - - -class LinearAlgebraTest(absltest.TestCase): - - def test_global_norm(self): - flat_updates = jnp.array([2., 4., 3., 5.], dtype=jnp.float32) - nested_updates = dict( - a=jnp.array([2., 4.], dtype=jnp.float32), - b=jnp.array([3., 5.], dtype=jnp.float32)) - np.testing.assert_array_equal( - jnp.sqrt(jnp.sum(flat_updates**2)), - linear_algebra.global_norm(nested_updates)) - - def test_matrix_inverse_pth_root(self): - """Test for matrix inverse pth root.""" - - def _gen_symmetrix_matrix(dim, condition_number): - u = scipy.stats.ortho_group.rvs(dim=dim).astype(np.float64) - v = u.T - diag = np.diag([condition_number ** (-i/(dim-1)) for i in range(dim)]) - return u @ diag @ v - - # Fails after it reaches a particular condition number. - for e in range(2, 12): - condition_number = 10 ** e - ms = _gen_symmetrix_matrix(16, condition_number) - self.assertLess( - np.abs(np.linalg.cond(ms) - condition_number), - condition_number * 0.01) - error = linear_algebra.matrix_inverse_pth_root( - ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1] - if e < 7: - self.assertLess(error, 0.1) - else: - # No guarantee of success after e >= 7 - pass - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/lookahead.py b/optax_add_eve/_src/lookahead.py deleted file mode 100644 index 97b3a6e9..00000000 --- a/optax_add_eve/_src/lookahead.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A lookahead optimization wrapper.""" - -from typing import NamedTuple, Tuple - -from absl import logging -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base - -# pylint:disable=no-value-for-parameter - - -class LookaheadState(NamedTuple): - """State of the `GradientTransformation` returned by `lookahead`. - - Attributes: - fast_state: Optimizer state of the fast optimizer. - steps_since_sync: Number of fast optimizer steps taken since slow and fast - parameters were synchronized. - """ - fast_state: base.OptState - steps_since_sync: jnp.ndarray - - -class LookaheadParams(NamedTuple): - """Holds a pair of slow and fast parameters for the lookahead optimizer. - - Gradients should always be calculated with the fast parameters. The slow - parameters should be used for testing and inference as they generalize better. - See the reference for a detailed discussion. - - References: - [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) - - Attributes: - fast: Fast parameters. - slow: Slow parameters. - """ - fast: base.Params - slow: base.Params - - @classmethod - def init_synced(cls, params: base.Params) -> 'LookaheadParams': - """Initialize a pair of synchronized lookahead parameters.""" - return cls(slow=params, fast=params) - - -def lookahead( - fast_optimizer: base.GradientTransformation, - sync_period: int, - slow_step_size: float, - reset_state: bool = False -) -> base.GradientTransformation: - """Lookahead optimizer. - - Performs steps with a fast optimizer and periodically updates a set of slow - parameters. Optionally resets the fast optimizer state after synchronization - by calling the init function of the fast optimizer. - - Updates returned by the lookahead optimizer should not be modified before they - are applied, otherwise fast and slow parameters are not synchronized - correctly. - - References: - [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) - - Args: - fast_optimizer: The optimizer to use in the inner loop of lookahead. - sync_period: Number of fast optimizer steps to take before synchronizing - parameters. Must be >= 1. - slow_step_size: Step size of the slow parameter updates. - reset_state: Whether to reset the optimizer state of the fast opimizer after - each synchronization. - - Returns: - A `GradientTransformation` with init and update functions. The updates - passed to the update function should be calculated using the fast lookahead - parameters only. - """ - if sync_period < 1: - raise ValueError('Synchronization period must be >= 1.') - - def init_fn(params: base.Params) -> LookaheadState: - try: - fast_params = params.fast - except AttributeError: - # Allowing init_fn to be called with fast parameters reduces the - # modifications necessary to adapt code to use lookahead in some cases. - logging.warning( - '`params` has no attribute `fast`. Continuing by assuming that ' - 'only fast parameters were passed to lookahead init.') - fast_params = params - - return LookaheadState( - fast_state=fast_optimizer.init(fast_params), - steps_since_sync=jnp.zeros(shape=(), dtype=jnp.int32)) - - def update_fn( - updates: base.Updates, state: LookaheadState, - params: LookaheadParams) -> Tuple[LookaheadParams, LookaheadState]: - updates, fast_state = fast_optimizer.update(updates, state.fast_state, - params.fast) - - sync_next = (state.steps_since_sync == sync_period - 1) - updates = _lookahead_update(updates, sync_next, params, slow_step_size) - if reset_state: - # Jittable way of resetting the fast optimizer state if parameters will be - # synchronized after this update step. - initial_state = fast_optimizer.init(params.fast) - fast_state = jax.tree_util.tree_map( - lambda current, init: (1 - sync_next) * current + sync_next * init, - fast_state, initial_state) - - steps_since_sync = (state.steps_since_sync + 1) % sync_period - return updates, LookaheadState(fast_state, steps_since_sync) - - return base.GradientTransformation(init_fn, update_fn) - - -def _lookahead_update( - updates: base.Updates, sync_next: bool, params: LookaheadParams, - slow_step_size: float) -> LookaheadParams: - """Returns the updates corresponding to one lookahead step. - - References: - [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) - - Args: - updates: Updates returned by the fast optimizer. - sync_next: Wether fast and slow parameters should be synchronized after the - fast optimizer step. - params: Current fast and slow parameters as `LookaheadParams` object. - slow_step_size: Step size of the slow optimizer. - - Returns: - The updates for the lookahead parameters. - """ - # In the paper, lookahead is presented as two nested loops. To write lookahead - # as optax wrapper, these loops have to be broken into successive updates. - # This leads to two types of update steps: - # - # Non-synchronization steps (sync_next == False): - # The updates returned by the fast optimizer are used for the fast parameters - # without change and the slow parameter updates are zero (i.e. fast_updates = - # updates, slow_updates = 0). - # - # Synchronisation step (sync_next == True): - # This consists of two substeps: a last fast optimizer step and the - # synchronization. - # Substep 1 (last fast optimizer step): - # last_fast_params = fast_params + updates - # Substep 2 (synchronization): - # new_slow_params = slow_params + slow_step_size * ( - # last_fast_params - slow_params) - # new_fast_params = new_slow_params - # - # Merging into a single update step we get the update rules: - # slow_updates = slow_step_size * (fast_params + updates - slow_params) - # fast_updates = new_slow_params - fast_params = updates - (1 - - # slow_step_size) * (fast_params + updates - slow_params) - # - # To make the equations jittable, the two types of steps are merged. Defining - # last_difference = fast_params + updates - slow_params, this yields the - # following equtions which are implemented below: - # slow_updates = slow_step_size * sync_next * last_difference - # fast_updates = updates - ( - # 1 - slow_step_size) * sync_next * last_difference - last_difference = jax.tree_util.tree_map( - lambda f, u, s: f + u - s, params.fast, updates, params.slow) - slow_updates = jax.tree_util.tree_map( - lambda diff: slow_step_size * sync_next * diff, last_difference) - fast_updates = jax.tree_util.tree_map( - lambda up, diff: up - sync_next * (1 - slow_step_size) * diff, updates, - last_difference) - - return LookaheadParams(fast=fast_updates, slow=slow_updates) - diff --git a/optax_add_eve/_src/lookahead_test.py b/optax_add_eve/_src/lookahead_test.py deleted file mode 100644 index 99964a1d..00000000 --- a/optax_add_eve/_src/lookahead_test.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `lookahead.py`.""" - -from typing import NamedTuple - -from absl.testing import absltest -from absl.testing import parameterized -import chex -import jax -import jax.numpy as jnp -import numpy as np -from optax_add_eve._src import alias -from optax_add_eve._src import base -from optax_add_eve._src import lookahead -from optax_add_eve._src import update - - -def _build_sgd(): - return alias.sgd(1.) - - -class TestOptimizerState(NamedTuple): - """Fast optimizer state for the lookahead tests.""" - aggregate_grads: base.Params - # Include a variable with non-zero initial value to check that it is reset - # correctly by the lookahead optimizer. - is_reset: bool = True - - -def _test_optimizer(step_size: float) -> base.GradientTransformation: - """Fast optimizer for the lookahead tests.""" - - # Use SGD for simplicity but add non-trivial optimizer state so that the - # resetting behaviour of lookahead can be tested. - def init_fn(params): - aggregate_grads = jax.tree_util.tree_map(jnp.zeros_like, params) - return TestOptimizerState(aggregate_grads, is_reset=True) - - def update_fn(updates, state, params): - # The test optimizer does not use the parameters, but we check that they - # have been passed correctly. - chex.assert_trees_all_equal_shapes(updates, params) - aggregate_grads = update.apply_updates(state.aggregate_grads, updates) - updates = jax.tree_util.tree_map(lambda u: step_size * u, updates) - return updates, TestOptimizerState(aggregate_grads, is_reset=False) - - return base.GradientTransformation(init_fn, update_fn) - - -class LookaheadTest(chex.TestCase): - """Tests for the lookahead optimizer.""" - - def setUp(self): - super().setUp() - self.grads = {'x': np.array(2.), 'y': np.array(-2.)} - self.initial_params = {'x': np.array(3.), 'y': np.array(-3.)} - self.synced_initial_params = lookahead.LookaheadParams.init_synced( - self.initial_params) - - def loop(self, optimizer, num_steps, params): - """Performs a given number of optimizer steps.""" - init_fn, update_fn = optimizer - # Use the chex variant to check various function versions (jit, pmap, etc). - step = self.variant(update_fn) - opt_state = self.variant(init_fn)(params) - for _ in range(num_steps): - updates, opt_state = step(self.grads, opt_state, params) - params = update.apply_updates(params, updates) - - return params, opt_state - - @chex.all_variants - def test_lookahead(self): - """Tests the lookahead optimizer in an analytically tractable setting.""" - sync_period = 3 - optimizer = lookahead.lookahead( - _test_optimizer(-0.5), sync_period=sync_period, slow_step_size=1 / 3) - - final_params, _ = self.loop(optimizer, 2 * sync_period, - self.synced_initial_params) - # x steps must be: 3 -> 2 -> 1 -> 2 (sync) -> 1 -> 0 -> 1 (sync). - # Similarly for y (with sign flipped). - correct_final_params = {'x': 1, 'y': -1} - chex.assert_trees_all_close(final_params.slow, correct_final_params) - - @chex.all_variants - @parameterized.parameters([False], [True]) - def test_lookahead_state_reset(self, reset_state): - """Checks that lookahead resets the fast optimizer state correctly.""" - num_steps = sync_period = 3 - fast_optimizer = _test_optimizer(-0.5) - optimizer = lookahead.lookahead( - fast_optimizer, - sync_period=sync_period, - slow_step_size=0.5, - reset_state=reset_state) - - _, opt_state = self.loop(optimizer, num_steps, self.synced_initial_params) - fast_state = opt_state.fast_state - if reset_state: - correct_state = fast_optimizer.init(self.initial_params) - else: - _, correct_state = self.loop(fast_optimizer, num_steps, - self.initial_params) - - chex.assert_trees_all_close(fast_state, correct_state) - - @chex.all_variants - @parameterized.parameters( - [1, 0.5, {'x': np.array(1.), 'y': np.array(-1.)}], - [1, 0, {'x': np.array(3.), 'y': np.array(-3.)}], - [1, 1, {'x': np.array(-1.), 'y': np.array(1.)}], - [2, 1, {'x': np.array(-1.), 'y': np.array(1.)}]) # pyformat: disable - def test_lookahead_edge_cases(self, sync_period, slow_step_size, - correct_result): - """Checks special cases of the lookahed optimizer parameters.""" - # These edge cases are important to check since users might use them as - # simple ways of disabling lookahead in experiments. - optimizer = lookahead.lookahead( - _test_optimizer(-1), sync_period, slow_step_size) - final_params, _ = self.loop( - optimizer, num_steps=2, params=self.synced_initial_params) - chex.assert_trees_all_close(final_params.slow, correct_result) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/loss.py b/optax_add_eve/_src/loss.py deleted file mode 100644 index 578741f1..00000000 --- a/optax_add_eve/_src/loss.py +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Standard losses used in optimisation. - -We provide implementations of the most canonical losses used in deep -learning. These operate transparently on batches, and do not perform any -reduction over the batch dimensions, leaving it to the user to, for instance, -mean or sum losses across batch dimensions. -""" - -from typing import Optional, Tuple - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import utils - - -def l2_loss( - predictions: chex.Array, - targets: Optional[chex.Array] = None, -) -> chex.Array: - """Calculates the L2 loss for a set of predictions. - - Note: the 0.5 term is standard in "Pattern Recognition and Machine Learning" - by Bishop, but not "The Elements of Statistical Learning" by Tibshirani. - - References: - [Chris Bishop, 2006](https://bit.ly/3eeP0ga) - - Args: - predictions: a vector of arbitrary shape `[...]`. - targets: a vector with shape broadcastable to that of `predictions`; - if not provided then it is assumed to be a vector of zeros. - - Returns: - elementwise squared differences, with same shape as `predictions`. - """ - chex.assert_type([predictions], float) - if targets is not None: - # Avoid broadcasting logic for "-" operator. - chex.assert_equal_shape((predictions, targets)) - errors = (predictions - targets) if (targets is not None) else predictions - return 0.5 * (errors)**2 - - -def huber_loss( - predictions: chex.Array, - targets: Optional[chex.Array] = None, - delta: float = 1.) -> chex.Array: - """Huber loss, similar to L2 loss close to zero, L1 loss away from zero. - - If gradient descent is applied to the `huber loss`, it is equivalent to - clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. - - References: - [Huber, 1964](www.projecteuclid.org/download/pdf_1/euclid.aoms/1177703732) - - Args: - predictions: a vector of arbitrary shape `[...]`. - targets: a vector with shape broadcastable to that of `predictions`; - if not provided then it is assumed to be a vector of zeros. - delta: the bounds for the huber loss transformation, defaults at 1. - - Returns: - elementwise huber losses, with the same shape of `predictions`. - """ - chex.assert_type([predictions], float) - errors = (predictions - targets) if (targets is not None) else predictions - # 0.5 * err^2 if |err| <= d - # 0.5 * d^2 + d * (|err| - d) if |err| > d - abs_errors = jnp.abs(errors) - quadratic = jnp.minimum(abs_errors, delta) - # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient. - linear = abs_errors - quadratic - return 0.5 * quadratic ** 2 + delta * linear - - -def smooth_labels( - labels: chex.Array, - alpha: float, -) -> jnp.ndarray: - """Apply label smoothing. - - Label smoothing is often used in combination with a cross-entropy loss. - Smoothed labels favour small logit gaps, and it has been shown that this can - provide better model calibration by preventing overconfident predictions. - - References: - [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) - - Args: - labels: one hot labels to be smoothed. - alpha: the smoothing factor, the greedy category with be assigned - probability `(1-alpha) + alpha / num_categories` - - Returns: - a smoothed version of the one hot input labels. - - """ - chex.assert_type([labels], float) - num_categories = labels.shape[-1] - return (1.0 - alpha) * labels + alpha / num_categories - - -def sigmoid_binary_cross_entropy(logits, labels): - """Computes element-wise sigmoid cross entropy given logits and labels. - - This can be used to measure the error in discrete classification tasks in - which each class is an independent binary prediction and different classes - are not mutually exclusive. This may be used for multilabel image - classification for instance a model may predict that an image contains both a - cat and a dog. - - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - - Args: - logits: Each element is the unnormalized log probability of a binary - prediction. - labels: The target probabilities, must have a shape broadcastable to that of - `logits`. - - Returns: - cross entropy for each binary prediction, same shape as `logits`. - """ - chex.assert_type([logits], float) - log_p = jax.nn.log_sigmoid(logits) - # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable - log_not_p = jax.nn.log_sigmoid(-logits) - return -labels * log_p - (1. - labels) * log_not_p - - -def softmax_cross_entropy( - logits: chex.Array, - labels: chex.Array, -) -> chex.Array: - """Computes the softmax cross entropy between sets of logits and labels. - - Measures the probability error in discrete classification tasks in which - the classes are mutually exclusive (each entry is in exactly one class). - For example, each CIFAR-10 image is labeled with one and only one label: - an image can be a dog or a truck, but not both. - - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - - Args: - logits: Unnormalized log probabilities, with shape `[..., num_classes]`. - labels: Valid probability distributions (non-negative, sum to 1), e.g a - one hot encoding specifying the correct class for each input; - must have a shape broadcastable to `[..., num_classes]`` - - Returns: - cross entropy between each prediction and the corresponding target - distributions, with shape `[...]`. - """ - chex.assert_type([logits], float) - return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) - - -def softmax_cross_entropy_with_integer_labels( - logits: chex.Array, - labels: chex.Array, -) -> chex.Array: - """Computes softmax cross entropy between sets of logits and integer labels. - - Measures the probability error in discrete classification tasks in which - the classes are mutually exclusive (each entry is in exactly one class). - For example, each CIFAR-10 image is labeled with one and only one label: - an image can be a dog or a truck, but not both. - - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - - Args: - logits: Unnormalized log probabilities, with shape `[..., num_classes]`. - labels: Integers specifying the correct class for each input, with shape - `[...]`. - - Returns: - Cross entropy between each prediction and the corresponding target - distributions, with shape `[...]`. - """ - chex.assert_type([logits], float) - chex.assert_type([labels], int) - # This is like jnp.take_along_axis(jax.nn.log_softmax(...), ...) except that - # we avoid subtracting the normalizer from all values, just from the values - # for the correct labels. - logits_max = jnp.max(logits, axis=-1, keepdims=True) - logits -= jax.lax.stop_gradient(logits_max) - label_logits = jnp.take_along_axis(logits, labels[..., None], axis=-1)[..., 0] - log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=-1)) - return log_normalizers - label_logits - - -def cosine_similarity( - predictions: chex.Array, - targets: chex.Array, - epsilon: float = 0., -) -> chex.Array: - r"""Computes the cosine similarity between targets and predictions. - - The cosine **similarity** is a measure of similarity between vectors defined - as the cosine of the angle between them, which is also the inner product of - those vectors normalized to have unit norm. - - References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity) - - Args: - predictions: The predicted vectors, with shape `[..., dim]`. - targets: Ground truth target vectors, with shape `[..., dim]`. - epsilon: minimum norm for terms in the denominator of the cosine similarity. - - Returns: - cosine similarity measures, with shape `[...]`. - """ - chex.assert_type([predictions, targets], float) - # vectorize norm fn, to treat all dimensions except the last as batch dims. - batched_norm_fn = jnp.vectorize( - utils.safe_norm, signature='(k)->()', excluded={1}) - # normalise the last dimension of targets and predictions. - unit_targets = targets / jnp.expand_dims( - batched_norm_fn(targets, epsilon), axis=-1) - unit_predictions = predictions / jnp.expand_dims( - batched_norm_fn(predictions, epsilon), axis=-1) - # return cosine similarity. - return jnp.sum(unit_targets * unit_predictions, axis=-1) - - -def cosine_distance( - predictions: chex.Array, - targets: chex.Array, - epsilon: float = 0., -) -> chex.Array: - r"""Computes the cosine distance between targets and predictions. - - The cosine **distance**, implemented here, measures the **dissimilarity** - of two vectors as the opposite of cosine **similarity**: `1 - cos(\theta)`. - - References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity) - - Args: - predictions: The predicted vectors, with shape `[..., dim]`. - targets: Ground truth target vectors, with shape `[..., dim]`. - epsilon: minimum norm for terms in the denominator of the cosine similarity. - - Returns: - cosine distances, with shape `[...]`. - """ - chex.assert_type([predictions, targets], float) - # cosine distance = 1 - cosine similarity. - return 1. - cosine_similarity(predictions, targets, epsilon) - - -def log_cosh( - predictions: chex.Array, - targets: Optional[chex.Array] = None, -) -> chex.Array: - """Calculates the log-cosh loss for a set of predictions. - - log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` - for large x. It is a twice differentiable alternative to the Huber loss. - - References: - [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) - - Args: - predictions: a vector of arbitrary shape `[...]`. - targets: a vector with shape broadcastable to that of `predictions`; - if not provided then it is assumed to be a vector of zeros. - - Returns: - the log-cosh loss, with same shape as `predictions`. - """ - chex.assert_type([predictions], float) - errors = (predictions - targets) if (targets is not None) else predictions - # log(cosh(x)) = log((exp(x) + exp(-x))/2) = log(exp(x) + exp(-x)) - log(2) - return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) - - -def ctc_loss_with_forward_probs( - logits: chex.Array, - logit_paddings: chex.Array, - labels: chex.Array, - label_paddings: chex.Array, - blank_id: int = 0, - log_epsilon: float = -1e5) -> Tuple[chex.Array, chex.Array, chex.Array]: - r"""Computes CTC loss and CTC forward-probabilities. - - The CTC loss is a loss function based on log-likelihoods of the model that - introduces a special blank symbol :math:`\phi` to represent variable-length - output sequences. - - Forward probabilities returned by this function, as auxiliary results, are - grouped into two part: blank alpha-probability and non-blank alpha - probability. Those are defined as follows: - - .. math:: - \alpha_{\mathrm{BLANK}}(t, n) = - \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ - \alpha_{\mathrm{LABEL}}(t, n) = - \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). - - Here, :math:`\pi` denotes the alignment sequence in the reference - [Graves et al, 2006] that is blank-inserted representations of ``labels``. - The return values are the logarithms of the above probabilities. - - References: - [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) - - Args: - logits: (B, T, K)-array containing logits of each class where B denotes - the batch size, T denotes the max time frames in ``logits``, and K - denotes the number of classes including a class for blanks. - logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each - element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` - denotes that ``logits[b, t, :]`` are padded values. - labels: (B, N)-array containing reference integer labels where N denotes - the max time frames in the label sequence. - label_paddings: (B, N)-array. Padding indicators for ``labels``. Each - element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` - denotes that ``labels[b, n]`` is a padded label. In the current - implementation, ``labels`` must be right-padded, i.e. each row - ``labelpaddings[b, :]`` must be repetition of zeroes, followed by - repetition of ones. - blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as - probabilities of blank symbols. - log_epsilon: Numerically-stable approximation of log(+0). - - Returns: - A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, - ``loss_value`` is a (B,)-array containing the loss values for each sequence - in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are - (T, B, N+1)-arrays where the (t, b, n)-th element denotes - \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th - sequence in the batch. - """ - - chex.assert_rank(logits, 3) - chex.assert_rank(labels, 2) - batchsize, unused_maxinputlen, num_classes = logits.shape - batchsize_of_labels, maxlabellen = labels.shape - chex.assert_equal(batchsize, batchsize_of_labels) - chex.assert_equal(labels.shape, label_paddings.shape) - chex.assert_equal(logits.shape[:2], logit_paddings.shape) - - logprobs = jax.nn.log_softmax(logits) - labellens = maxlabellen - jnp.sum(label_paddings, axis=1).astype(jnp.int32) - - # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. - repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) - repeat = jnp.pad(repeat, ((0, 0), (0, 1))) - - logprobs_phi = logprobs[:, :, blank_id:blank_id + 1] # [B, T, 1] - logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] - - one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K] - logprobs_emit = jnp.einsum('btk,bnk->btn', logprobs, one_hot) - logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] - - logalpha_phi_init = jnp.ones( - (batchsize, maxlabellen + 1)) * log_epsilon # [B, N] - logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) - logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon - - def update_phi_score(phi, added_score): - # Update `phi[:, 1:]`` with adding `added_score` in log space. - return jnp.concatenate( - [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1) - - def loop_body(prev, x): - prev_phi, prev_emit = prev - # emit-to-phi epsilon transition, except if the next label is repetition - prev_phi_orig = prev_phi - prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) - - logprob_emit, logprob_phi, pad = x - - # phi-to-emit transition - next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, - prev_emit + logprob_emit) - # self-loop transition - next_phi = prev_phi + logprob_phi - # emit-to-phi blank transition only when the next label is repetition - next_phi = update_phi_score( - next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) - - pad = pad.reshape((batchsize, 1)) - next_emit = pad * prev_emit + (1.0 - pad) * next_emit - next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi - - return (next_phi, next_emit), (next_phi, next_emit) - - xs = (logprobs_emit, logprobs_phi, logit_paddings.transpose((1, 0))) - _, (logalpha_phi, - logalpha_emit) = jax.lax.scan(loop_body, - (logalpha_phi_init, logalpha_emit_init), xs) - - # last row needs to be updated with the last epsilon transition - logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) - logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) - - # extract per_seq_loss - one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1] - per_seq_loss = -jnp.einsum('bn,bn->b', logalpha_phi_last, one_hot) - - return per_seq_loss, logalpha_phi, logalpha_emit - - -def ctc_loss(logits: chex.Array, - logit_paddings: chex.Array, - labels: chex.Array, - label_paddings: chex.Array, - blank_id: int = 0, - log_epsilon: float = -1e5) -> chex.Array: - """Computes CTC loss. - - See docstring for ``ctc_loss_with_forward_probs`` for details. - - Args: - logits: (B, T, K)-array containing logits of each class where B denotes - the batch size, T denotes the max time frames in ``logits``, and K - denotes the number of classes including a class for blanks. - logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each - element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` - denotes that ``logits[b, t, :]`` are padded values. - labels: (B, N)-array containing reference integer labels where N denotes - the max time frames in the label sequence. - label_paddings: (B, N)-array. Padding indicators for ``labels``. Each - element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` - denotes that ``labels[b, n]`` is a padded label. In the current - implementation, ``labels`` must be right-padded, i.e. each row - ``labelpaddings[b, :]`` must be repetition of zeroes, followed by - repetition of ones. - blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as - probabilities of blank symbols. - log_epsilon: Numerically-stable approximation of log(+0). - - Returns: - (B,)-array containing loss values for each sequence in the batch. - """ - per_seq_loss, _, _ = ctc_loss_with_forward_probs( - logits, logit_paddings, labels, label_paddings, - blank_id=blank_id, log_epsilon=log_epsilon) - return per_seq_loss - - -def kl_divergence(log_predictions: chex.Array, - targets: chex.Array) -> chex.Array: - """Computes the Kullback-Leibler divergence (relative entropy) loss. - - Measures the information gain achieved if target probability distribution - would be used instead of predicted probability distribution. - - References: - [Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703) - - Args: - log_predictions: Probabilities of predicted distribution with shape - [..., dim]. Expected to be in the log-space to avoid underflow. - targets: Probabilities of target distribution with shape [..., dim]. - Expected to be strictly positive. - - Returns: - Kullback-Leibler divergence of predicted distribution from target - distribution with shape [...]. - """ - chex.assert_type([log_predictions, targets], float) - loss = targets * (jnp.log(targets) - log_predictions) - return jnp.sum(loss, axis=-1) - - -def kl_divergence_with_log_targets(log_predictions: chex.Array, - log_targets: chex.Array) -> chex.Array: - """Computes the Kullback-Leibler divergence (relative entropy) loss. - - Version of kl_div_loss where targets are given in log-space. - - Args: - log_predictions: Probabilities of predicted distribution with shape - [..., dim]. Expected to be in the log-space to avoid underflow. - log_targets: Probabilities of target distribution with shape [..., dim]. - Expected to be in the log-space. - - Returns: - Kullback-Leibler divergence of predicted distribution from target - distribution with shape [...]. - """ - chex.assert_type([log_predictions, log_targets], float) - loss = jnp.exp(log_targets) * (log_targets - log_predictions) - return jnp.sum(loss, axis=-1) - - -def hinge_loss(predictor_outputs: chex.Array, - targets: chex.Array) -> chex.Array: - """Computes the hinge loss for binary classification. - - Args: - predictor_outputs: Outputs of the decision function. - targets: Target values. Target values should be strictly in the set {-1, 1}. - - Returns: - Binary Hinge Loss. - """ - return jnp.maximum(0, 1 - predictor_outputs * targets) diff --git a/optax_add_eve/_src/loss_test.py b/optax_add_eve/_src/loss_test.py deleted file mode 100644 index dd183177..00000000 --- a/optax_add_eve/_src/loss_test.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for optax._src.loss.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import loss - - -class L2LossTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ys = jnp.array([-2., -1., 0.5, 1.]) - self.ts = jnp.array([-1.5, 0., -1, 1.]) - # compute expected outputs in numpy. - self.exp = 0.5 * (self.ts - self.ys) ** 2 - - @chex.all_variants - def test_scalar(self): - np.testing.assert_allclose( - self.variant(loss.l2_loss)(self.ys[0], self.ts[0]), self.exp[0]) - - @chex.all_variants - def test_batched(self): - np.testing.assert_allclose( - self.variant(loss.l2_loss)(self.ys, self.ts), self.exp) - - @chex.all_variants - def test_shape_mismatch(self): - with self.assertRaises(AssertionError): - _ = self.variant(loss.l2_loss)(self.ys, jnp.expand_dims(self.ts, axis=-1)) - - -class HuberLossTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ys = np.array([-2.0, 0.5, 0., 0.5, 2.0, 4.0, 132.]) - self.ts = np.array([0.0, -0.5, 0., 1., 1.0, 2.0, 0.3]) - # computed expected outputs manually. - self.exp = np.array([1.5, 0.5, 0., 0.125, 0.5, 1.5, 131.2]) - - @chex.all_variants - def test_scalar(self): - np.testing.assert_allclose( - self.variant(loss.huber_loss)(self.ys[0], self.ts[0], delta=1.0), - self.exp[0]) - - @chex.all_variants - def test_batched(self): - np.testing.assert_allclose( - self.variant(loss.huber_loss)(self.ys, self.ts, delta=1.0), - self.exp) - - -class SmoothLabelsTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ts = np.array([[0., 1., 0.], [1., 0., 0.]], dtype=np.float32) - # compute expected outputs in numpy. - self.exp_alpha_zero = self.ts - self.exp_alpha_zero_point_one = 0.9 * self.ts + 0.1 / self.ts.shape[-1] - self.exp_alpha_one = jnp.ones_like(self.ts) / self.ts.shape[-1] - - @chex.all_variants - def test_scalar(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.smooth_labels)(self.ts[0], 0.), - self.exp_alpha_zero[0], atol=1e-4) - np.testing.assert_allclose( - self.variant(loss.smooth_labels)(self.ts[0], 0.1), - self.exp_alpha_zero_point_one[0], atol=1e-4) - np.testing.assert_allclose( - self.variant(loss.smooth_labels)(self.ts[0], 1.), - self.exp_alpha_one[0], atol=1e-4) - - @chex.all_variants - def test_batched(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.smooth_labels)(self.ts, 0.), - self.exp_alpha_zero, atol=1e-4) - np.testing.assert_allclose( - self.variant(loss.smooth_labels)(self.ts, 0.1), - self.exp_alpha_zero_point_one, atol=1e-4) - np.testing.assert_allclose( - self.variant(loss.smooth_labels)(self.ts, 1.), - self.exp_alpha_one, atol=1e-4) - - -class SoftmaxCrossEntropyTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ys = np.array([[10., 1., -2.], [1., 4., 0.2]], dtype=np.float32) - self.ts = np.array([[0., 1., 0.], [1., 0., 0.]], dtype=np.float32) - # taken expected outputs from rlax. - self.exp = np.array([9.00013, 3.0696733], dtype=np.float32) - - @chex.all_variants - def test_scalar(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.softmax_cross_entropy)(self.ys[0], self.ts[0]), - self.exp[0], atol=1e-4) - - @chex.all_variants - def test_batched(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.softmax_cross_entropy)(self.ys, self.ts), - self.exp, atol=1e-4) - - -class SoftmaxCrossEntropyWithIntegerLabelsTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ys = np.array([[10., 1., -2.], [1., 4., 0.2]], dtype=np.float32) - self.ts = np.array([1, 0], dtype=np.int32) - - @chex.all_variants - def test_consistent_with_softmax_cross_entropy_scalar(self): - """Tests for a scalar.""" - exp = loss.softmax_cross_entropy(self.ys[0], jax.nn.one_hot(self.ts[0], 3)) - np.testing.assert_allclose( - self.variant(loss.softmax_cross_entropy_with_integer_labels)( - self.ys[0], self.ts[0]), - exp, rtol=1e-6) - - @chex.all_variants - def test_consistent_with_softmax_cross_entropy_batched(self): - """Tests for a full batch.""" - exp = loss.softmax_cross_entropy(self.ys, jax.nn.one_hot(self.ts, 3)) - np.testing.assert_allclose( - self.variant(loss.softmax_cross_entropy_with_integer_labels)( - self.ys, self.ts), - exp, rtol=1e-6) - - -class SigmoidCrossEntropyTest(parameterized.TestCase): - - @parameterized.parameters( - dict(preds=np.array([-1e+09, -1e-09]), - labels=np.array([1., 0.]), - expected=5e+08), - dict(preds=np.array([-1e+09, -1e-09]), - labels=np.array([0., 1.]), - expected=0.3465736), - dict(preds=np.array([1e+09, 1e-09]), - labels=np.array([1., 0.]), - expected=0.3465736), - dict(preds=np.array([1e+09, 1e-09]), - labels=np.array([0., 1.]), - expected=5e+08), - dict(preds=np.array([-1e+09, 1e-09]), - labels=np.array([1., 0.]), - expected=5e+08), - dict(preds=np.array([-1e+09, 1e-09]), - labels=np.array([0., 1.]), - expected=0.3465736), - dict(preds=np.array([1e+09, -1e-09]), - labels=np.array([1., 0.]), - expected=0.3465736), - dict(preds=np.array([1e+09, -1e-09]), - labels=np.array([0., 1.]), - expected=5e+08), - dict(preds=np.array([0., 0.]), - labels=np.array([1., 0.]), - expected=0.6931472), - dict(preds=np.array([0., 0.]), - labels=np.array([0., 1.]), - expected=0.6931472), - ) - def testSigmoidCrossEntropy(self, preds, labels, expected): - tested = jnp.mean(loss.sigmoid_binary_cross_entropy(preds, labels)) - np.testing.assert_allclose(tested, expected, rtol=1e-6, atol=1e-6) - - -class CosineDistanceTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ys = np.array([[10., 1., -2.], [1., 4., 0.2]], dtype=np.float32) - self.ts = np.array([[0., 1.2, 0.2], [1., -0.3, 0.]], dtype=np.float32) - # distance computed expected output from `scipy 1.20`. - self.exp = np.array([0.9358251989, 1.0464068465], dtype=np.float32) - - @chex.all_variants - def test_scalar_distance(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.cosine_distance)(self.ys[0], self.ts[0]), - self.exp[0], atol=1e-4) - - @chex.all_variants - def test_scalar_similarity(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.cosine_similarity)(self.ys[0], self.ts[0]), - 1. - self.exp[0], atol=1e-4) - - @chex.all_variants - def test_batched_distance(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.cosine_distance)(self.ys, self.ts), - self.exp, atol=1e-4) - - @chex.all_variants - def test_batched_similarity(self): - """Tests for a full batch.""" - np.testing.assert_allclose( - self.variant(loss.cosine_similarity)(self.ys, self.ts), - 1. - self.exp, atol=1e-4) - - -# TODO(b/188419459): add test for grad and second order grad. -class LogCoshTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - # Test large values for overflow - self.ys = jnp.array([500, -2., -1., 0.5, 1.]) - self.ts = jnp.array([-200, -1.5, 0., -1, 1.]) - # computed using tensorflow.keras.losses.log_cosh v2.4.1 - self.exp = jnp.array([699.3068, 0.12011445, 0.4337809, 0.85544014, 0.]) - self.exp_ys_only = jnp.array( - [499.30685, 1.3250027, 0.4337809, 0.12011451, 0.43378082]) - - @chex.all_variants - def test_scalar(self): - out = self.variant(loss.log_cosh)(self.ys[0], self.ts[0]) - np.testing.assert_allclose(out, self.exp[0], atol=1e-5) - - @chex.all_variants - def test_batched(self): - out = self.variant(loss.log_cosh)(self.ys, self.ts) - np.testing.assert_allclose(out, self.exp, atol=1e-5) - - @chex.all_variants - def test_scalar_predictions_only(self): - out = self.variant(loss.log_cosh)(self.ys[0]) - np.testing.assert_allclose(out, self.exp_ys_only[0], atol=1e-5) - - @chex.all_variants - def test_batched_predictions_only(self): - out = self.variant(loss.log_cosh)(self.ys) - np.testing.assert_allclose(out, self.exp_ys_only, atol=1e-5) - - -def _lengths_to_paddings(lengths: chex.Array, maxlength: int) -> chex.Array: - indices = jnp.arange(maxlength).reshape((1,) * lengths.ndim + (maxlength,)) - lengths = jnp.expand_dims(lengths, axis=-1) - elem_valid = indices < lengths - return np.logical_not(elem_valid).astype(np.float32) - - -def _average_ctc_loss(logprobs: chex.Array, logprob_paddings: chex.Array, - labels: chex.Array, - label_paddings: chex.Array) -> chex.Array: - return jnp.average( - loss.ctc_loss(logprobs, logprob_paddings, labels, label_paddings)) - - -class CTCTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - np.random.seed(1234) - self._rtol = 5e-3 if jax.default_backend() != 'cpu' else 1e-6 - - @chex.all_variants - def test_with_one_to_one_alignment(self): - # when inputsteps and outputsteps are equal, no blank will be allowed. - batchsize = 8 - steps = 50 - nclasses = 40 - logits = np.random.randn(batchsize, steps, nclasses) - labels = np.random.uniform( - 1, nclasses, size=(batchsize, steps)).astype(np.int32) - - # This function only covers the cases without same-label repetition. - # `test_repeat_with_one_to_one_alignment` below complements those cases. - # So, redraw the samples for satisfying the non-repetition constraint. - for n in range(labels.shape[0]): - for t in range(1, labels.shape[1]): - while labels[n, t] == labels[n, t - 1]: - labels[n, t] = np.random.uniform(1, nclasses) - - results = self.variant(loss.ctc_loss_with_forward_probs)( - logits, np.zeros(logits.shape[:2]), - labels, np.zeros(labels.shape)) - (per_seq_loss, logalpha_blank, logalpha_emit) = results - - logprobs = jax.nn.log_softmax(logits) - for b in range(batchsize): - p = 0.0 - for t in range(steps): - p += logprobs[b, t, labels[b, t]] - np.testing.assert_allclose( - np.array(-p), per_seq_loss[b], rtol=self._rtol) - - # Check forward-probabilities. - # 1. All-phi path: logalpha_blank[-1, b, 0] must be a probability of - # the path that outputs blank symbols for all the frames. - np.testing.assert_allclose(logalpha_blank[-1, b, 0], - np.sum(logprobs[b, :, 0]), - rtol=self._rtol) - - # 2. After emitting all the labels - # the negated loss must be identical with the forward probability of - # paths after consuming all the labels (because one-to-one alignment - # doesn't allow extra blank symbols) - np.testing.assert_allclose(logalpha_emit[-1, b, steps - 1], - -per_seq_loss[b], - rtol=self._rtol) - # and, this forward probability must be copied to the blank forward - # probability of the next step. - np.testing.assert_allclose(logalpha_blank[-1, b, steps], - -per_seq_loss[b], - rtol=self._rtol) - - @chex.all_variants - def test_with_one_to_one_alignment_and_paddings(self): - batch_size = 5 - nclasses = 13 - steps = 7 - logits = np.random.normal(size=[batch_size, steps, nclasses]) - logprobs = jax.nn.log_softmax(logits) - - labels = [] - for n in range(batch_size): - row = list(range(1, nclasses)) - np.random.shuffle(row) - labels.append(row[:steps]) - labels = np.array(labels) - - lengths = np.random.randint(3, 6, size=(batch_size,)) - paddings = _lengths_to_paddings(lengths, steps) - - actual_loss = self.variant(loss.ctc_loss)(logits, paddings, labels, - paddings) - - value_and_grad = self.variant(jax.value_and_grad(_average_ctc_loss)) - unused_avg_loss, actual_gradients = value_and_grad(logits, paddings, labels, - paddings) - - for n in range(batch_size): - expected_loss = -sum(logprobs[n, t, k] - for t, k in enumerate(labels[n, :lengths[n]])) - np.testing.assert_allclose(expected_loss, actual_loss[n], rtol=self._rtol) - - expected_gradients = np.array(jax.nn.softmax(logits[n])) - expected_gradients[lengths[n]:] = 0.0 - for t, k in enumerate(labels[n, :lengths[n]]): - expected_gradients[t, k] -= 1.0 - expected_gradients /= batch_size - np.testing.assert_allclose( - expected_gradients, actual_gradients[n], rtol=self._rtol) - - @chex.all_variants - def test_repeat_with_one_to_one_alignment(self): - # test if it can correctly handle the same-label repetition. - nclasses = 5 - labels = np.array([ - [1, 2, 2, 3], - [2, 3, 4, 4], - [1, 1, 1, 1], - [1, 1, 2, 3], - [1, 1, 1, 2], - ]) - expected_alignment = [ # expected minimal alignment - [1, 2, 0, 2, 3], - [2, 3, 4, 0, 4], - [1, 0, 1, 0, 1, 0, 1], - [1, 0, 1, 2, 3], - [1, 0, 1, 0, 1, 2], - ] - batch_size = len(labels) - label_lens = np.array([4] * batch_size) - label_steps = 6 - # Designed to have two padding elements on the right. - labels = np.pad(labels, [(0, 0), (0, label_steps - labels.shape[1])]) - label_paddings = _lengths_to_paddings(label_lens, label_steps) - - logit_lengths = np.array([len(seq) for seq in expected_alignment]) - logit_steps = max(logit_lengths) - logits = np.random.randn(batch_size, logit_steps, nclasses) - logit_paddings = _lengths_to_paddings(logit_lengths, logit_steps) - - per_seq_loss = self.variant(loss.ctc_loss)(logits, logit_paddings, labels, - label_paddings) - - logprobs = jax.nn.log_softmax(logits) - for n in range(batch_size): - expected_loss = -sum(logprobs[n, t, k] - for t, k in enumerate(expected_alignment[n])) - np.testing.assert_allclose( - jnp.array(expected_loss), per_seq_loss[n], rtol=self._rtol) - - -class KLDivergenceTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.log_ps = np.array( - [[-2.9957, -3.5066, -3.9120, -1.2040, -0.6931, -2.3026], - [-1.6094, -1.6094, -1.6094, -2.3026, -1.8971, -1.8971]]) - self.qs = np.array([[0.2, 0.2, 0.2, 0.1, 0.15, 0.15], - [0.05, 0.03, 0.02, 0.3, 0.5, 0.1]]) - # Computed kullback-leibler divergence of P from Q. - self.exp = np.array([0.8875625, 0.7187435584901326]) - - @chex.all_variants - def test_scalar(self): - np.testing.assert_allclose( - self.variant(loss.kl_divergence)(self.log_ps[0], self.qs[0]), - self.exp[0], - atol=1e-4) - - @chex.all_variants - def test_batched(self): - np.testing.assert_allclose( - self.variant(loss.kl_divergence)(self.log_ps, self.qs), - self.exp, - atol=1e-4) - - -class KLDivergenceWithLogTargetsTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.log_ps = np.array( - [[-2.9957, -3.5066, -3.9120, -1.2040, -0.6931, -2.3026], - [-1.6094, -1.6094, -1.6094, -2.3026, -1.8971, -1.8971]]) - self.qs = np.array([[-1.6094, -1.6094, -1.6094, -2.3026, -1.8971, -1.8971], - [-2.9957, -3.5066, -3.9120, -1.2040, -0.6931, -2.3026]]) - # Computed kullback-leibler divergence of P from Q. - self.exp = np.array([0.8875625, 0.7187435584901326]) - - @chex.all_variants - def test_scalar(self): - np.testing.assert_allclose( - self.variant(loss.kl_divergence_with_log_targets)(self.log_ps[0], - self.qs[0]), - self.exp[0], - atol=1e-4) - - @chex.all_variants - def test_batched(self): - np.testing.assert_allclose( - self.variant(loss.kl_divergence_with_log_targets)(self.log_ps, self.qs), - self.exp, - atol=1e-4) - - -class HingeLossTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.ys = np.array([ - -0.97740268, -1.01812625, -0.81675726, -0.73605974, 2.08235648, - 1.84101354, -1.0581002 - ]) - self.ts = np.array([-1, -1, -1, -1, 1, 1, -1]) - # Computed expected outputs. - self.correct_result = np.array( - [0.02259731, 0., 0.18324274, 0.26394027, 0., 0., 0.]) - - @chex.all_variants - def test_batched(self): - np.testing.assert_allclose( - self.variant(loss.hinge_loss)(self.ys, self.ts), - self.correct_result, - atol=1e-4) - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/numerics_test.py b/optax_add_eve/_src/numerics_test.py deleted file mode 100644 index 89c7a706..00000000 --- a/optax_add_eve/_src/numerics_test.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for optax._src.numerics.""" - -import functools -import itertools -import re - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import numerics - -_ALL_ORDS = [None, np.inf, -np.inf, 'fro', 'nuc', 0, 1, 2, -2, -2, -1.5, 1.5] - -int32_array = lambda i: jnp.array(i, dtype=jnp.int32) -float32_array = lambda i: jnp.array(i, dtype=jnp.float32) - - -def _invalid_ord_axis_inputs(ord_axis_keepdims): - ord_, axis = ord_axis_keepdims[0], ord_axis_keepdims[1] - return any(((ord_ == 0 and axis is None), - (isinstance(ord_, float) and axis is None), - (isinstance(ord_, str) and axis is not None))) - - -class NumericsTest(chex.TestCase): - - @chex.all_variants - def test_safe_int32_increments(self): - inc_fn = self.variant(numerics.safe_int32_increment) - # increment small numbers correctly. - base = int32_array(3) - incremented = inc_fn(base) - np.testing.assert_array_equal(incremented, int32_array(4)) - # avoid overflow when incrementing maxint. - base = int32_array(np.iinfo(np.int32).max) - incremented = inc_fn(base) - np.testing.assert_array_equal(incremented, base) - - @chex.all_variants - @parameterized.parameters( - itertools.filterfalse( - _invalid_ord_axis_inputs, - itertools.product(_ALL_ORDS, [None, 0, 1], [False, True]))) - def test_safe_norm(self, ord, axis, keepdims): # pylint: disable=redefined-builtin - dnorm_dx = self.variant( - jax.jacfwd( - functools.partial( - numerics.safe_norm, ord=ord, axis=axis, keepdims=keepdims), - argnums=0)) - # Test gradient is 0. in 0. when zero min norm is used. - g = dnorm_dx(float32_array(jnp.zeros((3, 4))), float32_array(0.)) - np.testing.assert_array_equal(g, jnp.zeros_like(g)) - # Test gradient is 0. in 0. when non zero min norm is used. - g = dnorm_dx(float32_array(jnp.zeros((3, 4))), float32_array(3.)) - np.testing.assert_array_equal(g, jnp.zeros_like(g)) - - @chex.all_variants - def test_safe_rms(self): - drms_dx = self.variant(jax.grad(numerics.safe_root_mean_squares)) - # Test gradient is 0. in 0. when zero min rms is used. - g = drms_dx(float32_array(0.), float32_array(0.)) - np.testing.assert_array_equal(g, jnp.zeros_like(g)) - # Test gradient is 0. in 0. when non zero min rms is used. - g = drms_dx(float32_array(0.), float32_array(3.)) - np.testing.assert_array_equal(g, jnp.zeros_like(g)) - - def test_complex_vs_real_abs_sqr(self): - # Tests that JAX generates the same HLO from `numerics.abs_sq`, - # `jnp.square(x)`, `x * x`, and `x**2`. - real_sq_fns = (lambda x: x**2, lambda x: x * x, jnp.square) - - def _get_hlo_repr(f, x): - hlo_string = jax.xla_computation(f)(x).as_hlo_text() - return re.sub('HloModule.*?\n', '', - re.sub('ENTRY.*?{', 'ENTRY XXXX', hlo_string)) - - # Real arg (same HLO). - for real_sq_fn in real_sq_fns: - for real_x in (3, 3.0, np.array([4, 5.2])): - self.assertEqual( - _get_hlo_repr(real_sq_fn, real_x), - _get_hlo_repr(numerics.abs_sq, real_x)) - - # Complex arg (different HLOs). - for real_sq_fn in real_sq_fns: - for complex_x in (1j, 3. + 1j, np.array([4 + 1j, 5.2 + 1j])): - self.assertNotEqual( - _get_hlo_repr(real_sq_fn, complex_x), - _get_hlo_repr(numerics.abs_sq, complex_x)) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/privacy.py b/optax_add_eve/_src/privacy.py deleted file mode 100644 index 78c58210..00000000 --- a/optax_add_eve/_src/privacy.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Differential Privacy utilities.""" - -from typing import NamedTuple - -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base -from optax_add_eve._src import clipping - - -# pylint:disable=no-value-for-parameter -class DifferentiallyPrivateAggregateState(NamedTuple): - """State containing PRNGKey for `differentially_private_aggregate`.""" - rng_key: jnp.array - - -def differentially_private_aggregate( - l2_norm_clip: float, - noise_multiplier: float, - seed: int -) -> base.GradientTransformation: - """Aggregates gradients based on the DPSGD algorithm. - - WARNING: Unlike other transforms, `differentially_private_aggregate` expects - the input updates to have a batch dimension in the 0th axis. That is, this - function expects per-example gradients as input (which are easy to obtain in - JAX using `jax.vmap`). It can still be composed with other transformations as - long as it is the first in the chain. - - References: - [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) - - Args: - l2_norm_clip: maximum L2 norm of the per-example gradients. - noise_multiplier: ratio of standard deviation to the clipping norm. - seed: initial seed used for the jax.random.PRNGKey - - Returns: - A `GradientTransformation`. - """ - noise_std = l2_norm_clip * noise_multiplier - - def init_fn(params): - del params - return DifferentiallyPrivateAggregateState(rng_key=jax.random.PRNGKey(seed)) - - def update_fn(updates, state, params=None): - del params - grads_flat, grads_treedef = jax.tree_util.tree_flatten(updates) - bsize = grads_flat[0].shape[0] - clipped, _ = clipping.per_example_global_norm_clip(grads_flat, l2_norm_clip) - - new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1) - noised = [(g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize - for g, r in zip(clipped, rngs)] - return (jax.tree_util.tree_unflatten(grads_treedef, noised), - DifferentiallyPrivateAggregateState(rng_key=new_key)) - - return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/privacy_test.py b/optax_add_eve/_src/privacy_test.py deleted file mode 100644 index 82455063..00000000 --- a/optax_add_eve/_src/privacy_test.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `privacy.py`.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import privacy - - -class DifferentiallyPrivateAggregateTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.batch_size = 8 - self.params = {'key_a': (jnp.zeros((2, 3, 4)), jnp.zeros([])), - 'key_b': jnp.zeros((6, 7))} - # Example `i`'s grads are full of `i`s. Important to include 0 to ensure - # there are no divisons by 0 (e.g. in norm clipping) - a = jnp.arange(self.batch_size) - self.per_eg_grads = jax.tree_util.tree_map( - lambda p: jnp.moveaxis(a * jnp.ones(p.shape+(self.batch_size,)), -1, 0), - self.params) - - @chex.all_variants - def test_no_privacy(self): - """l2_norm_clip=MAX_FLOAT32 and noise_multiplier=0 should recover SGD.""" - dp_agg = privacy.differentially_private_aggregate( - l2_norm_clip=jnp.finfo(jnp.float32).max, - noise_multiplier=0., - seed=0) - state = dp_agg.init(self.params) - update_fn = self.variant(dp_agg.update) - mean_grads = jax.tree_util.tree_map(lambda g: g.mean(0), self.per_eg_grads) - - for _ in range(3): - updates, state = update_fn(self.per_eg_grads, state) - chex.assert_trees_all_close(updates, mean_grads) - - @chex.all_variants - @parameterized.parameters(0.5, 10.0, 20.0, 40.0, 80.0) - def test_clipping_norm(self, l2_norm_clip): - dp_agg = privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, - noise_multiplier=0., - seed=42) - state = dp_agg.init(self.params) - update_fn = self.variant(dp_agg.update) - - # Shape of the three arrays below is (self.batch_size, ) - norms = [jnp.linalg.norm(g.reshape(self.batch_size, -1), axis=1) - for g in jax.tree_util.tree_leaves(self.per_eg_grads)] - global_norms = jnp.linalg.norm(jnp.stack(norms), axis=0) - divisors = jnp.maximum(global_norms / l2_norm_clip, 1.) - # Since the values of all the parameters are the same within each example, - # we can easily compute what the values should be: - expected_val = jnp.mean(jnp.arange(self.batch_size) / divisors) - expected_tree = jax.tree_util.tree_map( - lambda p: jnp.broadcast_to(expected_val, p.shape), self.params) - - for _ in range(3): - updates, state = update_fn(self.per_eg_grads, state, self.params) - chex.assert_trees_all_close(updates, expected_tree, rtol=2e-7) - - @chex.all_variants - @parameterized.parameters((3.0, 2.0), (1.0, 5.0), (100.0, 4.0), (1.0, 90.0)) - def test_noise_multiplier(self, l2_norm_clip, noise_multiplier): - """Standard dev. of noise should be l2_norm_clip * noise_multiplier.""" - dp_agg = privacy.differentially_private_aggregate( - l2_norm_clip=l2_norm_clip, - noise_multiplier=noise_multiplier, - seed=1337) - state = dp_agg.init(None) - update_fn = self.variant(dp_agg.update) - expected_std = l2_norm_clip * noise_multiplier - - grads = [jnp.ones((1, 100, 100))] # batch size 1 - for _ in range(3): - updates, state = update_fn(grads, state) - chex.assert_trees_all_close(expected_std, - jnp.std(updates[0]), - atol=0.1 * expected_std) - - def test_aggregated_updates_as_input_fails(self): - """Expect per-example gradients as input to this transform.""" - dp_agg = privacy.differentially_private_aggregate(l2_norm_clip=0.1, - noise_multiplier=1.1, - seed=2021) - state = dp_agg.init(self.params) - mean_grads = jax.tree_util.tree_map(lambda g: g.mean(0), self.per_eg_grads) - with self.assertRaises(ValueError): - dp_agg.update(mean_grads, state, self.params) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/schedule.py b/optax_add_eve/_src/schedule.py deleted file mode 100644 index 4fcdca7d..00000000 --- a/optax_add_eve/_src/schedule.py +++ /dev/null @@ -1,620 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""JAX Schedules. - -Schedules may be used to anneal the value of a hyper-parameter over time; for -instance, they may be used to anneal the learning rate used to update an agent's -parameters or the exploration factor used to select actions. -""" - -import functools -import inspect -from typing import Callable, Dict, Union, NamedTuple, Optional, Iterable, Sequence - -from absl import logging -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base -from optax_add_eve._src import numerics - - -def constant_schedule( - value: Union[float, int] -) -> base.Schedule: - """Constructs a constant schedule. - - Args: - value: value to be held constant throughout. - - Returns: - schedule: A function that maps step counts to values. - """ - return lambda count: value - - -def polynomial_schedule( - init_value: chex.Scalar, - end_value: chex.Scalar, - power: chex.Scalar, - transition_steps: int, - transition_begin: int = 0 -) -> base.Schedule: - """Constructs a schedule with polynomial transition from init to end value. - - Args: - init_value: initial value for the scalar to be annealed. - end_value: end value of the scalar to be annealed. - power: the power of the polynomial used to transition from init to end. - transition_steps: number of steps over which annealing takes place, - the scalar starts changing at `transition_begin` steps and completes - the transition by `transition_begin + transition_steps` steps. - If `transition_steps <= 0`, then the entire annealing process is disabled - and the value is held fixed at `init_value`. - transition_begin: must be positive. After how many steps to start annealing - (before this many steps the scalar value is held fixed at `init_value`). - - Returns: - schedule: A function that maps step counts to values. - """ - if transition_steps <= 0: - logging.info( - 'A polynomial schedule was set with a non-positive `transition_steps` ' - 'value; this results in a constant schedule with value `init_value`.') - return lambda count: init_value - - if transition_begin < 0: - logging.info( - 'An exponential schedule was set with a negative `transition_begin` ' - 'value; this will result in `transition_begin` falling back to `0`.') - transition_begin = 0 - - def schedule(count): - count = jnp.clip(count - transition_begin, 0, transition_steps) - frac = 1 - count / transition_steps - return (init_value - end_value) * (frac**power) + end_value - return schedule - - -# Alias polynomial schedule to linear schedule for convenience. -def linear_schedule( - init_value: chex.Scalar, - end_value: chex.Scalar, - transition_steps: int, - transition_begin: int = 0 -) -> base.Schedule: - return polynomial_schedule( - init_value=init_value, end_value=end_value, power=1, - transition_steps=transition_steps, transition_begin=transition_begin) - - -def piecewise_constant_schedule( - init_value: float, - boundaries_and_scales: Optional[Dict[int, float]] = None -) -> base.Schedule: - """Returns a function which implements a piecewise constant schedule. - - Args: - init_value: An initial value `init_v`. - boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling - factors `f_i`. For any step count `s`, the schedule returns `init_v` - scaled by the product of all factors `f_i` such that `b_i` < `s`. - - Returns: - schedule: A function that maps step counts to values. - """ - if boundaries_and_scales is not None: - all_positive = all(scale >= 0. for scale in boundaries_and_scales.values()) - if not all_positive: - raise ValueError( - '`piecewise_constant_schedule` expects non-negative scale factors') - - def schedule(count): - v = init_value - if boundaries_and_scales is not None: - for threshold, scale in sorted(boundaries_and_scales.items()): - indicator = jnp.maximum(0., jnp.sign(threshold - count)) - v = v * indicator + (1 - indicator) * scale * v - return v - - return schedule - - -def exponential_decay( - init_value: float, - transition_steps: int, - decay_rate: float, - transition_begin: int = 0, - staircase: bool = False, - end_value: Optional[float] = None -) -> base.Schedule: - """Constructs a schedule with either continuous or discrete exponential decay. - - This function applies an exponential decay function to a provided initial - value. The function returns the decayed value as follows: - - ``` - decayed_value = init_value * decay_rate ^ (count / transition_steps) - ``` - - If the argument `staircase` is `True`, then `count / transition_steps` is - an integer division and the decayed value follows a staircase function. - - Args: - init_value: the initial learning rate. - transition_steps: must be positive. See the decay computation above. - decay_rate: must not be zero. The decay rate. - transition_begin: must be positive. After how many steps to start annealing - (before this many steps the scalar value is held fixed at `init_value`). - staircase: if `True`, decay the values at discrete intervals. - end_value: the value at which the exponential decay stops. When - `decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as - an upper bound. Has no effect when `decay_rate` = 0. - - Returns: - schedule: A function that maps step counts to values. - """ - - if transition_steps <= 0: - logging.info( - 'An exponential schedule was set with a non-positive `transition_steps`' - ' value; this will result in a constant schedule with value ' - '`init_value`.') - return lambda count: init_value - - if decay_rate == 0: - logging.info( - 'An exponential schedule was set with a zero `decay_rate` value; ' - 'this will result in a constant schedule with value `init_value`.') - return lambda count: init_value - - if transition_begin < 0: - logging.info( - 'An exponential schedule was set with a negative `transition_begin` ' - 'value; this will result in `transition_begin` falling back to `0`.') - transition_begin = 0 - - if end_value is not None: - clip_fn = jnp.maximum if decay_rate < 1.0 else jnp.minimum - - def schedule(count): - decreased_count = count - transition_begin - p = decreased_count / transition_steps - if staircase: - p = jnp.floor(p) - decayed_value = jnp.where( - decreased_count <= 0, init_value, init_value * jnp.power(decay_rate, p)) - if end_value is not None: - decayed_value = clip_fn(decayed_value, end_value) - return decayed_value - - return schedule - - -def cosine_decay_schedule( - init_value: float, - decay_steps: int, - alpha: float = 0.0 -) -> base.Schedule: - """Returns a function which implements cosine learning rate decay. - - The schedule does not restart when ``decay_steps`` has been reached. Instead, - the learning rate remains constant afterwards. For a cosine schedule with - restarts, :func:`optax.join_schedules` can be used to join several cosine - decay schedules. - - For more details see: https://arxiv.org/abs/1608.03983. - - Args: - init_value: An initial value `init_v`. - decay_steps: Positive integer - the number of steps for which to apply - the decay for. - alpha: Float. The minimum value of the multiplier used to adjust the - learning rate. - - Returns: - schedule: A function that maps step counts to values. - """ - if not decay_steps > 0: - raise ValueError('The cosine_decay_schedule requires positive decay_steps!') - - def schedule(count): - count = jnp.minimum(count, decay_steps) - cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * count / decay_steps)) - decayed = (1 - alpha) * cosine_decay + alpha - return init_value * decayed - - return schedule - - -def _linear_interpolate(start: float, end: float, pct: float): - return (end-start) * pct + start - - -def _cosine_interpolate(start: float, end: float, pct: float): - return end + (start-end) / 2.0 * (jnp.cos(jnp.pi * pct) + 1) - - -def piecewise_interpolate_schedule( - interpolate_type: str, - init_value: float, - boundaries_and_scales: Optional[Dict[int, float]] = None -) -> base.Schedule: - """Returns a function which implements a piecewise interpolated schedule. - - Args: - interpolate_type: 'linear' or 'cosine', specifying the interpolation - strategy. - init_value: An initial value `init_v`. - boundaries_and_scales: A map from boundaries `b_i` to non-negative scaling - factors `f_i`. At boundary step `b_i`, the schedule returns `init_v` - scaled by the product of all factors `f_j` such that `b_j` <= `b_i`. The - values in between each boundary will be interpolated as per `type`. - - Returns: - schedule: A function that maps step counts to values. - """ - if interpolate_type == 'linear': - interpolate_fn = _linear_interpolate - elif interpolate_type == 'cosine': - interpolate_fn = _cosine_interpolate - else: - raise ValueError('`interpolate_type` must be either \'cos\' or \'linear\'') - - if boundaries_and_scales: - boundaries, scales = zip(*sorted(boundaries_and_scales.items())) - if not all(scale >= 0. for scale in scales): - raise ValueError( - '`piecewise_interpolate_schedule` expects non-negative scale factors') - else: - boundaries, scales = (), () - - bounds = jnp.stack((0,) + boundaries) - values = jnp.cumprod(jnp.stack((init_value,) + scales)) - interval_sizes = (bounds[1:] - bounds[:-1]) - - def schedule(count): - indicator = (bounds[:-1] <= count) & (count < bounds[1:]) - pct = (count - bounds[:-1]) / interval_sizes - interp_vals = interpolate_fn(values[:-1], values[1:], pct) - return indicator.dot(interp_vals) + (bounds[-1] <= count) * values[-1] - - return schedule - - -def linear_onecycle_schedule( - transition_steps: int, - peak_value: float, - pct_start: float = 0.3, - pct_final: float = 0.85, - div_factor: float = 25.0, - final_div_factor: float = 1e4 -) -> base.Schedule: - """Returns a function which implements the onecycle learning rate schedule. - - This function uses a linear annealing strategy. - For more details see: https://arxiv.org/abs/1708.07120 - - Args: - transition_steps: Number of steps over which annealing takes place. - peak_value: Maximum value attained by schedule at pct_start percent - of the cycle (in number of steps). - pct_start: The percentage of the cycle (in number of steps) spent - increasing the learning rate. - pct_final: The percentage of the cycle (in number of steps) spent - increasing to peak_value then decreasing back to init_value. - div_factor: Determines the initial value via init_value = - peak_value / div_factor - final_div_factor: Determines the final value via final_value = - init_value / final_div_factor - - Returns: - schedule: A function that maps step counts to values. - """ - if transition_steps <= 0: - raise ValueError( - 'A linear onecycle schedule was set with a non-positive ' - '`transition_steps`') - - return piecewise_interpolate_schedule( - 'linear', - peak_value / div_factor, - {int(pct_start * transition_steps): div_factor, - int(pct_final * transition_steps): 1. / div_factor, - transition_steps: 1. / final_div_factor}) - - -def cosine_onecycle_schedule( - transition_steps: int, - peak_value: float, - pct_start: float = 0.3, - div_factor: float = 25.0, - final_div_factor: float = 1e4 -) -> base.Schedule: - """Returns a function which implements the onecycle learning rate schedule. - - This function uses a cosine annealing strategy. - For more details see: https://arxiv.org/abs/1708.07120 - - Args: - transition_steps: Number of steps over which annealing takes place. - peak_value: Maximum value attained by schedule at pct_start percent - of the cycle (in number of steps). - pct_start: The percentage of the cycle (in number of steps) spent - increasing the learning rate. - div_factor: Determines the initial value via init_value = - peak_value / div_factor - final_div_factor: Determines the final value via final_value = - init_value / final_div_factor - - Returns: - schedule: A function that maps step counts to values. - """ - if transition_steps <= 0: - raise ValueError( - 'A linear onecycle schedule was set with a non-positive ' - '`transition_steps`') - - return piecewise_interpolate_schedule( - 'cosine', - peak_value / div_factor, - {int(pct_start * transition_steps): div_factor, - int(transition_steps): 1. / (div_factor * final_div_factor)}) - - -def join_schedules(schedules: Sequence[base.Schedule], - boundaries: Sequence[int]) -> base.Schedule: - """Sequentially apply multiple schedules. - - Args: - schedules: A list of callables (expected to be optax schedules). Each - schedule will receive a step count indicating the number of steps since - the previous boundary transition. - boundaries: A list of integers (of length one less than schedules) that - indicate when to transition between schedules. - Returns: - schedule: A function that maps step counts to values. - """ - def schedule(step: jnp.DeviceArray) -> jnp.DeviceArray: - output = schedules[0](step) - for boundary, schedule in zip(boundaries, schedules[1:]): - output = jnp.where(step < boundary, output, schedule(step - boundary)) - return output - return schedule - - -def warmup_cosine_decay_schedule( - init_value: float, - peak_value: float, - warmup_steps: int, - decay_steps: int, - end_value: float = 0.0 -) -> base.Schedule: - """Linear warmup followed by cosine decay. - - Args: - init_value: Initial value for the scalar to be annealed. - peak_value: Peak value for scalar to be annealed at end of warmup. - warmup_steps: Positive integer, the length of the linear warmup. - decay_steps: Positive integer, the total length of the schedule. Note that - this includes the warmup time, so the number of steps during which cosine - annealing is applied is `decay_steps - warmup_steps`. - end_value: End value of the scalar to be annealed. - Returns: - schedule: A function that maps step counts to values. - """ - schedules = [ - linear_schedule( - init_value=init_value, - end_value=peak_value, - transition_steps=warmup_steps), - cosine_decay_schedule( - init_value=peak_value, - decay_steps=decay_steps - warmup_steps, - alpha=end_value/peak_value)] - return join_schedules(schedules, [warmup_steps]) - - -def warmup_exponential_decay_schedule( - init_value: float, - peak_value: float, - warmup_steps: int, - transition_steps: int, - decay_rate: float, - transition_begin: int = 0, - staircase: bool = False, - end_value: Optional[float] = None -) -> base.Schedule: - """Linear warmup followed by exponential decay. - - Args: - init_value: Initial value for the scalar to be annealed. - peak_value: Peak value for scalar to be annealed at end of warmup. - warmup_steps: Positive integer, the length of the linear warmup. - transition_steps: must be positive. See `exponential_decay` for more - details. - decay_rate: must not be zero. The decay rate. - transition_begin: must be positive. After how many steps to start annealing - (before this many steps the scalar value is held fixed at `peak_value`). - staircase: if `True`, decay the values at discrete intervals. - end_value: the value at which the exponential decay stops. When - `decay_rate` < 1, `end_value` is treated as a lower bound, otherwise as - an upper bound. Has no effect when `decay_rate` = 0. - Returns: - schedule: A function that maps step counts to values. - """ - schedules = [ - linear_schedule( - init_value=init_value, - end_value=peak_value, - transition_steps=warmup_steps), - exponential_decay( - init_value=peak_value, - transition_steps=transition_steps, - decay_rate=decay_rate, - transition_begin=transition_begin, - staircase=staircase, - end_value=end_value)] - return join_schedules(schedules, [warmup_steps]) - - -def sgdr_schedule(cosine_kwargs: Iterable[Dict[str, chex.Numeric]] - ) -> base.Schedule: - """SGD with warm restarts, from Loschilov & Hutter (arXiv:1608.03983). - - This learning rate schedule applies multiple joined cosine decay cycles. - For more details see: https://arxiv.org/abs/1608.03983 - - Args: - cosine_kwargs: An Iterable of dicts, where each element specifies the - arguments to pass to each cosine decay cycle. The `decay_steps` kwarg - will specify how long each cycle lasts for, and therefore when to - transition to the next cycle. - Returns: - schedule: A function that maps step counts to values. - """ - boundaries = [] - schedules = [] - step = 0 - for kwargs in cosine_kwargs: - schedules += [warmup_cosine_decay_schedule(**kwargs)] - boundaries += [step + kwargs['decay_steps']] - step += kwargs['decay_steps'] - return join_schedules(schedules, boundaries[:-1]) - - -def _convert_floats(x, dtype): - """Convert float-like inputs to dtype, rest pass through.""" - if jax.dtypes.scalar_type_of(x) == float: - return jnp.asarray(x, dtype=dtype) - return x - - -class InjectHyperparamsState(NamedTuple): - """Maintains inner transform state, hyperparameters, and step count.""" - count: jnp.ndarray # shape=(), dtype=jnp.int32 - hyperparams: Dict[str, chex.Numeric] - inner_state: base.OptState - - -def inject_hyperparams( - inner_factory: Callable[..., base.GradientTransformation], - static_args: Union[str, Iterable[str]] = (), - hyperparam_dtype: Optional[jnp.dtype] = None, -) -> Callable[..., base.GradientTransformation]: - """Wrapper that injects hyperparameters into the inner GradientTransformation. - - This wrapper allows you to pass schedules (i.e. a function that returns a - numeric value given a step count) instead of constants for - hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean - flags cannot be scheduled). - - For example, to use ``scale_by_adam`` with a piecewise linear - schedule for beta_1 and constant for beta_2:: - - scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)( - b1=optax.piecewise_linear_schedule(...), - b2=0.99) - - You may manually change numeric hyperparameters that were not scheduled - through the ``hyperparams`` dict in the ``InjectHyperparamState``:: - - state = scheduled_adam.init(params) - updates, state = scheduled_adam.update(grads, state) - state.hyperparams['b2'] = 0.95 - updates, state = scheduled_adam.update(updates, state) # uses b2 = 0.95 - - Manually overriding scheduled hyperparameters will have no effect (e.g. - in the code sample above, you cannot manually adjust ``b1``). - - Args: - inner_factory: a function that returns the inner - ``optax.GradientTransformation`` given the hyperparameters. - static_args: a string or iterable of strings specifying which - callable parameters are not schedules. inject_hyperparams treats all - callables as schedules by default, so if a hyperparameter is a - non-schedule callable, you must specify that using this argument. - hyperparam_dtype: Optional datatype override. If specified, all float - hyperparameters will be cast to this type. - - Returns: - A callable that returns a ``optax.GradientTransformation``. This callable - accepts the same arguments as ``inner_factory``, except you may provide - schedules in place of the constant arguments. - """ - static_args = ({static_args} if isinstance(static_args, str) else - set(static_args)) - inner_signature = inspect.signature(inner_factory) - - if not static_args.issubset(inner_signature.parameters): - raise ValueError( - '`static_args` must specify a subset of `inner_factory`\'s parameters. ' - f'Given `static_args`: {static_args}. `inner_factory` parameters: ' - f'{set(inner_signature.parameters.keys())}') - - @functools.wraps(inner_factory) - def wrapped_transform(*args, **kwargs) -> base.GradientTransformation: - bound_arguments = inner_signature.bind(*args, **kwargs) - bound_arguments.apply_defaults() - - sched_hps, numeric_hps, other_hps = {}, {}, {} - for name, value in bound_arguments.arguments.items(): - if name in static_args or isinstance(value, bool): - other_hps[name] = value - elif callable(value): - sched_hps[name] = value - elif isinstance(value, (int, float, chex.Array)): - numeric_hps[name] = value - else: - other_hps[name] = value - - def schedule_fn(count, dtype): - return {k: _convert_floats(f(count), dtype) for k, f in sched_hps.items()} - - def init_fn(params): - count = jnp.zeros([], jnp.int32) - if hyperparam_dtype is None: - dtype = getattr(next(iter( - jax.tree_util.tree_leaves(params)), None), 'dtype', None) - else: - dtype = hyperparam_dtype - hparams = { - k: jnp.asarray(_convert_floats(v, dtype)) - for k, v in numeric_hps.items()} - hparams.update(schedule_fn(count, dtype)) - return InjectHyperparamsState( # pylint:disable=too-many-function-args - count, hparams, inner_factory(**other_hps, **hparams).init(params)) - - def update_fn(updates, state, params=None): - if hyperparam_dtype is None: - dtype = getattr(next(iter( - jax.tree_util.tree_leaves(updates)), None), 'dtype', None) - else: - dtype = hyperparam_dtype - hparams = {k: _convert_floats(v, dtype) - for k, v in state.hyperparams.items()} - hparams.update(schedule_fn(state.count, dtype)) - updates, inner_state = inner_factory(**other_hps, **hparams).update( - updates, state.inner_state, params) - count_inc = numerics.safe_int32_increment(state.count) - - # pylint:disable=too-many-function-args - return updates, InjectHyperparamsState(count_inc, hparams, inner_state) - # pylint:enable=too-many-function-args - - return base.GradientTransformation(init_fn, update_fn) - - return wrapped_transform diff --git a/optax_add_eve/_src/schedule_test.py b/optax_add_eve/_src/schedule_test.py deleted file mode 100644 index a862c442..00000000 --- a/optax_add_eve/_src/schedule_test.py +++ /dev/null @@ -1,649 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `schedule.py`.""" - -import functools - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import clipping -from optax_add_eve._src import schedule -from optax_add_eve._src import transform -from optax_add_eve._src import wrappers - - -class ConstantTest(chex.TestCase): - - @chex.all_variants - def test_constant(self): - """Check constant schedule.""" - # Get schedule function. - const_value = 10 - num_steps = 15 - schedule_fn = self.variant(schedule.constant_schedule(const_value)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(num_steps): - # Compute next value. - generated_vals.append(schedule_fn(count)) - # Test output. - expected_vals = np.array([const_value] * num_steps, dtype=np.float32) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - -class PolynomialTest(chex.TestCase): - - @chex.all_variants - def test_linear(self): - """Check linear schedule.""" - # Get schedule function. - schedule_fn = self.variant( - schedule.polynomial_schedule( - init_value=10., end_value=20., power=1, transition_steps=10)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(15): - # Compute next value. - generated_vals.append(schedule_fn(count)) - # Test output. - expected_vals = np.array(list(range(10, 20)) + [20] * 5, dtype=np.float32) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - @chex.all_variants - def test_zero_steps_schedule(self): - # Get schedule function. - initial_value = 10. - end_value = 20. - - for num_steps in [-1, 0]: - schedule_fn = self.variant( - schedule.polynomial_schedule( - init_value=initial_value, end_value=end_value, - power=1, transition_steps=num_steps)) - for count in range(15): - np.testing.assert_allclose(schedule_fn(count), initial_value) - - @chex.all_variants - def test_nonlinear(self): - """Check non-linear (quadratic) schedule.""" - # Get schedule function. - schedule_fn = self.variant( - schedule.polynomial_schedule( - init_value=25., end_value=10., power=2, transition_steps=10)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(15): - # Compute next value. - generated_vals.append(schedule_fn(count)) - # Test output. - expected_vals = np.array( - [10. + 15. * (1. - n / 10)**2 for n in range(10)] + [10] * 5, - dtype=np.float32) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - @chex.all_variants - def test_with_decay_begin(self): - """Check quadratic schedule with non-zero schedule begin.""" - # Get schedule function. - schedule_fn = self.variant( - schedule.polynomial_schedule( - init_value=30., end_value=10., power=2, - transition_steps=10, transition_begin=4)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(20): - # Compute next value. - generated_vals.append(schedule_fn(count)) - # Test output. - expected_vals = np.array( - [30.] * 4 + [10. + 20. * (1. - n / 10)**2 for n in range(10)] + - [10] * 6, - dtype=np.float32) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - -class PiecewiseConstantTest(chex.TestCase): - - @chex.all_variants - def test_positive(self): - """Check piecewise constant schedule of positive values.""" - # Get schedule function. - schedule_fn = self.variant( - schedule.piecewise_constant_schedule(0.1, {3: 2., 6: 0.5})) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(10): - # Compute next value. - generated_vals.append(schedule_fn(count)) - # Test output. - expected_vals = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1]) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - @chex.all_variants - def test_negative(self): - """Check piecewise constant schedule of negative values.""" - # Get schedule function. - schedule_fn = self.variant( - schedule.piecewise_constant_schedule(-0.1, {3: 2., 6: 0.5})) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(10): - # Compute next value. - generated_vals.append(schedule_fn(count)) - # Test output. - expected_vals = -1 * np.array( - [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1]) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - -class ExponentialTest(chex.TestCase): - - @chex.all_variants - @parameterized.parameters(False, True) - def test_constant_schedule(self, staircase): - """Checks constant schedule for exponential decay schedule.""" - num_steps = 15 - # Get schedule function. - init_value = 1. - schedule_fn = self.variant( - schedule.exponential_decay( - init_value=init_value, transition_steps=num_steps, - decay_rate=1., staircase=staircase)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(num_steps): - generated_vals.append(schedule_fn(count)) - expected_vals = np.array([init_value] * num_steps, dtype=np.float32) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - @chex.all_variants - @parameterized.parameters(False, True) - def test_nonvalid_transition_steps(self, staircase): - """Checks nonvalid decay steps results in a constant schedule.""" - init_value = 1. - for transition_steps in [-1, 0]: - schedule_fn = self.variant( - schedule.exponential_decay( - init_value=init_value, transition_steps=transition_steps, - decay_rate=1., staircase=staircase)) - for count in range(15): - np.testing.assert_allclose(schedule_fn(count), init_value) - - @chex.all_variants - @parameterized.parameters(False, True) - def test_nonvalid_decay_rate(self, staircase): - """Checks nonvalid decay steps results in a constant schedule.""" - init_value = 1. - schedule_fn = self.variant( - schedule.exponential_decay( - init_value=init_value, transition_steps=2, - decay_rate=0., staircase=staircase)) - for count in range(15): - np.testing.assert_allclose(schedule_fn(count), init_value) - - @chex.all_variants - @parameterized.parameters((False, 0), (True, 0), (False, 5), (True, 5)) - def test_exponential(self, staircase, transition_begin): - """Checks non-linear (quadratic) schedule.""" - # Get schedule function. - init_value = 1. - num_steps = 15 - transition_steps = 2 - decay_rate = 2. - schedule_fn = self.variant( - schedule.exponential_decay( - init_value=init_value, transition_steps=transition_steps, - decay_rate=decay_rate, transition_begin=transition_begin, - staircase=staircase)) - - # Test that generated values equal the expected schedule values. - def _staircased(count): - p = count / transition_steps - if staircase: - p = np.floor(p) - return p - - generated_vals = [] - for count in range(num_steps + transition_begin): - generated_vals.append(schedule_fn(count)) - expected_vals = np.array( - [init_value] * transition_begin + [ - init_value * np.power(decay_rate, _staircased(count)) - for count in range(num_steps) - ], - dtype=np.float32) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - @chex.all_variants - @parameterized.parameters( - (0.2, 0.1, False), (1.0, 0.1, False), (2.0, 3.0, False), - (0.2, 0.1, True), (1.0, 0.1, True), (2.0, 3.0, True)) - def test_end_value_with_staircase(self, decay_rate, end_value, staircase): - # Get schedule function. - init_value = 1. - num_steps = 11 - transition_steps = 2 - transition_begin = 3 - schedule_fn = self.variant( - schedule.exponential_decay( - init_value=init_value, transition_steps=transition_steps, - decay_rate=decay_rate, transition_begin=transition_begin, - staircase=staircase, end_value=end_value)) - - # Test that generated values equal the expected schedule values. - def _staircased(count): - p = count / transition_steps - if staircase: - p = np.floor(p) - return p - - generated_vals = [] - for count in range(num_steps + transition_begin): - generated_vals.append(schedule_fn(count)) - expected_vals = np.array( - [init_value] * transition_begin + [ - init_value * np.power(decay_rate, _staircased(count)) - for count in range(num_steps) - ], - dtype=np.float32) - - if decay_rate < 1.0: - expected_vals = np.maximum(expected_vals, end_value) - else: - expected_vals = np.minimum(expected_vals, end_value) - - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - @chex.all_variants - def test_immutable_count(self): - """Checks constant schedule for exponential decay schedule.""" - num_steps = 5 - # Get schedule function. - init_value = 32. - schedule_fn = self.variant( - schedule.exponential_decay( - init_value=init_value, transition_steps=1, - decay_rate=0.5)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(num_steps): - # Jax arrays are read-only in ChexVariantType.WITHOUT_DEVICE. - immutable_count = jnp.array(count, dtype=jnp.float32) - generated_vals.append(schedule_fn(immutable_count)) - expected_vals = np.array([32, 16, 8, 4, 2], dtype=np.float32) - np.testing.assert_allclose( - expected_vals, np.array(generated_vals), atol=1e-3) - - -class CosineDecayTest(chex.TestCase): - - @chex.all_variants - def test_decay_count_smaller_count(self): - """Check cosine schedule decay for the entire training schedule.""" - initial_value = 0.1 - schedule_fn = self.variant( - schedule.cosine_decay_schedule(initial_value, 10, 0.0)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(10): - # Compute next value. - generated_vals.append(schedule_fn(count)) - # Test output. - expected_multipliers = np.array( - 0.5 + 0.5 * np.cos( - np.pi * np.array( - [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]))) - np.testing.assert_allclose( - initial_value * expected_multipliers, - np.array(generated_vals), atol=1e-3) - - @chex.all_variants - def test_decay_count_greater_count(self): - """Check cosine schedule decay for a part of the training schedule.""" - initial_value = 0.1 - schedule_fn = self.variant( - schedule.cosine_decay_schedule(initial_value, 5, 0.0)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(12): - # Compute next value. - generated_vals.append(schedule_fn(count)) - - # Test output. - expected_multipliers = np.array( - 0.5 + 0.5 * np.cos( - np.pi * np.array( - [0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.]))) - np.testing.assert_allclose( - initial_value * expected_multipliers, - np.array(generated_vals), atol=1e-3) - - @chex.all_variants - def test_decay_count_greater_count_with_alpha(self): - """Check cosine schedule decay for a part of the training schedule.""" - # Get schedule function. - initial_value = 0.1 - schedule_fn = self.variant( - schedule.cosine_decay_schedule(initial_value, 5, 0.1)) - # Test that generated values equal the expected schedule values. - generated_vals = [] - for count in range(12): - # Compute next value. - generated_vals.append(schedule_fn(count)) - - # Test output. - expected_multipliers = np.array( - 0.5 + 0.5 * np.cos( - np.pi * np.array( - [0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.]))) - expected_multipliers = 0.9 * expected_multipliers + 0.1 - np.testing.assert_allclose( - initial_value * expected_multipliers, - np.array(generated_vals), atol=1e-3) - - -class WarmupCosineDecayTest(chex.TestCase): - - @chex.all_variants - @parameterized.named_parameters( - ('with end value', 10, 0.5, 1e-4), - ('without end value', 5, 3, 0.),) - def test_limits(self, init_value, peak_value, end_value): - """Check cosine schedule decay for the entire training schedule.""" - schedule_fn = self.variant(schedule.warmup_cosine_decay_schedule( - init_value=init_value, - peak_value=peak_value, - warmup_steps=100, - decay_steps=1000, - end_value=end_value, - )) - - np.testing.assert_allclose(init_value, schedule_fn(0)) - np.testing.assert_allclose(peak_value, schedule_fn(100)) - np.testing.assert_allclose(end_value, schedule_fn(1000), rtol=1e-3) - - -class SGDRTest(chex.TestCase): - - @chex.all_variants - @parameterized.named_parameters( - ('with step decay', 1.6, 0.8, 0.4), - ('without step_decay', 1.6, 1.6, 1.6),) - def test_limits(self, lr0, lr1, lr2): - """Check cosine schedule decay for the entire training schedule.""" - lr_kwargs = [] - for step, lr in zip([2e3, 3e3, 5e3], [lr0, lr1, lr2]): - lr_kwargs += [dict(decay_steps=int(step), peak_value=lr, - init_value=0, end_value=0.0, warmup_steps=500)] - schedule_fn = self.variant(schedule.sgdr_schedule(lr_kwargs)) - np.testing.assert_allclose(lr0, schedule_fn(500)) - np.testing.assert_allclose(lr1, schedule_fn(2500)) - np.testing.assert_allclose(lr2, schedule_fn(5500)) - - -class PiecewiseInterpolateTest(chex.TestCase): - - @chex.all_variants - def test_linear_piecewise(self): - schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( - 'linear', 200., {5: 1.5, 10: 0.25})) - generated_vals = [schedule_fn(step) for step in range(13)] - expected_vals = [200., 220., 240., 260., 280., 300., 255., 210., 165., - 120., 75., 75., 75.] - np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) - - @chex.all_variants - def test_cos_piecewise(self): - schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( - 'cosine', 400., {5: 1.2, 3: 0.6, 7: 1.})) - generated_vals = [schedule_fn(step) for step in range(9)] - expected_vals = [400., 360., 280., 240., 264., 288., 288., 288., 288.] - np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) - - @chex.all_variants - def test_empty_dict(self): - schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( - 'linear', 13., {})) - generated_vals = [schedule_fn(step) for step in range(5)] - expected_vals = [13., 13., 13., 13., 13.] - np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) - - @chex.all_variants - def test_no_dict(self): - schedule_fn = self.variant(schedule.piecewise_interpolate_schedule( - 'cosine', 17.)) - generated_vals = [schedule_fn(step) for step in range(3)] - expected_vals = [17., 17., 17.] - np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) - - def test_invalid_type(self): - # pytype: disable=wrong-arg-types - with self.assertRaises(ValueError): - schedule.piecewise_interpolate_schedule('linar', 13.) - with self.assertRaises(ValueError): - schedule.piecewise_interpolate_schedule('', 13., {5: 3.}) - with self.assertRaises(ValueError): - schedule.piecewise_interpolate_schedule(None, 13., {}) - # pytype: enable=wrong-arg-types - - def test_invalid_scale(self): - with self.assertRaises(ValueError): - schedule.piecewise_interpolate_schedule('linear', 13., {5: -3}) - - -class OneCycleTest(chex.TestCase): - - @chex.all_variants - def test_linear(self): - schedule_fn = self.variant(schedule.linear_onecycle_schedule( - transition_steps=10, - peak_value=1000, - pct_start=0.3, - pct_final=0.7, - div_factor=10., - final_div_factor=100.)) - - generated_vals = [schedule_fn(step) for step in range(12)] - expected_vals = [100., 400., 700., 1000., 775., 550., 325., 100., 67., - 34., 1., 1.] - np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) - - @chex.all_variants - def test_cosine(self): - schedule_fn = self.variant(schedule.cosine_onecycle_schedule( - transition_steps=5, - peak_value=1000., - pct_start=0.4, - div_factor=10., - final_div_factor=100.)) - - generated_vals = [schedule_fn(step) for step in range(7)] - expected_vals = [100., 550., 1000., 750.25, 250.75, 1., 1.] - np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3) - - def test_nonpositive_transition_steps(self): - with self.assertRaises(ValueError): - schedule.cosine_onecycle_schedule(transition_steps=0, peak_value=5.) - with self.assertRaises(ValueError): - schedule.linear_onecycle_schedule(transition_steps=0, peak_value=5.) - - -class InjectHyperparamsTest(chex.TestCase): - """Tests for the inject_hyperparams wrapper.""" - - @chex.all_variants - def test_updates(self): - optim = schedule.inject_hyperparams(transform.scale)( # stateless - step_size=schedule.piecewise_constant_schedule( - 3.0, {1: 5, 7: 2, 12: 1.5})) - - params = [jnp.zeros([], dtype=jnp.float32)] - state = self.variant(optim.init)(params) - update_fn = self.variant(optim.update) - expected_step_size = [3.0]*2 + [15.0]*6 + [30.0]*5 + [45.0]*3 - - grads = [jnp.ones([], dtype=jnp.float32)] - for i in range(15): - updates, state = update_fn(grads, state, params=params) - np.testing.assert_almost_equal(updates[0], expected_step_size[i+1]) - - @chex.all_variants - def test_hyperparams_state(self): - optim = schedule.inject_hyperparams(transform.trace)( # stateful - decay=schedule.piecewise_constant_schedule( - 0.8, {3: 0.5, 9: 1.25}), - nesterov=True) - - params = [jnp.zeros([2, 3]) for _ in range(3)] - state = self.variant(optim.init)(params) - update_fn = self.variant(optim.update) - - expected_mom = [0.8]*4 + [0.4]*6 + [0.5]*2 - grads = jax.tree_util.tree_map(jnp.ones_like, params) - for i in range(12): - np.testing.assert_almost_equal(state.hyperparams['decay'], - expected_mom[i]) - _, state = update_fn(grads, state) - - np.testing.assert_almost_equal(state.hyperparams['decay'], - expected_mom[-1]) - - @chex.all_variants - def test_constant_hyperparams(self): - optim = schedule.inject_hyperparams(transform.scale_by_adam)(b1=0., b2=0.) - - params = [jnp.zeros([2, 3]) for _ in range(3)] - state = self.variant(optim.init)(params) - update_fn = self.variant(optim.update) - - grads = jax.tree_util.tree_map(jnp.ones_like, params) - for _ in range(5): - updates, state = update_fn(grads, state, params) - np.testing.assert_almost_equal(state.hyperparams['b1'], 0.0) - np.testing.assert_almost_equal(state.hyperparams['b2'], 0.0) - np.testing.assert_almost_equal(state.hyperparams['eps'], 1e-8) - np.testing.assert_almost_equal(state.hyperparams['eps_root'], 0.0) - assert 'eps' in state.hyperparams - chex.assert_trees_all_close(updates, grads) - - @chex.all_variants - def test_overriding_hyperparam(self): - optim = schedule.inject_hyperparams(clipping.clip_by_global_norm)(0.1) - params = jnp.zeros((3, 5, 7)) - state = self.variant(optim.init)(params) - update_fn = self.variant(optim.update) - - grads = jnp.ones_like(params) - for i in range(5): - state.hyperparams['max_norm'] = i - updates, state = update_fn(grads, state) - assert np.isclose(jnp.linalg.norm(updates.ravel()), i) - - @chex.all_variants - @parameterized.named_parameters(('string', 'mask'), ('list', ['mask'])) - def test_static_args(self, static_args): - @functools.partial(schedule.inject_hyperparams, static_args=static_args) - def custom_optim(learning_rate, mask): - return wrappers.masked(transform.scale(-learning_rate), mask) - - optim = custom_optim( - 0.1, functools.partial(jax.tree_util.tree_map, lambda x: x.ndim > 1)) - params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))] - grads = params - state = self.variant(optim.init)(params) - updates, state = self.variant(optim.update)(grads, state) - expected_updates = jax.tree_util.tree_map( - lambda x: -0.1 * x if x.ndim > 1 else x, grads) - - assert set(state.hyperparams.keys()) == {'learning_rate'}, state.hyperparams - chex.assert_trees_all_close(updates, expected_updates) - - @chex.all_variants - @parameterized.named_parameters(('one_arg', 'b1'), ('two_arg', ['b1', 'b2'])) - def test_numeric_static_args(self, static_args): - optim = schedule.inject_hyperparams( - transform.scale_by_adam, static_args=static_args)(b1=0.9, b2=0.95) - - params = [jnp.ones((1, 2)), jnp.ones(2), jnp.ones((1, 1, 1))] - grads = params - state = self.variant(optim.init)(params) - _, state = self.variant(optim.update)(grads, state) - - assert not set(state.hyperparams.keys()).intersection(set(static_args)) - - @chex.all_variants - @parameterized.named_parameters( - ('bf16hyp f32param bf16grad', jnp.bfloat16, jnp.float32, jnp.bfloat16), - ('bf16hyp f32param f32_grads', jnp.bfloat16, jnp.float32, jnp.float32), - ('f32hyp bf16param bf16grad', jnp.float32, jnp.bfloat16, jnp.bfloat16), - ('f32hyp f32param bf16grad', jnp.float32, jnp.float32, jnp.bfloat16), - ('f32hyp bf16param f32grad', jnp.float32, jnp.bfloat16, jnp.float32), - ) - def test_hyperparam_dtypes(self, - hyperparam_dtype, - param_dtype, - grad_dtype): - """Tests that hyperparam dtype override works as desired.""" - optim = schedule.inject_hyperparams( - transform.scale_by_adam, - hyperparam_dtype=hyperparam_dtype)(b1=0.9, b2=0.95) - - params = [jnp.ones((1, 2), dtype=param_dtype), - jnp.ones(2, dtype=param_dtype), - jnp.ones((1, 1, 1), dtype=param_dtype)] - grads = jax.tree_map(lambda x: x.astype(grad_dtype), params) - state = self.variant(optim.init)(params) - # Check that the hyperparams are overriden - self.assertEqual(state.hyperparams['b1'].dtype, hyperparam_dtype) - self.assertEqual(state.hyperparams['b2'].dtype, hyperparam_dtype) - - _, state = self.variant(optim.update)(grads, state) - - self.assertEqual(state.hyperparams['b1'].dtype, hyperparam_dtype) - self.assertEqual(state.hyperparams['b2'].dtype, hyperparam_dtype) - - @parameterized.named_parameters(('string', 'lr'), ('list', ['lr'])) - def test_static_args_error(self, static_args): - with self.assertRaises(ValueError): - schedule.inject_hyperparams(transform.scale, static_args=static_args) - - @chex.all_variants - def test_inject_hyperparams_starts_with_step_count_zero(self): - """Checks that inject_hyperparams uses step count 0 in the first update.""" - # See also: https://github.com/deepmind/optax/issues/415. - opt = schedule.inject_hyperparams(transform.scale)(lambda count: count) - params = jnp.zeros(3) - grads = jnp.array([-1, 0, 1]) - updates, _ = self.variant(opt.update)(grads, opt.init(params)) - np.testing.assert_array_equal(updates, np.zeros(3)) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/second_order_test.py b/optax_add_eve/_src/second_order_test.py deleted file mode 100644 index 820f1ed8..00000000 --- a/optax_add_eve/_src/second_order_test.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `second_order.py`.""" - -import collections -import functools -import itertools - -from absl.testing import absltest - -import chex -import haiku as hk -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import second_order - - -NUM_CLASSES = 2 -NUM_SAMPLES = 3 -NUM_FEATURES = 4 - - -class SecondOrderTest(chex.TestCase): - - def setUp(self): - super().setUp() - - self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES) - self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES) - - def net_fn(z): - mlp = hk.Sequential( - [hk.Linear(10), jax.nn.relu, hk.Linear(NUM_CLASSES)], name='mlp') - return jax.nn.log_softmax(mlp(z)) - - net = hk.without_apply_rng(hk.transform(net_fn)) - self.parameters = net.init(jax.random.PRNGKey(0), self.data) - - def loss(params, inputs, targets): - log_probs = net.apply(params, inputs) - return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs) - - self.loss_fn = loss - - def jax_hessian_diag(loss_fun, params, inputs, targets): - """This is the 'ground-truth' obtained via the JAX library.""" - hess = jax.hessian(loss_fun)(params, inputs, targets) - - # Extracts the diagonal components. - hess_diag = collections.defaultdict(dict) - for k0, k1 in itertools.product(params.keys(), ['w', 'b']): - params_shape = params[k0][k1].shape - n_params = np.prod(params_shape) - hess_diag[k0][k1] = jnp.diag(hess[k0][k1][k0][k1].reshape( - n_params, n_params)).reshape(params_shape) - for k, v in hess_diag.items(): - hess_diag[k] = v - return second_order.ravel(hess_diag) - - self.hessian = jax_hessian_diag( - self.loss_fn, self.parameters, self.data, self.labels) - - @chex.all_variants - def test_hessian_diag(self): - hessian_diag_fn = self.variant( - functools.partial(second_order.hessian_diag, self.loss_fn)) - actual = hessian_diag_fn(self.parameters, self.data, self.labels) - np.testing.assert_array_almost_equal(self.hessian, actual, 5) - - @chex.all_variants - def test_fisher_diag_shape(self): - fisher_diag_fn = self.variant( - functools.partial(second_order.fisher_diag, self.loss_fn)) - fisher_diagonal = fisher_diag_fn(self.parameters, self.data, self.labels) - chex.assert_equal_shape([fisher_diagonal, self.hessian]) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/stochastic_gradient_estimators.py b/optax_add_eve/_src/stochastic_gradient_estimators.py deleted file mode 100644 index 82d0d0f5..00000000 --- a/optax_add_eve/_src/stochastic_gradient_estimators.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Stochastic Monte Carlo gradient estimators. - -Utility functions to approximate gradients of the form using Monte Carlo -estimation: - \nabla_{\theta} E_{p(x; \theta)} f(x) - -Here f is assumed to have no dependence on the parameters theta - if f has -dependence on theta, the functions below need to be called with `stop_grad(f)` -and the chain rule needs to be applied outside these functions in order -to obtain unbiased gradient. - -For more details, see: -S. Mohamed, M. Rosca, M. Figurnov, A Mnih. - Monte Carlo Gradient Estimation in Machine Learning. JMLR, 2020. -""" - -import math -from typing import Any, Callable, Sequence - -import chex -import jax -import jax.numpy as jnp -import numpy as np -from optax_add_eve._src import base -from optax_add_eve._src import utils - - -def score_function_jacobians( - function: Callable[[chex.Array], float], - params: base.Params, - dist_builder: Callable[..., Any], - rng: chex.PRNGKey, - num_samples: int) -> Sequence[chex.Array]: - r"""Score function gradient estimation. - - Approximates: - \nabla_{\theta} E_{p(x; \theta)} f(x) - With: - E_{p(x; \theta)} f(x) \nabla_{\theta} \log p(x; \theta) - - Requires: p to be differentiable wrt to theta. Applicable to both continuous - and discrete random variables. No requirements on f. - - Args: - function: Function f(x) for which to estimate grads_{params} E_dist f(x). - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - params: A tuple of jnp arrays. - The parameters for which to construct the distribution. - dist_builder: a constructor which builds a distribution given the input - parameters specified by params. `dist_builder(params)` should return a - valid distribution. - rng: a PRNGKey key. - num_samples: Int, the number of samples used to compute the grads. - - Returns: - A tuple of size `params`, each element is `num_samples x param.shape` - jacobian vector containing the estimates of the gradients obtained for - each sample. - The mean of this vector is the gradient wrt to parameters that can be used - for learning. The entire jacobian vector can be used to assess estimator - variance. - """ - def surrogate(params): - dist = dist_builder(*params) - one_sample_surrogate_fn = lambda x: function(x) * dist.log_prob(x) - samples = jax.lax.stop_gradient(dist.sample((num_samples,), seed=rng)) - # We vmap the function application over samples - this ensures that the - # function we use does not have to be vectorized itself. - return jax.vmap(one_sample_surrogate_fn)(samples) - - return jax.jacfwd(surrogate)(params) - - -def pathwise_jacobians( - function: Callable[[chex.Array], float], - params: base.Params, - dist_builder: Callable[..., Any], - rng: chex.PRNGKey, - num_samples: int) -> Sequence[chex.Array]: - r"""Pathwise gradient estimation. - - Approximates: - \nabla_{\theta} E_{p(x; \theta)} f(x) - With: - E_{p(\epsilon)} \nabla_{\theta} f(g(\epsilon, \theta)) - where x = g(\epsilon, \theta). g depends on the distribution p. - - Requires: p to be reparametrizable and the reparametrization to be implemented - in tensorflow_probability. Applicable to continuous random variables. - f needs to be differentiable. - - Args: - function: Function f(x) for which to estimate grads_{params} E_dist f(x). - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - params: A tuple of jnp arrays. - The parameters for which to construct the distribution. - dist_builder: a constructor which builds a distribution given the input - parameters specified by params. `dist_builder(params)` should return a - valid distribution. - rng: a PRNGKey key. - num_samples: Int, the number of samples used to compute the grads. - - Returns: - A tuple of size `params`, each element is `num_samples x param.shape` - jacobian vector containing the estimates of the gradients obtained for - each sample. - The mean of this vector is the gradient wrt to parameters that can be used - for learning. The entire jacobian vector can be used to assess estimator - variance. - """ - def surrogate(params): - # We vmap the function application over samples - this ensures that the - # function we use does not have to be vectorized itself. - dist = dist_builder(*params) - return jax.vmap(function)(dist.sample((num_samples,), seed=rng)) - - return jax.jacfwd(surrogate)(params) - - -def measure_valued_jacobians( - function: Callable[[chex.Array], float], - params: base.Params, - dist_builder: Callable[..., Any], - rng: chex.PRNGKey, - num_samples: int, - coupling: bool = True) -> Sequence[chex.Array]: - r"""Measure valued gradient estimation. - - Approximates: - \nabla_{\theta} E_{p(x; \theta)} f(x) - With: - 1./ c (E_{p1(x; \theta)} f(x) - E_{p2(x; \theta)} f(x)) where p1 and p2 are - measures which depend on p. - - Currently only supports computing gradients of expectations of Gaussian RVs. - - Args: - function: Function f(x) for which to estimate grads_{params} E_dist f(x). - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - params: A tuple of jnp arrays. - The parameters for which to construct the distribution. - dist_builder: a constructor which builds a distribution given the input - parameters specified by params. `dist_builder(params)` should return a - valid distribution. - rng: a PRNGKey key. - num_samples: Int, the number of samples used to compute the grads. - coupling: A boolean. Whether or not to use coupling for the positive and - negative samples. Recommended: True, as this reduces variance. - - Returns: - A tuple of size `params`, each element is `num_samples x param.shape` - jacobian vector containing the estimates of the gradients obtained for - each sample. - The mean of this vector is the gradient wrt to parameters that can be used - for learning. The entire jacobian vector can be used to assess estimator - variance. - """ - if dist_builder is not utils.multi_normal: - raise ValueError( - 'Unsupported distribution builder for measure_valued_jacobians!') - dist = dist_builder(*params) - # Need to apply chain rule for log scale grad (instead of scale grad). - return [ - measure_valued_estimation_mean( - function, dist, rng, num_samples, coupling=coupling), - jnp.exp(dist.log_scale) * measure_valued_estimation_std( - function, dist, rng, num_samples, coupling=coupling)] - - -def measure_valued_estimation_mean( - function: Callable[[chex.Array], float], - dist: Any, - rng: chex.PRNGKey, - num_samples: int, - coupling: bool = True) -> chex.Array: - """Measure valued grads of a Gaussian expectation of `function` wrt the mean. - - Args: - function: Function f(x) for which to estimate grads_{mean} E_dist f(x). - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - dist: a distribution on which we can call `sample`. - rng: a PRNGKey key. - num_samples: Int, the number of samples used to compute the grads. - coupling: A boolean. Whether or not to use coupling for the positive and - negative samples. Recommended: True, as this reduces variance. - - Returns: - A `num_samples x D` vector containing the estimates of the gradients - obtained for each sample. The mean of this vector can be used to update - the mean parameter. The entire vector can be used to assess estimator - variance. - """ - mean, log_std = dist.params - std = jnp.exp(log_std) - - dist_samples = dist.sample((num_samples,), seed=rng) - - pos_rng, neg_rng = jax.random.split(rng) - pos_sample = jax.random.weibull_min( - pos_rng, scale=math.sqrt(2.), concentration=2., shape=dist_samples.shape) - - if coupling: - neg_sample = pos_sample - else: - neg_sample = jax.random.weibull_min( - neg_rng, - scale=math.sqrt(2.), - concentration=2., - shape=dist_samples.shape) - - # N x D - positive_diag = mean + std * pos_sample - # N x D - negative_diag = mean - std * neg_sample - - # NOTE: you can sample base samples here if you use the same rng - # Duplicate the D dimension - N x D x D. - base_dist_samples = utils.tile_second_to_last_dim(dist_samples) - positive = utils.set_diags(base_dist_samples, positive_diag) - negative = utils.set_diags(base_dist_samples, negative_diag) - - c = np.sqrt(2 * np.pi) * std # D - # Apply function. We apply the function to each element of N x D x D. - # We apply a function that takes a sample and returns one number, so the - # output will be N x D (which is what we want, batch by dimension). - # We apply a function in parallel to the batch. - # Broadcast the division. - vmaped_function = jax.vmap(jax.vmap(function, 1, 0)) - grads = (vmaped_function(positive) - vmaped_function(negative)) / c - - chex.assert_shape(grads, (num_samples,) + std.shape) - return grads - - -def measure_valued_estimation_std( - function: Callable[[chex.Array], float], - dist: Any, - rng: chex.PRNGKey, - num_samples: int, - coupling: bool = True) -> chex.Array: - """Measure valued grads of a Gaussian expectation of `function` wrt the std. - - Args: - function: Function f(x) for which to estimate grads_{std} E_dist f(x). - The function takes in one argument (a sample from the distribution) and - returns a floating point value. - dist: a distribution on which we can call `sample`. - rng: a PRNGKey key. - num_samples: Int, the number of samples used to compute the grads. - coupling: A boolean. Whether or not to use coupling for the positive and - negative samples. Recommended: True, as this reduces variance. - - Returns: - A `num_samples x D` vector containing the estimates of the gradients - obtained for each sample. The mean of this vector can be used to update - the scale parameter. The entire vector can be used to assess estimator - variance. - """ - mean, log_std = dist.params - std = jnp.exp(log_std) - - dist_samples = dist.sample((num_samples,), seed=rng) - - pos_rng, neg_rng = jax.random.split(rng) - - # The only difference between mean and std gradients is what we sample. - pos_sample = jax.random.double_sided_maxwell( - pos_rng, loc=0.0, scale=1.0, shape=dist_samples.shape) - if coupling: - unif_rvs = jax.random.uniform(neg_rng, dist_samples.shape) - neg_sample = unif_rvs * pos_sample - else: - neg_sample = jax.random.normal(neg_rng, dist_samples.shape) - - # Both need to be positive in the case of the scale. - # N x D - positive_diag = mean + std * pos_sample - # N x D - negative_diag = mean + std * neg_sample - - # NOTE: you can sample base samples here if you use the same rng - # Duplicate the D dimension - N x D x D. - base_dist_samples = utils.tile_second_to_last_dim(dist_samples) - positive = utils.set_diags(base_dist_samples, positive_diag) - negative = utils.set_diags(base_dist_samples, negative_diag) - - # Different C for the scale - c = std # D - # Apply function. We apply the function to each element of N x D x D. - # We apply a function that takes a sample and returns one number, so the - # output will be N x D (which is what we want, batch by dimension). - # We apply a function in parallel to the batch. - # Broadcast the division. - vmaped_function = jax.vmap(jax.vmap(function, 1, 0)) - grads = (vmaped_function(positive) - vmaped_function(negative)) / c - - chex.assert_shape(grads, (num_samples,) + std.shape) - return grads - diff --git a/optax_add_eve/_src/stochastic_gradient_estimators_test.py b/optax_add_eve/_src/stochastic_gradient_estimators_test.py deleted file mode 100644 index e89532d4..00000000 --- a/optax_add_eve/_src/stochastic_gradient_estimators_test.py +++ /dev/null @@ -1,371 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `stochastic_gradient_estimators.py`.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import stochastic_gradient_estimators as sge -from optax_add_eve._src import utils - - -# Set seed for deterministic sampling. -np.random.seed(42) - - -_estimator_to_num_samples = { - sge.score_function_jacobians: 5 * 10**5, - sge.measure_valued_jacobians: 10**5, - sge.pathwise_jacobians: 5 * 10**4, -} - -_weighted_estimator_to_num_samples = { - sge.score_function_jacobians: 5 * 10**6, - sge.measure_valued_jacobians: 5 * 10**5, - sge.pathwise_jacobians: 5 * 10**4, -} - - -def _ones(dims): - return jnp.ones(shape=(dims), dtype=jnp.float32) - - -def _assert_equal(actual, expected, rtol=1e-2, atol=1e-2): - """Asserts that arrays are equal.""" - # Note: assert_allclose does not check shapes - chex.assert_equal_shape((actual, expected)) - - # We get around the bug https://github.com/numpy/numpy/issues/13801 - zero_indices = np.argwhere(expected == 0) - if not np.all(np.abs(actual[zero_indices]) <= atol): - raise AssertionError(f'Larger than {atol} diff in {actual[zero_indices]}') - - non_zero_indices = np.argwhere(expected != 0) - np.testing.assert_allclose( - np.asarray(actual)[non_zero_indices], - expected[non_zero_indices], rtol, atol) - - -def _estimator_variant(variant, estimator): - return variant(estimator, static_argnums=(0, 2, 4)) - - -def _measure_valued_variant(variant): - return variant( - sge.measure_valued_jacobians, - static_argnums=(0, 2, 4, 5)) - - -class GradientEstimatorsTest(chex.TestCase): - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], [ - ('0.1', 0.1), - ('0.5', 0.5), - ('0.9', 0.9), - ], - named=True)) - def testConstantFunction(self, estimator, constant): - data_dims = 3 - num_samples = _estimator_to_num_samples[estimator] - - effective_mean = 1.5 - mean = effective_mean * _ones(data_dims) - - effective_log_scale = 0.0 - log_scale = effective_log_scale * _ones(data_dims) - rng = jax.random.PRNGKey(1) - - jacobians = _estimator_variant(self.variant, estimator)( - lambda x: jnp.array(constant), [mean, log_scale], - utils.multi_normal, rng, num_samples) - - # Average over the number of samples. - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads = np.mean(mean_jacobians, axis=0) - expected_mean_grads = np.zeros(data_dims, dtype=np.float32) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads = np.mean(log_scale_jacobians, axis=0) - expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) - - _assert_equal(mean_grads, expected_mean_grads, atol=5e-3) - _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-3) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], [ - ('0.5_-1.', 0.5, -1.), - ('0.7_0.0)', 0.7, 0.0), - ('0.8_0.1', 0.8, 0.1), - ], - named=True)) - def testLinearFunction(self, estimator, effective_mean, effective_log_scale): - data_dims = 3 - num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) - - mean = effective_mean * _ones(data_dims) - log_scale = effective_log_scale * _ones(data_dims) - - jacobians = _estimator_variant(self.variant, estimator)( - np.sum, [mean, log_scale], - utils.multi_normal, rng, num_samples) - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads = np.mean(mean_jacobians, axis=0) - expected_mean_grads = np.ones(data_dims, dtype=np.float32) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads = np.mean(log_scale_jacobians, axis=0) - expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) - - _assert_equal(mean_grads, expected_mean_grads) - _assert_equal(log_scale_grads, expected_log_scale_grads) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], [ - ('1.0_0.3', 1.0, 0.3), - ], - named=True)) - def testQuadraticFunction( - self, estimator, effective_mean, effective_log_scale): - data_dims = 3 - num_samples = _estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) - - mean = effective_mean * _ones(data_dims) - log_scale = effective_log_scale * _ones(data_dims) - - jacobians = _estimator_variant(self.variant, estimator)( - lambda x: np.sum(x**2) / 2, [mean, log_scale], - utils.multi_normal, rng, num_samples) - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads = np.mean(mean_jacobians, axis=0) - expected_mean_grads = effective_mean * np.ones( - data_dims, dtype=np.float32) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads = np.mean(log_scale_jacobians, axis=0) - expected_log_scale_grads = np.exp(2 * effective_log_scale) * np.ones( - data_dims, dtype=np.float32) - - _assert_equal(mean_grads, expected_mean_grads, atol=5e-2) - _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], [ - ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), - ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), - ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [10., 5., 1.]), - ], - named=True)) - def testWeightedLinear( - self, estimator, effective_mean, effective_log_scale, weights): - num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) - - mean = jnp.array(effective_mean) - log_scale = jnp.array(effective_log_scale) - weights = jnp.array(weights) - - data_dims = len(effective_mean) - - function = lambda x: jnp.sum(x * weights) - jacobians = _estimator_variant(self.variant, estimator)( - function, [mean, log_scale], - utils.multi_normal, rng, num_samples) - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads = np.mean(mean_jacobians, axis=0) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads = np.mean(log_scale_jacobians, axis=0) - - expected_mean_grads = weights - expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) - - _assert_equal(mean_grads, expected_mean_grads, atol=5e-2) - _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product([ - ('_score_function_jacobians', sge.score_function_jacobians), - ('_pathwise_jacobians', sge.pathwise_jacobians), - ('_measure_valued_jacobians', sge.measure_valued_jacobians), - ], [ - ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), - ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), - ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [3., 5., 1.]), - ], - named=True)) - def testWeightedQuadratic( - self, estimator, effective_mean, effective_log_scale, weights): - num_samples = _weighted_estimator_to_num_samples[estimator] - rng = jax.random.PRNGKey(1) - - mean = jnp.array(effective_mean, dtype=jnp.float32) - log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) - weights = jnp.array(weights, dtype=jnp.float32) - - data_dims = len(effective_mean) - - function = lambda x: jnp.sum(x * weights) ** 2 - jacobians = _estimator_variant(self.variant, estimator)( - function, [mean, log_scale], - utils.multi_normal, rng, num_samples) - - mean_jacobians = jacobians[0] - chex.assert_shape(mean_jacobians, (num_samples, data_dims)) - mean_grads = np.mean(mean_jacobians, axis=0) - - log_scale_jacobians = jacobians[1] - chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) - log_scale_grads = np.mean(log_scale_jacobians, axis=0) - - expected_mean_grads = 2 * weights * np.sum(weights * mean) - effective_scale = np.exp(log_scale) - expected_scale_grads = 2 * weights ** 2 * effective_scale - expected_log_scale_grads = expected_scale_grads * effective_scale - - _assert_equal(mean_grads, expected_mean_grads, atol=1e-1, rtol=1e-1) - _assert_equal( - log_scale_grads, expected_log_scale_grads, atol=1e-1, rtol=1e-1) - - @chex.all_variants - @parameterized.named_parameters( - chex.params_product( - [ - ('_sum_cos_x', [1.0], [1.0], lambda x: jnp.sum(jnp.cos(x))), - # Need to ensure that the mean is not too close to 0. - ('_sum_log_x', [10.0], [0.0], lambda x: jnp.sum(jnp.log(x))), - ('_sum_cos_2x', [1.0, 2.0], [1.0, -2 - ], lambda x: jnp.sum(jnp.cos(2 * x))), - ('_cos_sum_2x', [1.0, 2.0], [1.0, -2 - ], lambda x: jnp.cos(jnp.sum(2 * x))), - ], - [ - ('coupling', True), - ('nocoupling', False), - ], - named=True)) - def testNonPolynomialFunctionConsistencyWithPathwise(self, effective_mean, - effective_log_scale, - function, coupling): - num_samples = 10**5 - rng = jax.random.PRNGKey(1) - measure_rng, pathwise_rng = jax.random.split(rng) - - mean = jnp.array(effective_mean, dtype=jnp.float32) - log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) - data_dims = len(effective_mean) - - measure_valued_jacobians = _measure_valued_variant(self.variant)( - function, [mean, log_scale], - utils.multi_normal, measure_rng, num_samples, coupling) - - measure_valued_mean_jacobians = measure_valued_jacobians[0] - chex.assert_shape(measure_valued_mean_jacobians, (num_samples, data_dims)) - measure_valued_mean_grads = np.mean(measure_valued_mean_jacobians, axis=0) - - measure_valued_log_scale_jacobians = measure_valued_jacobians[1] - chex.assert_shape( - measure_valued_log_scale_jacobians, (num_samples, data_dims)) - measure_valued_log_scale_grads = np.mean( - measure_valued_log_scale_jacobians, axis=0) - - pathwise_jacobians = _estimator_variant( - self.variant, sge.pathwise_jacobians)(function, [mean, log_scale], - utils.multi_normal, pathwise_rng, - num_samples) - - pathwise_mean_jacobians = pathwise_jacobians[0] - chex.assert_shape(pathwise_mean_jacobians, (num_samples, data_dims)) - pathwise_mean_grads = np.mean(pathwise_mean_jacobians, axis=0) - - pathwise_log_scale_jacobians = pathwise_jacobians[1] - chex.assert_shape(pathwise_log_scale_jacobians, (num_samples, data_dims)) - pathwise_log_scale_grads = np.mean(pathwise_log_scale_jacobians, axis=0) - - _assert_equal( - pathwise_mean_grads, measure_valued_mean_grads, rtol=5e-1, atol=1e-1) - _assert_equal( - pathwise_log_scale_grads, measure_valued_log_scale_grads, - rtol=5e-1, atol=1e-1) - - -class MeasuredValuedEstimatorsTest(chex.TestCase): - - @chex.all_variants - @parameterized.parameters([True, False]) - def testRaisesErrorForNonGaussian(self, coupling): - num_samples = 10**5 - rng = jax.random.PRNGKey(1) - - function = lambda x: jnp.sum(x) ** 2 - - mean = jnp.array(0, dtype=jnp.float32) - log_scale = jnp.array(0., dtype=jnp.float32) - - class TestDist(): - - def __init__(self, params): - self._params = params - - def sample(self, n): - return np.zeros(n) - - with self.assertRaises(ValueError): - _measure_valued_variant(self.variant)( - function, [mean, log_scale], - TestDist, rng, num_samples, coupling) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/transform.py b/optax_add_eve/_src/transform.py deleted file mode 100644 index ba4037ac..00000000 --- a/optax_add_eve/_src/transform.py +++ /dev/null @@ -1,1206 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gradient transformations.""" - -import functools -from typing import Any, Callable, NamedTuple, Optional, Union - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base -from optax_add_eve._src import clipping -from optax_add_eve._src import numerics -from optax_add_eve._src import utils -from optax_add_eve._src import wrappers - -# pylint:disable=no-value-for-parameter - -_abs_sq = numerics.abs_sq - - -class TraceState(NamedTuple): - """Holds an aggregation of past updates.""" - trace: base.Params - - -def trace( - decay: float, - nesterov: bool = False, - accumulator_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """Compute a trace of past updates. - - Note: `trace` and `ema` have very similar but distinct updates; - `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`. - Both are frequently found in the optimization literature. - - Args: - decay: Decay rate for the trace of past updates. - nesterov: Whether to use Nesterov momentum. - accumulator_dtype: Optional `dtype` to be used for the accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - - Returns: - A `GradientTransformation` object. - """ - - accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) - - def init_fn(params): - return TraceState( - trace=jax.tree_util.tree_map( - lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) - - def update_fn(updates, state, params=None): - del params - f = lambda g, t: g + decay * t - new_trace = jax.tree_util.tree_map(f, updates, state.trace) - updates = ( - jax.tree_util.tree_map(f, updates, new_trace) if nesterov - else new_trace) - new_trace = utils.cast_tree(new_trace, accumulator_dtype) - return updates, TraceState(trace=new_trace) - - return base.GradientTransformation(init_fn, update_fn) - - -def update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order`-th moment.""" - return jax.tree_util.tree_map( - lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) - - -def update_infinity_moment(updates, moments, decay, eps): - """Compute the exponential moving average of the infinity norm.""" - return jax.tree_util.tree_map( - lambda g, t: jnp.maximum(jnp.abs(g) + eps, decay * t), updates, moments) - - -def update_moment_per_elem_norm(updates, moments, decay, order): - """Compute the EMA of the `order`-th moment of the element-wise norm.""" - - def orderth_norm(g): - if jnp.isrealobj(g): - return g ** order - else: - half_order = order / 2 - # JAX generates different HLO for int and float `order` - if half_order.is_integer(): - half_order = int(half_order) - return _abs_sq(g) ** half_order - - return jax.tree_util.tree_map( - lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments) - - -@functools.partial(jax.jit, inline=True) -def bias_correction(moment, decay, count): - """Performs bias correction. It becomes a no-op as count goes to infinity.""" - # The conversion to the data type of the moment ensures that bfloat16 remains - # bfloat16 in the optimizer state. This conversion has to be done after - # `bias_correction_` is calculated as calculating `decay**count` in low - # precision can result in it being rounded to 1 and subsequently a - # "division by zero" error. - bias_correction_ = 1 - decay**count - - # Perform division in the original precision. - return jax.tree_util.tree_map( - lambda t: t / bias_correction_.astype(t.dtype), moment) - - -def _reject_complex(params): - if any(jnp.iscomplexobj(x) for x in jax.tree_util.tree_leaves(params)): - raise ValueError('This transformation does not support complex parameters.') - - -class EmaState(NamedTuple): - """Holds an exponential moving average of past updates.""" - count: chex.Array # shape=(), dtype=jnp.int32. - ema: base.Params - - -def ema( - decay: float, - debias: bool = True, - accumulator_dtype: Optional[Any] = None -) -> base.GradientTransformation: - """Compute an exponential moving average of past updates. - - Note: `trace` and `ema` have very similar but distinct updates; - `ema = decay * ema + (1-decay) * t`, while `trace = decay * trace + t`. - Both are frequently found in the optimization literature. - - Args: - decay: Decay rate for the exponential moving average. - debias: Whether to debias the transformed gradient. - accumulator_dtype: Optional `dtype` to used for the accumulator; if `None` - then the `dtype` is inferred from `params` and `updates`. - - Returns: - A `GradientTransformation` object. - """ - - accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) - - def init_fn(params): - return EmaState( - count=jnp.zeros([], jnp.int32), - ema=jax.tree_util.tree_map( - lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) - - def update_fn(updates, state, params=None): - del params - updates = new_ema = update_moment(updates, state.ema, decay, order=1) - count_inc = utils.safe_int32_increment(state.count) - if debias: - updates = bias_correction(new_ema, decay, count_inc) - state_ema = utils.cast_tree(new_ema, accumulator_dtype) - return updates, EmaState(count=count_inc, ema=state_ema) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByRssState(NamedTuple): - """State holding the sum of gradient squares to date.""" - sum_of_squares: base.Updates - - -def scale_by_rss( - initial_accumulator_value: float = 0.1, - eps: float = 1e-7 -) -> base.GradientTransformation: - """Rescale updates by the root of the sum of all squared gradients to date. - - References: - [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) - [McMahan et al., 2010](https://arxiv.org/abs/1002.4908) - - Args: - initial_accumulator_value: Starting value for accumulators, must be >= 0. - eps: A small floating point value to avoid zero denominator. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - sum_of_squares = jax.tree_util.tree_map( - lambda t: jnp.full_like(t, initial_accumulator_value), params) - return ScaleByRssState(sum_of_squares=sum_of_squares) - - def update_fn(updates, state, params=None): - del params - sum_of_squares = jax.tree_util.tree_map( - lambda g, t: _abs_sq(g) + t, updates, state.sum_of_squares) - inv_sqrt_g_square = jax.tree_util.tree_map( - lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), sum_of_squares) - updates = jax.tree_util.tree_map( - lambda scale, g: scale * g, inv_sqrt_g_square, updates) - return updates, ScaleByRssState(sum_of_squares=sum_of_squares) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByRmsState(NamedTuple): - """State for exponential root mean-squared (RMS)-normalized updates.""" - nu: base.Updates - - -def scale_by_rms( - decay: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0. -) -> base.GradientTransformation: - """Rescale updates by the root of the exp. moving avg of the square. - - References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) - - Args: - decay: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - initial_scale: Initial value for second moment. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - nu = jax.tree_util.tree_map( - lambda n: jnp.full_like(n, initial_scale), params) # second moment - return ScaleByRmsState(nu=nu) - - def update_fn(updates, state, params=None): - del params - nu = update_moment_per_elem_norm(updates, state.nu, decay, 2) - updates = jax.tree_util.tree_map( - lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu) - return updates, ScaleByRmsState(nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByRStdDevState(NamedTuple): - """State for centered exponential moving average of squares of updates.""" - mu: base.Updates - nu: base.Updates - - -def scale_by_stddev( - decay: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0. -) -> base.GradientTransformation: - """Rescale updates by the root of the centered exp. moving average of squares. - - References: - [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) - - Args: - decay: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - initial_scale: Initial value for second moment. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_util.tree_map( - lambda n: jnp.full_like(n, initial_scale), params) # second moment - return ScaleByRStdDevState(mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = update_moment(updates, state.mu, decay, 1) - nu = update_moment_per_elem_norm(updates, state.nu, decay, 2) - updates = jax.tree_util.tree_map( - lambda g, m, n: g * jax.lax.rsqrt(n - _abs_sq(m) + eps), - updates, mu, nu) - return updates, ScaleByRStdDevState(mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the Adam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: base.Updates - nu: base.Updates - - -def scale_by_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - mu_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """Rescale updates according to the Adam algorithm. - - References: - [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - mu_dtype: Optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype is inferred from `params` and `updates`. - - Returns: - A `GradientTransformation` object. - """ - - mu_dtype = utils.canonicalize_dtype(mu_dtype) - - def init_fn(params): - mu = jax.tree_util.tree_map( # First moment - lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) - nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = update_moment(updates, state.mu, b1, 1) - nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) - count_inc = numerics.safe_int32_increment(state.count) - mu_hat = bias_correction(mu, b1, count_inc) - nu_hat = bias_correction(nu, b2, count_inc) - updates = jax.tree_util.tree_map( - lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) - mu = utils.cast_tree(mu, mu_dtype) - return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByAmsgradState(NamedTuple): - """State for the AMSGrad algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: base.Updates - nu: base.Updates - nu_max: base.Updates - - -def scale_by_amsgrad( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - mu_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """Rescale updates according to the AMSGrad algorithm. - - References: - [Reddi et al, 2018](https://openreview.net/forum?id=ryQu7f-RZ) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - mu_dtype: Optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype is inferred from `params` and `updates`. - - Returns: - A `GradientTransformation` object. - """ - - mu_dtype = utils.canonicalize_dtype(mu_dtype) - - def init_fn(params): - mu = jax.tree_util.tree_map( # First moment - lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) - nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - nu_max = jax.tree_util.tree_map(jnp.zeros_like, params) - return ScaleByAmsgradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, - nu_max=nu_max) - - def update_fn(updates, state, params=None): - del params - mu = update_moment(updates, state.mu, b1, 1) - nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) - count_inc = numerics.safe_int32_increment(state.count) - mu_hat = bias_correction(mu, b1, count_inc) - nu_hat = bias_correction(nu, b2, count_inc) - nu_max = jax.tree_util.tree_map(jnp.maximum, state.nu_max, nu_hat) - updates = jax.tree_util.tree_map( - lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_max) - mu = utils.cast_tree(mu, mu_dtype) - return updates, ScaleByAmsgradState(count=count_inc, mu=mu, nu=nu, - nu_max=nu_max) - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_adamax( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8 -) -> base.GradientTransformation: - """Rescale updates according to the Adamax algorithm. - - References:nu = update_infinity_moment(updates, state.nu, b2, eps) - count_inc = utils.numerics.safe_int32_increment(state.count) - mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) - nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted maximum of grads. - eps: Term added to the denominator to improve numerical stability. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Infinite moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - count_inc = numerics.safe_int32_increment(state.count) - mu = update_moment(updates, state.mu, b1, 1) - nu = update_infinity_moment(updates, state.nu, b2, eps) - # Bias correction for mean. No bias correction needed for infinity moment. - mu_hat = bias_correction(mu, b1, count_inc) - updates = jax.tree_util.tree_map(lambda m, v: m / v, mu_hat, nu) - return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByEveState(NamedTuple): - """State for the Eve algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: base.Updates - nu: base.Updates - d: float - f_prev: float - - -def scale_by_eve(b1: float = 0.9, - b2: float = 0.999, - b3: float = 0.999, - c: float = 10., - eps: float = 1e-8, - f_star: float = 0., - mu_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """Rescale updates according to the Eve algorithm. - - References: - [Hayashi et al, 2018](https://arxiv.org/abs/1611.01505) - - Args: - b1: the exponential decay rate to track the first moment of past gradients. - b2: the exponential decay rate to track the second moment of past gradients. - b3: the exponential decay rate to track the sub-optimality. - c: the clipping limit to prevent extreme global learning rate changes - eps: a small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - f_star: estimation of the global minimum - mu_dtype: optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - - Returns: - An (init_fn, update_fn) tuple. - """ - mu_dtype = utils.canonicalize_dtype(mu_dtype) - - def init_fn(params): - mu = jax.tree_util.tree_map( # First moment - lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) - nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=1.) - - def update_fn(updates: base.Updates, state: ScaleByEveState, f: float): - """ - Eve requires an additional parameter: the loss for the current iteration: f = f_t - ScaleByEveState holds the loss from the previous iteration: state.f_prev = f_{t-1} - """ - mu = update_moment(updates, state.mu, b1, 1) - nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) - count_inc = utils.numerics.safe_int32_increment(state.count) - mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) - nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) - if count_inc > 1: - d_new = jnp.abs(f - state.f_prev) / (jnp.min(jnp.array([f,state.f_prev])) - f_star) - d_tilde = jnp.clip(d_new,1/c,c) - d = b3*state.d + (1-b3)*d_tilde - else: - d = 1. - updates = jax.tree_util.tree_map( - lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) - mu = utils.cast_tree(mu, mu_dtype) - return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=f) - - return base.GradientTransformation(init_fn, update_fn) - - -ScaleState = base.EmptyState - - -def scale( - step_size: float -) -> base.GradientTransformation: - """Scale updates by some fixed scalar `step_size`. - - Args: - step_size: A scalar corresponding to a fixed scaling factor for updates. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return ScaleState() - - def update_fn(updates, state, params=None): - del params - updates = jax.tree_util.tree_map(lambda g: step_size * g, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_param_block_norm( - min_scale: float = 1e-3 -) -> base.GradientTransformation: - """Scale updates for each param block by the norm of that block's parameters. - - A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix - (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. - - Args: - min_scale: Minimum scaling factor. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return base.EmptyState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - updates = jax.tree_util.tree_map( - lambda u, p: u * numerics.safe_norm(p, min_scale), - updates, params) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_param_block_rms( - min_scale: float = 1e-3 -) -> base.GradientTransformation: - """Scale updates by rms of the gradient for each param vector or matrix. - - A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix - (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. - - Args: - min_scale: Minimum scaling factor. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return base.EmptyState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - updates = jax.tree_util.tree_map( - lambda u, p: u * numerics.safe_root_mean_squares(p, min_scale), - updates, params) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByBeliefState(NamedTuple): - """State for the rescaling by AdaBelief algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: base.Updates - nu: base.Updates - - -def scale_by_belief( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-16, - eps_root: float = 1e-16 -) -> base.GradientTransformation: - """Rescale updates according to the AdaBelief algorithm. - - References: - [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of variance of grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the second moment of the prediction error to - improve numerical stability. If backpropagating gradients through the - gradient transformation (e.g. for meta-learning), this must be non-zero. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment - s = jax.tree_util.tree_map(jnp.zeros_like, params) # Second Central moment - return ScaleByBeliefState(count=jnp.zeros([], jnp.int32), mu=mu, nu=s) - - def update_fn(updates, state, params=None): - del params - mu = update_moment(updates, state.mu, b1, 1) - prediction_error = jax.tree_util.tree_map( - lambda g, m: g-m, updates, state.mu) - nu = update_moment_per_elem_norm(prediction_error, state.nu, b2, 2) - nu = jax.tree_util.tree_map(lambda v: v + eps_root, nu) - count_inc = numerics.safe_int32_increment(state.count) - mu_hat = bias_correction(mu, b1, count_inc) - nu_hat = bias_correction(nu, b2, count_inc) - updates = jax.tree_util.tree_map( - lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat) - return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_yogi( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-3, - eps_root: float = 0.0, - initial_accumulator_value: float = 1e-6 -) -> base.GradientTransformation: - """Rescale updates according to the Yogi algorithm. - - Supports complex numbers, see - https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 - - References: - [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of variance of grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - initial_accumulator_value: The starting value for accumulators. - Only positive values are allowed. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - value_like = lambda p: jnp.full_like(p, initial_accumulator_value) - mu = jax.tree_util.tree_map(value_like, params) # First moment - nu = jax.tree_util.tree_map(value_like, params) # Second Central moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = update_moment(updates, state.mu, b1, 1) - nu = jax.tree_util.tree_map( - lambda g, v: v - (1 - b2) * jnp.sign(v - _abs_sq(g)) * _abs_sq(g), - updates, state.nu) - count_inc = numerics.safe_int32_increment(state.count) - mu_hat = bias_correction(mu, b1, count_inc) - nu_hat = bias_correction(nu, b2, count_inc) - updates = jax.tree_util.tree_map( - lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_radam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - threshold: float = 5.0 -) -> base.GradientTransformation: - """Rescale updates according to the Rectified Adam algorithm. - - References: - [Liu et al, 2020](https://arxiv.org/abs/1908.03265) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - threshold: Threshold for variance tractability. - - Returns: - A `GradientTransformation` object. - """ - - ro_inf = 2./(1 - b2) - 1 - def _radam_update(params): - ro = params[0] - mu_hat = params[1] - nu_hat = params[2] - r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro)) - updates = jax.tree_util.tree_map( - lambda m, v: r*m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) - return updates - - def init_fn(params): - mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = update_moment(updates, state.mu, b1, 1) - nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) - count_inc = numerics.safe_int32_increment(state.count) - b2t = b2**count_inc - ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) - mu_hat = bias_correction(mu, b1, count_inc) - nu_hat = bias_correction(nu, b2, count_inc) - updates = jax.lax.cond( - ro >= threshold, _radam_update, lambda _: mu_hat, - (ro, mu_hat, nu_hat)) - return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -AddDecayedWeightsState = base.EmptyState - - -def add_decayed_weights( - weight_decay: float = 0.0, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None -) -> base.GradientTransformation: - """Add parameter scaled by `weight_decay`. - - Args: - weight_decay: A scalar weight decay rate. - mask: A tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the transformation to, and `False` for those you want to skip. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return AddDecayedWeightsState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - updates = jax.tree_util.tree_map( - lambda g, p: g + weight_decay * p, updates, params) - return updates, state - - # If mask is not `None`, apply mask to the gradient transformation. - # E.g. it is common to skip weight decay on bias units and batch stats. - if mask is not None: - return wrappers.masked( - base.GradientTransformation(init_fn, update_fn), mask) - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByScheduleState(NamedTuple): - """Maintains count for scale scheduling.""" - count: chex.Array # shape=(), dtype=jnp.int32 - - -def scale_by_schedule( - step_size_fn: base.Schedule -) -> base.GradientTransformation: - """Scale updates using a custom schedule for the `step_size`. - - Args: - step_size_fn: A function that takes an update count as input and proposes - the step_size to multiply the updates by. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return ScaleByScheduleState(count=jnp.zeros([], jnp.int32)) - - def update_fn(updates, state, params=None): - del params - step_size = step_size_fn(state.count) - updates = jax.tree_util.tree_map( - lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates) - return updates, ScaleByScheduleState( - count=numerics.safe_int32_increment(state.count)) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByTrustRatioState(NamedTuple): - """The scale and decay trust ratio transformation is stateless.""" - - -def scale_by_trust_ratio( - min_norm: float = 0.0, - trust_coefficient: float = 1., - eps: float = 0., -) -> base.GradientTransformation: - """Scale updates by trust ratio`. - - References: - [You et. al 2020](https://arxiv.org/abs/1904.00962) - - Args: - min_norm: Minimum norm for params and gradient norms; by default is zero. - trust_coefficient: A multiplier for the trust ratio. - eps: Additive constant added to the denominator for numerical stability. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return ScaleByTrustRatioState() - - def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) - - def _scale_update(update, param): - - # Clip norms to minimum value, by default no clipping. - param_norm = numerics.safe_norm(param, min_norm) - update_norm = numerics.safe_norm(update, min_norm) - trust_ratio = trust_coefficient * param_norm / (update_norm + eps) - - # If no minimum norm clipping is used - # Set trust_ratio to 1 in case where parameters would never be updated. - zero_norm = jnp.logical_or(param_norm == 0., update_norm == 0.) - safe_trust_ratio = jnp.where( - zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio) - - return update * safe_trust_ratio - - updates = jax.tree_util.tree_map(_scale_update, updates, params) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class AddNoiseState(NamedTuple): - """State for adding gradient noise. Contains a count for annealing.""" - count: chex.Array - rng_key: chex.PRNGKey - - -def add_noise( - eta: float, - gamma: float, - seed: int -) -> base.GradientTransformation: - """Add gradient noise. - - References: - [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807) - - Args: - eta: Base variance of the gaussian noise added to the gradient. - gamma: Decay exponent for annealing of the variance. - seed: Seed for random number generation. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return AddNoiseState( - count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed)) - - def update_fn(updates, state, params=None): # pylint: disable=missing-docstring - del params - num_vars = len(jax.tree_util.tree_leaves(updates)) - treedef = jax.tree_util.tree_structure(updates) - count_inc = numerics.safe_int32_increment(state.count) - variance = eta / count_inc**gamma - standard_deviation = jnp.sqrt(variance) - all_keys = jax.random.split(state.rng_key, num=num_vars + 1) - noise = jax.tree_util.tree_map( - lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), - updates, jax.tree_util.tree_unflatten(treedef, all_keys[1:])) - updates = jax.tree_util.tree_map( - lambda g, n: g + standard_deviation.astype(g.dtype) * n, - updates, noise) - return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0]) - - return base.GradientTransformation(init_fn, update_fn) - - -class ApplyEvery(NamedTuple): - """Contains a counter and a gradient accumulator.""" - count: chex.Array - grad_acc: base.Updates - - -def apply_every( - k: int = 1 -) -> base.GradientTransformation: - """Accumulate gradients and apply them every k steps. - - Note that if this transformation is part of a chain, the states of the other - transformations will still be updated at every step. In particular, using - `apply_every` with a batch size of N/2 and k=2 is not necessarily equivalent - to not using `apply_every` with a batch size of N. If this equivalence is - important for you, consider using the `optax.MultiSteps`. - - Args: - k: Emit non-zero gradients every k steps, otherwise accumulate them. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - grad_acc = jax.tree_util.tree_map(jnp.zeros_like, params) - return ApplyEvery(count=jnp.zeros([], jnp.int32), grad_acc=grad_acc) - - def update_fn(updates, state, params=None): - del params - c = state.count % k - acc = c != 0 - grad_acc = jax.tree_util.tree_map( - lambda g, ga: acc * ga + g, updates, state.grad_acc) - emit = c == (k - 1) - updates = jax.tree_util.tree_map(lambda ga: emit * ga, grad_acc) - count_inc = numerics.safe_int32_increment(state.count) - return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc) - - return base.GradientTransformation(init_fn, update_fn) - - -def _subtract_mean(g): - if len(g.shape) > 1: - return g - g.mean(tuple(range(1, len(g.shape))), keepdims=True) - else: - return g - - -CentralState = base.EmptyState - - -def centralize() -> base.GradientTransformation: - """Centralize gradients. - - References: - [Yong et al, 2020](https://arxiv.org/abs/2004.01461) - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - del params - return CentralState() - - def update_fn(updates, state, params=None): - del params - updates = jax.tree_util.tree_map(_subtract_mean, updates) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleBySM3State(NamedTuple): - """State for the SM3 algorithm.""" - mu: base.Updates - nu: base.Updates - - -def scale_by_sm3( - b1: float = 0.9, - b2: float = 1.0, - eps: float = 1e-8 -) -> base.GradientTransformation: - """Scale updates by sm3`. - - References: - [Anil et. al 2019](https://arxiv.org/abs/1901.11150) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - - Returns: - A `GradientTransformation` object. - """ - - def zeros_for_dim(p): - return [jnp.zeros([s]) for s in p.shape] - - def init_fn(params): - _reject_complex(params) - mu = jax.tree_util.tree_map(zeros_for_dim, params) - nu = jax.tree_util.tree_map(jnp.zeros_like, params) - return ScaleBySM3State(mu, nu) - - def _expanded_shape(shape, axis): - # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. - # For eg: i = 1 returns [1, N, 1]. - rank = len(shape) - return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) - - def _new_accum(g, v): - coeffs = ((1.0 - b2) if b2 != 1.0 else 1.0, b2) - if g.ndim < 2: - return coeffs[0]*g**2 + coeffs[1]*v[0] - else: - return coeffs[0]*g**2 + coeffs[1]*functools.reduce(jnp.minimum, v) - - def _new_mu(g, i): - if g.ndim < 2: - return g - else: - return jnp.max(g, axis=other_axes(i, g.ndim)) - - def other_axes(idx, ndim): - return list(range(idx)) + list(range(idx+1, ndim)) - - def update_fn(updates, state, params=None): - del params - mu = jax.tree_util.tree_map( - lambda g, v: # pylint:disable=g-long-lambda - [jnp.reshape(v[i], _expanded_shape(g.shape, i)) for i in range(g.ndim)], - updates, state.mu) - accum = jax.tree_util.tree_map(_new_accum, updates, mu) - accum_inv_sqrt = jax.tree_util.tree_map( - lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), accum) - up = jax.tree_util.tree_map(lambda g, a: g*a, updates, accum_inv_sqrt) - nu = update_moment(up, state.nu, b1, 1) - mu = jax.tree_util.tree_map( - lambda g: [_new_mu(g, i) for i in range(g.ndim)], accum) - - return nu, ScaleBySM3State(mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -class ScaleByNovogradState(NamedTuple): - """State for Novograd.""" - count: chex.Array - mu: base.Updates - nu: base.Updates - - -def scale_by_novograd( - b1: float = 0.9, - b2: float = 0.25, - eps: float = 1e-8, - eps_root: float = 0.0, - weight_decay: float = 0.0, - mu_dtype: Optional[Any] = None, -) -> base.GradientTransformation: - """Computes NovoGrad updates. - - References: - [Ginsburg et al, 2019](https://arxiv.org/abs/1905.11286) - - Args: - b1: A decay rate for the exponentially weighted average of grads. - b2: A decay rate for the exponentially weighted average of squared grads. - eps: A term added to the denominator to improve numerical stability. - eps_root: A term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - weight_decay: A scalar weight decay rate. - mu_dtype: An optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype is inferred from `params` and `updates`. - - Returns: - The corresponding `GradientTransformation`. - """ - - mu_dtype = utils.canonicalize_dtype(mu_dtype) - - def init_fn(params): - mu = jax.tree_util.tree_map( # First moment - lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) - nu = jax.tree_util.tree_map(lambda _: 0.0, params) # Second moment - return ScaleByNovogradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def nu_addition(grads): - return jnp.linalg.norm(grads)**2 - - def mu_addition(grads, params, nu): - return grads / (jnp.sqrt(nu + eps_root) + eps) + weight_decay * params - - def init_nu(grads, nu): - del nu - return jax.tree_util.tree_map(nu_addition, grads) - - def update_nu(grads, nu): - updates = jax.tree_util.tree_map(nu_addition, grads) - return update_moment(updates, nu, b2, 1) - - def init_mu(grads, params, mu, nu): - del mu - return jax.tree_util.tree_map(mu_addition, grads, params, nu) - - def update_mu(grads, params, mu, nu): - updates = jax.tree_util.tree_map(mu_addition, grads, params, nu) - return jax.tree_util.tree_map(lambda m, u: b1 * m + u, mu, updates) - - # Second moment - def update_fn(updates, state, params): - count_inc = numerics.safe_int32_increment(state.count) - - nu = jax.lax.cond(count_inc == 1, init_nu, update_nu, updates, state.nu) - - mu = jax.lax.cond(count_inc == 1, init_mu, update_mu, updates, params, - state.mu, nu) - - mu = utils.cast_tree(mu, mu_dtype) - updates = mu - return updates, ScaleByNovogradState(count=count_inc, mu=mu, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -def scale_by_optimistic_gradient(alpha: float = 1.0, - beta: float = 1.0 - ) -> base.GradientTransformation: - """Compute generalized optimistic gradients. - - References: - [Mokhtari et al, 2019](https://arxiv.org/abs/1901.08511v2) - - Args: - alpha: Coefficient for generalized optimistic gradient descent. - beta: Coefficient for negative momentum. - - Returns: - A `GradientTransformation` object. - """ - - def init_fn(params): - prev_grads = jax.tree_util.tree_map(jnp.zeros_like, params) - return TraceState(trace=prev_grads) - - def update_fn(updates, state, params=None): - del params - - new_updates = jax.tree_util.tree_map( - lambda grad_t, grad_tm1: (alpha + beta) * grad_t - beta * grad_tm1, - updates, state.trace) - return new_updates, TraceState(trace=updates) - - return base.GradientTransformation(init_fn, update_fn) - - -# TODO(b/183800387): remove legacy aliases. -# These legacy aliases are here for checkpoint compatibility -# To be removed once checkpoints have updated. -_safe_int32_increment = numerics.safe_int32_increment -safe_int32_increment = numerics.safe_int32_increment -AdditiveWeightDecayState = AddDecayedWeightsState -additive_weight_decay = add_decayed_weights -ClipState = clipping.ClipState -ClipByGlobalNormState = clipping.ClipByGlobalNormState diff --git a/optax_add_eve/_src/transform_test.py b/optax_add_eve/_src/transform_test.py deleted file mode 100644 index 8218c2d9..00000000 --- a/optax_add_eve/_src/transform_test.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -"""Tests for `transform.py`.""" - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import jax -import jax.numpy as jnp -import numpy as np - -from optax_add_eve._src import alias -from optax_add_eve._src import combine -from optax_add_eve._src import transform -from optax_add_eve._src import update - -STEPS = 50 -LR = 1e-2 - - -class TransformTest(parameterized.TestCase): - - def setUp(self): - super().setUp() - self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.])) - self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) - - @chex.all_variants - @parameterized.named_parameters([ - ('adam', transform.scale_by_adam), - ('adamax', transform.scale_by_adamax), - ('rmsprop', transform.scale_by_rms), - ('stddev', transform.scale_by_stddev), - ('trust_ratio', transform.scale_by_trust_ratio), - ('param_block_norm', transform.scale_by_param_block_norm), - ('param_block_rms', transform.scale_by_param_block_rms), - ]) - def test_scalers(self, scaler_constr): - params = self.init_params - - scaler = scaler_constr() - init_fn = self.variant(scaler.init) - transform_fn = self.variant(scaler.update) - - state = init_fn(params) - chex.assert_tree_all_finite(state) - - updates, state = transform_fn(self.per_step_updates, state, params) - chex.assert_tree_all_finite((params, updates, state)) - jax.tree_util.tree_map( - lambda *args: chex.assert_equal_shape(args), params, updates) - - @chex.all_variants - def test_add_decayed_weights(self): - # Define a transform that add decayed weights. - # We can define a mask either as a pytree, or as a function that - # returns the pytree. Below we define the pytree directly. - mask = (True, dict(a=True, b=False)) - tx = transform.add_decayed_weights(0.1, mask=mask) - # Define input updates and weights. - updates = ( - jnp.zeros((2,), dtype=jnp.float32), - dict( - a=jnp.zeros((2,), dtype=jnp.float32), - b=jnp.zeros((2,), dtype=jnp.float32),)) - weights = ( - jnp.ones((2,), dtype=jnp.float32), - dict( - a=jnp.ones((2,), dtype=jnp.float32), - b=jnp.ones((2,), dtype=jnp.float32),)) - # This mask means that we will add decayed weights to the first two - # terms in the input updates, but not to the last element. - expected_tx_updates = ( - 0.1*jnp.ones((2,), dtype=jnp.float32), - dict( - a=0.1*jnp.ones((2,), dtype=jnp.float32), - b=jnp.zeros((2,), dtype=jnp.float32),)) - # Apply transform - state = tx.init(weights) - transform_fn = self.variant(tx.update) - new_updates, _ = transform_fn(updates, state, weights) - # Assert output as expected. - chex.assert_trees_all_close(new_updates, expected_tx_updates) - - @chex.all_variants - def test_ema(self): - values = jnp.array([5.0, 7.0]) - decay = 0.9 - d = decay - - ema = transform.ema(decay=decay, debias=False) - state = ema.init(values[0]) # init to zeroes - - transform_fn = self.variant(ema.update) - mean, state = transform_fn(values[0], state) - np.testing.assert_allclose(mean, (1-d) * values[0], atol=1e-4) - - mean, state = transform_fn(values[1], state) - np.testing.assert_allclose( - mean, - (1 - d) * (values[1] + d * values[0]), atol=1e-2) - - @chex.all_variants - def test_ema_debias(self): - values = jnp.array([5.0, 7.0]) - decay = 0.9 - d = decay - - ema = transform.ema(decay=decay) - state = ema.init(values[0]) - - transform_fn = self.variant(ema.update) - mean, state = transform_fn(values[0], state) - np.testing.assert_allclose(mean, values[0], atol=1e-4) - - mean, state = transform_fn(values[1], state) - np.testing.assert_allclose( - mean, - ((1 - d) * values[1] + d * (1 - d) * values[0]) / (1 - d**2), - atol=1e-2) - # The state must not be debiased. - np.testing.assert_allclose( - state.ema, - (1 - d) * values[1] + d * (1 - d) * values[0], - atol=1e-2) - - @chex.all_variants - def test_update_infinity_moment(self): - values = jnp.array([5.0, 7.0]) - decay = 0.9 - d = decay - - transform_fn = self.variant(transform.update_infinity_moment) - - # identity if updating with itself (and positive decay) - np.testing.assert_allclose( - transform_fn(values, values, decay=d, eps=0.), - values, - atol=1e-4 - ) - # return (decayed) max when updating with zeros - np.testing.assert_allclose( - transform_fn(jnp.zeros_like(values), values, decay=d, eps=0.), - d * values, - atol=1e-4 - ) - # infinity norm takes absolute values - np.testing.assert_allclose( - transform_fn(-values, jnp.zeros_like(values), decay=d, eps=0.), - values, - atol=1e-4 - ) - # return at least `eps` - np.testing.assert_allclose( - transform_fn(jnp.zeros_like(values), jnp.zeros_like(values), - decay=d, eps=1e-2), - jnp.ones_like(values) * 1e-2, - atol=1e-4 - ) - - @chex.all_variants - def test_apply_every(self): - # The frequency of the application of sgd - k = 4 - zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) - - # optax sgd - optax_sgd_params = self.init_params - sgd = alias.sgd(LR, 0.0) - state_sgd = sgd.init(optax_sgd_params) - - # optax sgd plus apply every - optax_sgd_apply_every_params = self.init_params - sgd_apply_every = combine.chain( - transform.apply_every(k=k), - transform.trace(decay=0, nesterov=False), - transform.scale(-LR)) - state_sgd_apply_every = sgd_apply_every.init(optax_sgd_apply_every_params) - transform_fn = self.variant(sgd_apply_every.update) - - for i in range(STEPS): - # Apply a step of sgd - updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) - optax_sgd_params = update.apply_updates(optax_sgd_params, updates_sgd) - - # Apply a step of sgd_apply_every - updates_sgd_apply_every, state_sgd_apply_every = transform_fn( - self.per_step_updates, state_sgd_apply_every) - optax_sgd_apply_every_params = update.apply_updates( - optax_sgd_apply_every_params, updates_sgd_apply_every) - - # Every k steps, check equivalence. - if i % k == k-1: - chex.assert_trees_all_close( - optax_sgd_apply_every_params, optax_sgd_params, - atol=1e-6, rtol=1e-5) - # Otherwise, check update is zero. - else: - chex.assert_trees_all_close( - updates_sgd_apply_every, zero_update, atol=0.0, rtol=0.0) - - def test_scale(self): - updates = self.per_step_updates - for i in range(1, STEPS + 1): - factor = 0.1 ** i - rescaler = transform.scale(factor) - # Apply rescaling. - scaled_updates, _ = rescaler.update(updates, None) - # Manually scale updates. - def rescale(t): - return t * factor # pylint:disable=cell-var-from-loop - manual_updates = jax.tree_util.tree_map(rescale, updates) - # Check the rescaled updates match. - chex.assert_trees_all_close(scaled_updates, manual_updates) - - @parameterized.named_parameters([ - ('1d', [1.0, 2.0], [1.0, 2.0]), - ('2d', [[1.0, 2.0], [3.0, 4.0]], [[-0.5, 0.5], [-0.5, 0.5]]), - ('3d', [[[1., 2.], [3., 4.]], - [[5., 6.], [7., 8.]]], [[[-1.5, -0.5], [0.5, 1.5]], - [[-1.5, -0.5], [0.5, 1.5]]]), - ]) - def test_centralize(self, inputs, outputs): - inputs = jnp.asarray(inputs) - outputs = jnp.asarray(outputs) - centralizer = transform.centralize() - centralized_inputs, _ = centralizer.update(inputs, None) - chex.assert_trees_all_close(centralized_inputs, outputs) - - @chex.all_variants - def test_add_noise_has_correct_variance_scaling(self): - # Prepare to compare noise with a rescaled unit-variance substitute. - eta = 0.3 - gamma = 0.55 - seed = 314 - noise = transform.add_noise(eta, gamma, seed) - noise_unit = transform.add_noise(1.0, 0.0, seed) - - params = self.init_params - state = noise.init(params) - state_unit = noise_unit.init(params) - - # Check the noise itself by adding it to zeros. - updates = jax.tree_util.tree_map(jnp.zeros_like, params) - - for i in range(1, STEPS + 1): - updates_i, state = self.variant(noise.update)(updates, state) - updates_i_unit, state_unit = noise_unit.update(updates, state_unit) - - scale = jnp.sqrt(eta / i**gamma) - - updates_i_rescaled = jax.tree_util.tree_map( - lambda g, s=scale: g * s, updates_i_unit) - - chex.assert_trees_all_close(updates_i, updates_i_rescaled, rtol=1e-4) - - def test_scale_by_optimistic_gradient(self): - - def f(params: jnp.ndarray) -> jnp.ndarray: - return params['x'] ** 2 - - initial_params = { - 'x': jnp.array(2.0) - } - - og = transform.scale_by_optimistic_gradient() - og_state = og.init(initial_params) - # Provide some arbitrary previous gradient. - og_state.trace['x'] = 1.5 - - g = jax.grad(f)(initial_params) - og_true = 2 * g['x'] - og_state.trace['x'] - og, og_state = og.update(g, og_state) - - # Compare transformation output with manually computed optimistic gradient. - chex.assert_trees_all_close(og_true, og['x']) - - @chex.all_variants - def test_bias_correction_bf16(self): - bias_correction_fn = self.variant(transform.bias_correction) - m = jnp.logspace(-10, 10, num=21, dtype=jnp.bfloat16) # 1e-10 ... 1e10 - for decay in (0.9, 0.99, 0.999, 0.9995): - for count in (1, 10, 100, 1000): - chex.assert_tree_all_finite( - bias_correction_fn(m, decay, count), - custom_message=f'failed with decay={decay}, count={count}') - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/update.py b/optax_add_eve/_src/update.py deleted file mode 100644 index ad88eee8..00000000 --- a/optax_add_eve/_src/update.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Apply transformed gradient updates to parameters.""" - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import base - - -def apply_updates(params: base.Params, updates: base.Updates) -> base.Params: - """Applies an update to the corresponding parameters. - - This is a utility functions that applies an update to a set of parameters, and - then returns the updated parameters to the caller. As an example, the update - may be a gradient transformed by a sequence of`GradientTransformations`. This - function is exposed for convenience, but it just adds updates and parameters; - you may also apply updates to parameters manually, using `tree_map` - (e.g. if you want to manipulate updates in custom ways before applying them). - - Args: - params: a tree of parameters. - updates: a tree of updates, the tree structure and the shape of the leaf - nodes must match that of `params`. - - Returns: - Updated parameters, with same structure, shape and type as `params`. - """ - return jax.tree_util.tree_map( - lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), - params, updates) - - -def incremental_update( - new_tensors: base.Params, - old_tensors: base.Params, - step_size: chex.Numeric -) -> base.Params: - """Incrementally update parameters via polyak averaging. - - Polyak averaging tracks an (exponential moving) average of the past - parameters of a model, for use at test/evaluation time. - - References: - [Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046) - - Args: - new_tensors: the latest value of the tensors. - old_tensors: a moving average of the values of the tensors. - step_size: the step_size used to update the polyak average on each step. - - Returns: - an updated moving average `step_size*new+(1-step_size)*old` of the params. - """ - return jax.tree_util.tree_map( - lambda new, old: step_size * new + (1.0 - step_size) * old, - new_tensors, old_tensors) - - -def periodic_update( - new_tensors: base.Params, - old_tensors: base.Params, - steps: chex.Array, - update_period: int -) -> base.Params: - """Periodically update all parameters with new values. - - A slow copy of a model's parameters, updated every K actual updates, can be - used to implement forms of self-supervision (in supervised learning), or to - stabilise temporal difference learning updates (in reinforcement learning). - - References: - [Grill et al., 2020](https://arxiv.org/abs/2006.07733) - [Mnih et al., 2015](https://arxiv.org/abs/1312.5602) - - Args: - new_tensors: the latest value of the tensors. - old_tensors: a slow copy of the model's parameters. - steps: number of update steps on the "online" network. - update_period: every how many steps to update the "target" network. - - Returns: - a slow copy of the model's parameters, updated every `update_period` steps. - """ - return jax.lax.cond( - jnp.mod(steps, update_period) == 0, - lambda _: new_tensors, - lambda _: old_tensors, - None) - diff --git a/optax_add_eve/_src/update_test.py b/optax_add_eve/_src/update_test.py deleted file mode 100644 index 73f57128..00000000 --- a/optax_add_eve/_src/update_test.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `update.py`.""" - -from absl.testing import absltest - -import chex -import jax -import jax.numpy as jnp - -from optax_add_eve._src import update - - -class UpdateTest(chex.TestCase): - - @chex.all_variants - def test_apply_updates(self): - params = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) - grads = jax.tree_util.tree_map(lambda t: 2 * t, params) - exp_params = jax.tree_util.tree_map(lambda t: 3 * t, params) - new_params = self.variant(update.apply_updates)(params, grads) - - chex.assert_trees_all_close( - exp_params, new_params, atol=1e-10, rtol=1e-5) - - @chex.all_variants - def test_apply_updates_mixed_precision(self): - params = ( - {'a': jnp.ones((3, 2), dtype=jnp.bfloat16)}, - jnp.ones((1,), dtype=jnp.bfloat16)) - grads = jax.tree_util.tree_map( - lambda t: (2 * t).astype(jnp.float32), params) - new_params = self.variant(update.apply_updates)(params, grads) - - for leaf in jax.tree_util.tree_leaves(new_params): - assert leaf.dtype == jnp.bfloat16 - - @chex.all_variants - def test_incremental_update(self): - params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) - params_2 = jax.tree_util.tree_map(lambda t: 2 * t, params_1) - exp_params = jax.tree_util.tree_map(lambda t: 1.5 * t, params_1) - new_params = self.variant( - update.incremental_update)(params_2, params_1, 0.5) - - chex.assert_trees_all_close( - exp_params, new_params, atol=1e-10, rtol=1e-5) - - @chex.all_variants - def test_periodic_update(self): - params_1 = ({'a': jnp.ones((3, 2))}, jnp.ones((1,))) - params_2 = jax.tree_util.tree_map(lambda t: 2 * t, params_1) - - update_period = 5 - update_fn = self.variant(update.periodic_update) - - for j in range(3): - for i in range(1, update_period): - new_params = update_fn( - params_2, params_1, j*update_period+i, update_period) - chex.assert_trees_all_close( - params_1, new_params, atol=1e-10, rtol=1e-5) - - new_params = update_fn( - params_2, params_1, (j+1)*update_period, update_period) - chex.assert_trees_all_close( - params_2, new_params, atol=1e-10, rtol=1e-5) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/utils.py b/optax_add_eve/_src/utils.py deleted file mode 100644 index a61febff..00000000 --- a/optax_add_eve/_src/utils.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions for testing.""" - -from typing import Optional, Tuple, Sequence - -import chex -import jax -import jax.numpy as jnp -import jax.scipy.stats.norm as multivariate_normal - -from optax_add_eve._src import linear_algebra -from optax_add_eve._src import numerics - - -def tile_second_to_last_dim(a: chex.Array) -> chex.Array: - ones = jnp.ones_like(a) - a = jnp.expand_dims(a, axis=-1) - return jnp.expand_dims(ones, axis=-2) * a - - -def canonicalize_dtype( - dtype: Optional[chex.ArrayDType]) -> Optional[chex.ArrayDType]: - """Canonicalise a dtype, skip if None.""" - if dtype is not None: - return jax.dtypes.canonicalize_dtype(dtype) - return dtype - - -def cast_tree(tree: chex.ArrayTree, - dtype: Optional[chex.ArrayDType]) -> chex.ArrayTree: - """Cast tree to given dtype, skip if None.""" - if dtype is not None: - return jax.tree_util.tree_map(lambda t: t.astype(dtype), tree) - else: - return tree - - -def set_diags(a: chex.Array, new_diags: chex.Array) -> chex.Array: - """Set the diagonals of every DxD matrix in an input of shape NxDxD. - - Args: - a: rank 3, tensor NxDxD. - new_diags: NxD matrix, the new diagonals of each DxD matrix. - - Returns: - NxDxD tensor, with the same contents as `a` but with the diagonal - changed to `new_diags`. - """ - n, d, d1 = a.shape - assert d == d1 - - indices1 = jnp.repeat(jnp.arange(n), d) - indices2 = jnp.tile(jnp.arange(d), n) - indices3 = indices2 - - # Use numpy array setting - a = a.at[indices1, indices2, indices3].set(new_diags.flatten()) - return a - - -class MultiNormalDiagFromLogScale(): - """MultiNormalDiag which directly exposes its input parameters.""" - - def __init__(self, loc: chex.Array, log_scale: chex.Array): - self._log_scale = log_scale - self._scale = jnp.exp(log_scale) - self._mean = loc - self._param_shape = jax.lax.broadcast_shapes( - self._mean.shape, self._scale.shape) - - def sample(self, shape: Sequence[int], - seed: chex.PRNGKey) -> chex.Array: - sample_shape = tuple(shape) + self._param_shape - return jax.random.normal( - seed, shape=sample_shape) * self._scale + self._mean - - def log_prob(self, x: chex.Array) -> chex.Array: - log_prob = multivariate_normal.logpdf(x, loc=self._mean, scale=self._scale) - # Sum over parameter axes. - sum_axis = [-(i + 1) for i in range(len(self._param_shape))] - return jnp.sum(log_prob, axis=sum_axis) - - @property - def log_scale(self) -> chex.Array: - return self._log_scale - - @property - def params(self) -> Sequence[chex.Array]: - return [self._mean, self._log_scale] - - -def multi_normal(loc: chex.Array, - log_scale: chex.Array) -> MultiNormalDiagFromLogScale: - return MultiNormalDiagFromLogScale(loc=loc, log_scale=log_scale) - - -@jax.custom_vjp -def _scale_gradient(inputs: chex.ArrayTree, scale: float) -> chex.ArrayTree: - """Internal gradient scaling implementation.""" - del scale # Only used for the backward pass defined in _scale_gradient_bwd. - return inputs - - -def _scale_gradient_fwd(inputs: chex.ArrayTree, - scale: float) -> Tuple[chex.ArrayTree, float]: - return _scale_gradient(inputs, scale), scale - - -def _scale_gradient_bwd(scale: float, - g: chex.ArrayTree) -> Tuple[chex.ArrayTree, None]: - return (jax.tree_util.tree_map(lambda g_: g_ * scale, g), None) - - -_scale_gradient.defvjp(_scale_gradient_fwd, _scale_gradient_bwd) - - -def scale_gradient(inputs: chex.ArrayTree, scale: float) -> chex.ArrayTree: - """Scales gradients for the backwards pass. - - Args: - inputs: A nested array. - scale: The scale factor for the gradient on the backwards pass. - - Returns: - An array of the same structure as `inputs`, with scaled backward gradient. - """ - # Special case scales of 1. and 0. for more efficiency. - if scale == 1.: - return inputs - elif scale == 0.: - return jax.lax.stop_gradient(inputs) - else: - return _scale_gradient(inputs, scale) - - -# TODO(b/183800387): remove legacy aliases. -safe_norm = numerics.safe_norm -safe_int32_increment = numerics.safe_int32_increment -global_norm = linear_algebra.global_norm diff --git a/optax_add_eve/_src/utils_test.py b/optax_add_eve/_src/utils_test.py deleted file mode 100644 index 03f13d3d..00000000 --- a/optax_add_eve/_src/utils_test.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `utils.py`.""" - -from unittest import mock - -from absl.testing import absltest -from absl.testing import parameterized - -import jax - -from optax_add_eve._src import utils - - -class ScaleGradientTest(parameterized.TestCase): - - @parameterized.product(inputs=[-1., 0., 1.], scale=[-0.5, 0., 0.5, 1., 2.]) - @mock.patch.object(jax.lax, 'stop_gradient', wraps=jax.lax.stop_gradient) - def test_scale_gradient(self, mock_sg, inputs, scale): - - def fn(inputs): - outputs = utils.scale_gradient(inputs, scale) - return outputs ** 2 - - grad = jax.grad(fn) - self.assertEqual(grad(inputs), 2 * inputs * scale) - if scale == 0.: - mock_sg.assert_called_once_with(inputs) - else: - self.assertFalse(mock_sg.called) - self.assertEqual(fn(inputs), inputs ** 2) - - @parameterized.product(scale=[-0.5, 0., 0.5, 1., 2.]) - def test_scale_gradient_pytree(self, scale): - - def fn(inputs): - outputs = utils.scale_gradient(inputs, scale) - outputs = jax.tree_util.tree_map(lambda x: x ** 2, outputs) - return sum(jax.tree_util.tree_leaves(outputs)) - - inputs = dict(a=-1., b=dict(c=(2.,), d=0.)) - - grad = jax.grad(fn) - grads = grad(inputs) - jax.tree_util.tree_map( - lambda i, g: self.assertEqual(g, 2 * i * scale), inputs, grads) - self.assertEqual( - fn(inputs), - sum(jax.tree_util.tree_leaves( - jax.tree_util.tree_map(lambda x: x**2, inputs)))) - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/_src/wrappers.py b/optax_add_eve/_src/wrappers.py deleted file mode 100644 index 3ae66026..00000000 --- a/optax_add_eve/_src/wrappers.py +++ /dev/null @@ -1,547 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Transformation wrappers.""" - -import functools -from typing import Any, Callable, NamedTuple, Optional, Tuple, Union - -import chex -import jax -from jax import lax -import jax.numpy as jnp -from jax.tree_util import tree_flatten -from jax.tree_util import tree_map -from jax.tree_util import tree_unflatten -import numpy as np -from optax_add_eve._src import base -from optax_add_eve._src import numerics -import typing_extensions - -Array = jnp.ndarray - - -def flatten( - inner: base.GradientTransformation -) -> base.GradientTransformation: - """Flattens parameters and gradients for init and update of inner transform. - - This can reduce the overhead of performing many calculations on lots of small - variables, at the cost of slightly increased memory usage. - - Args: - inner: Inner transformation to flatten inputs for. - - Returns: - New GradientTransformation. - """ - - def _flatten(params): - """Flattens and concatenates all tensors in params to a single vector.""" - params, _ = tree_flatten(params) - return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) - - def _unflatten(updates, flat): - """Extracts tensors from flat, using the structure and shapes of params.""" - updates_flat, treedef = tree_flatten(updates) - offsets = [] - for update in updates_flat: - size = np.prod(update.shape) - if offsets: - offsets.append(size + offsets[-1]) - else: - offsets.append(size) - del offsets[-1] - flat_split = jnp.split(flat, offsets) - reshaped = [ - jnp.reshape(flat_update, update.shape) - for flat_update, update in zip(flat_split, updates_flat) - ] - return tree_unflatten(treedef, reshaped) - - def init_fn(params): - flat = _flatten(params) - return inner.init(flat) - - def update_fn(updates, state, params=None): - if params is not None: - params = _flatten(params) - updates_flat, state = inner.update(_flatten(updates), state, params) - updates = _unflatten(updates, updates_flat) - return updates, state - - return base.GradientTransformation(init_fn, update_fn) - - -class ApplyIfFiniteState(NamedTuple): - """State of the `GradientTransformation` returned by `apply_if_finite`. - - Fields: - notfinite_count: Number of consecutive gradient updates containing an Inf or - a NaN. This number is reset to 0 whenever a gradient update without an Inf - or a NaN is done. - last_finite: Whether or not the last gradient update contained an Inf of a - NaN. - total_notfinite: Total number of gradient updates containing an Inf or - a NaN since this optimizer was initialised. This number is never reset. - inner_state: The state of the inner `GradientTransformation`. - """ - notfinite_count: jnp.array - last_finite: jnp.array - total_notfinite: jnp.array - inner_state: Any - - -def apply_if_finite( - inner: base.GradientTransformation, - max_consecutive_errors: int -) -> base.GradientTransformation: - """A function that wraps an optimizer to make it robust to a few NaNs or Infs. - - The purpose of this function is to prevent any optimization to happen if the - gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the - gradients, the wrapped optimizer ignores that gradient update. If the NaNs or - Infs persist after a given number of updates, the wrapped optimizer gives up - and accepts the update. - - Args: - inner: Inner transformation to be wrapped. - max_consecutive_errors: Maximum number of consecutive gradient updates - containing NaNs of Infs that the wrapped optimizer will ignore. After - that many ignored updates, the optimizer will give up and accept. - - Returns: - New GradientTransformation. - """ - - def init(params): - return ApplyIfFiniteState( - notfinite_count=jnp.zeros([], jnp.int32), - last_finite=jnp.array(True, jnp.bool_), - total_notfinite=jnp.zeros([], jnp.int32), - inner_state=inner.init(params)) - - def update(updates, state, params=None): - inner_state = state.inner_state - flat_updates = tree_flatten(updates)[0] - isfinite = jnp.all( - jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) - notfinite_count = jnp.where( - isfinite, jnp.zeros([], jnp.int32), - numerics.safe_int32_increment(state.notfinite_count)) - - def do_update(_): - return inner.update(updates, inner_state, params) - def reject_update(_): - return (tree_map(jnp.zeros_like, updates), inner_state) - - updates, new_inner_state = lax.cond( - jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors), - do_update, reject_update, operand=None) - - return updates, ApplyIfFiniteState( - notfinite_count=notfinite_count, - last_finite=isfinite, - total_notfinite=jnp.where( - isfinite, state.total_notfinite, - numerics.safe_int32_increment(state.total_notfinite)), - inner_state=new_inner_state) - - return base.GradientTransformation(init=init, update=update) - - -def _zeros_tree_like(inp_tree): - return jax.tree_util.tree_map(jnp.zeros_like, inp_tree) - - -class MultiStepsState(NamedTuple): - """State of the `GradientTransformation` returned by `MultiSteps`. - - Fields: - mini_step: current mini-step counter. At an update, this either increases by - 1 or is reset to 0. - gradient_step: gradient step counter. This only increases after enough - mini-steps have been accumulated. - inner_opt_state: the state of the wrapped otpimiser. - acc_grads: accumulated gradients over multiple mini-steps. - skip_state: an arbitrarily nested tree of arrays. This is only - relevant when passing a `should_skip_update_fn` to `MultiSteps`. This - structure will then contain values for debugging and or monitoring. The - actual structure will vary depending on the choice of - `ShouldSkipUpdateFunction`. - """ - mini_step: Array - gradient_step: Array - inner_opt_state: Any - acc_grads: Any - skip_state: chex.ArrayTree = () - - -class ShouldSkipUpdateFunction(typing_extensions.Protocol): - - def __call__(self, updates: base.Updates, gradient_step: Array, - params: Optional[base.Params]) -> Tuple[Array, chex.ArrayTree]: - """Returns true to indicate that updates should be skipped in a multi-step. - - Args: - updates: The updates that the gradient transformation has proposed - to apply - gradient_step: The current gradient step (see - `MultiStepsState.gradient_step`). This can be used for example to reject - large gradients with an annealed maximum allowed gradient norm. - params: If known, the current parameter tree of the function being - transformed. - Returns: - A tuple: - * First element is an array with a single bool indicating whether or not - the updates should be applied. - * Second element is an arbitrarily nested structure of arrays that will be - stored in `MultiStepsState.skip_state`. The structure will vary from - function to function. Debugging info, or values to monitor, can be put - in this structure. - """ - - -def skip_not_finite( - updates: base.Updates, gradient_step: Array, - params: Optional[base.Params]) -> Tuple[Array, chex.ArrayTree]: - """Returns True iff any of the `updates` contains an inf or a NaN. - - Args: - updates: see `ShouldSkipUpdateFunction`. - gradient_step: see `ShouldSkipUpdateFunction`. - params: see `ShouldSkipUpdateFunction`. - - Returns: - A tuple: - * First element is a scalar array of type bool. - * Second element is a dictionary with keys: - - `should_skip`: True iff `updates` contains an inf or a NaN. - - `num_not_finite`: total number of inf and NaN found in `updates`. - """ - del gradient_step, params - all_is_finite = [jnp.sum(jnp.logical_not(jnp.isfinite(p))) - for p in jax.tree_util.tree_leaves(updates)] - num_not_finite = jnp.sum(jnp.array(all_is_finite)) - should_skip = num_not_finite > 0 - return should_skip, dict(should_skip=should_skip, - num_not_finite=num_not_finite) - - -def skip_large_updates(updates: base.Updates, - gradient_step: Array, - params: Optional[base.Params], - max_squared_norm: float) -> Tuple[Array, chex.ArrayTree]: - """Returns True if the global norm square of `updates` is small enough. - - Args: - updates: see `ShouldSkipUpdateFunction`. - gradient_step: see `ShouldSkipUpdateFunction`. - params: see `ShouldSkipUpdateFunction`. - max_squared_norm: only updates with a norm square strictly less than this - value will be accepted. - - Returns: - A tuple: - * First element is a scalar array of type bool. - * Second element is a dictionary with keys: - - `should_skip`: True iff square norm of `updates` is larger or equal than - `max_squared_norm`. - - `norm_squared`: overall norm square of the `updates`. - """ - del gradient_step, params - norm_sq = jnp.sum( - jnp.array([jnp.sum(p**2) for p in jax.tree_util.tree_leaves(updates)])) - # This will also return True if `norm_sq` is NaN. - should_skip = jnp.logical_not(norm_sq < max_squared_norm) - return should_skip, dict(should_skip=should_skip, norm_squared=norm_sq) - - -class MultiSteps: - """An optimizer wrapper to accumulate gradients over multiple steps. - - This wrapper collects together the updates passed to its `update` function - over consecutive steps until a given number of scheduled steps is reached. - In each of these intermediate steps, the returned value from the optimizer is - a tree of zeros of the same shape of the updates passed as input. - - Once the scheduled number of intermediate 'mini-steps' has been reached, the - gradients accumulated to the current time will be passed to the wrapped - optimizer's update function, (with the inner optimizer's state being updated - appropriately) and then returned to the caller. The wrapper's accumulated - gradients are then set back to zero and the process starts again. - - The number of mini-steps per gradient update is controlled by a function, and - it can vary over training. This offers a means of varying batch size over - training. - """ - - def __init__( - self, - opt: base.GradientTransformation, - every_k_schedule: Union[int, Callable[[Array], Array]], - use_grad_mean: bool = True, - should_skip_update_fn: Optional[ShouldSkipUpdateFunction] = None): - """Initialiser. - - Args: - opt: the wrapped optimizer. - every_k_schedule: an int or f a function. - * As a function, it returns how many mini-steps should be accumulated - in a single gradient step. Its only argument is the current - gradient step count. By varying the returned value, users can vary the - overall training batch size. - * If an `int`, this is the constant number of mini-steps per gradient - update. - use_grad_mean: if `True` (the default), gradients accumulated over - multiple mini-steps are averaged. Otherwise, they are summed. - should_skip_update_fn: if provided, this function is used to decide when - to accept or reject the updates from a mini-step. When a mini-step is - rejected, the inner state of `MultiSteps` is not updated. In other - words, it is as if this mini-step never happened. For example: - * to ignore updates containing inf or NaN, do - `should_skip_update_fn=skip_not_finite`; - * to ignore updates with a norm square larger then 42, do - `should_skip_update_fn=functools.partial(skip_large_updates, - max_norm_sq=42.)`. - Note that the optimizer's state `MultiStepsState` contains a field - `skip_state` in which debugging and monitoring information returned - by `should_skip_update_fn` is written. - """ - self._opt = opt - if isinstance(every_k_schedule, int): - self._every_k_schedule = lambda step: every_k_schedule - else: - self._every_k_schedule = every_k_schedule - self._use_grad_mean = use_grad_mean - - if self._use_grad_mean: - # Use Welford algorithm for numerically stable aggregation of mean. - self._acc_update = ( - lambda grad, acc, *, n_acc: acc + (grad - acc) / (n_acc + 1)) - else: - self._acc_update = lambda grad, acc, *, n_acc: grad + acc - - if should_skip_update_fn is None: - - def should_skip_update_fn(*unused_args, **unused_kwargs): - return jnp.array(False, dtype=jnp.bool_), () - - self._should_skip_update_fn = should_skip_update_fn - - @property - def inner_opt(self): - return self._opt - - def init(self, params: Any) -> MultiStepsState: - """Builds and returns initial `MultiStepsState`.""" - updates = _zeros_tree_like(params) - gradient_step = jnp.zeros([], dtype=jnp.int32) - _, skip_state = self._should_skip_update_fn(updates, gradient_step, params) - init_state = MultiStepsState( - mini_step=jnp.zeros([], dtype=jnp.int32), - gradient_step=gradient_step, - inner_opt_state=self._opt.init(params), - acc_grads=updates, - skip_state=skip_state) - return init_state - - def update(self, - updates: base.Updates, - state: MultiStepsState, - params: Optional[base.Params] = None - ) -> Tuple[base.Updates, MultiStepsState]: - """Accumulates gradients and proposes non-zero updates every `k_steps`.""" - k_steps = self._every_k_schedule(state.gradient_step) - acc_grads = jax.tree_util.tree_map( - functools.partial(self._acc_update, n_acc=state.mini_step), - updates, state.acc_grads) - - should_skip_update, skip_state = self._should_skip_update_fn( - updates, state.gradient_step, params) - - def final_step(args): - del args - final_updates, new_inner_state = self._opt.update( - acc_grads, state.inner_opt_state, params=params) - new_state = MultiStepsState( - mini_step=jnp.zeros([], dtype=jnp.int32), - gradient_step=numerics.safe_int32_increment(state.gradient_step), - inner_opt_state=new_inner_state, - acc_grads=_zeros_tree_like(acc_grads), - skip_state=skip_state) - return final_updates, new_state - - def mid_step(args): - del args - updates_shape_dtype, _ = jax.eval_shape( - self._opt.update, acc_grads, state.inner_opt_state, params=params) - mid_updates = jax.tree_util.tree_map( - lambda sd: jnp.zeros(sd.shape, sd.dtype), updates_shape_dtype) - new_state = MultiStepsState( - mini_step=numerics.safe_int32_increment(state.mini_step), - gradient_step=state.gradient_step, - inner_opt_state=state.inner_opt_state, - acc_grads=acc_grads, - skip_state=skip_state) - return mid_updates, new_state - - new_updates, new_state = jax.lax.cond( - state.mini_step < k_steps - 1, (), mid_step, (), final_step) - - if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()): - raise ValueError( - 'The `should_skip_update_fn` function should return a boolean scalar ' - f'array, but it returned an array of dtype {should_skip_update.dtype}' - f' and shape {should_skip_update.shape}') - - multi_state_when_skip = MultiStepsState( - mini_step=state.mini_step, - gradient_step=state.gradient_step, - inner_opt_state=state.inner_opt_state, - acc_grads=state.acc_grads, - skip_state=skip_state) - zero_updates = jax.tree_util.tree_map(jnp.zeros_like, updates) - new_updates, new_state = jax.lax.cond( - should_skip_update, - (), lambda args: (zero_updates, multi_state_when_skip), - (), lambda args: (new_updates, new_state)) - - return new_updates, new_state - - def has_updated(self, state: MultiStepsState) -> Array: - return jnp.logical_and(state.mini_step == 0, state.gradient_step > 0) - - def gradient_transformation(self) -> base.GradientTransformation: - return base.GradientTransformation(init=self.init, update=self.update) - - -class MaskedState(NamedTuple): - """Maintains inner transform state for masked transformations.""" - inner_state: Any - - -class MaskedNode(NamedTuple): - """A node used to mask out unspecified parts of a tree. - - This node is ignored when mapping functions across the tree e.g. using - `jax.tree_util.tree_map` since it is a container without children. It can - therefore be used to mask out parts of a tree. - """ - - -def masked( - inner: base.GradientTransformation, - mask: Union[base.PyTree, Callable[[base.Params], base.PyTree]] -) -> base.GradientTransformation: - """Mask updates so only some are transformed, the rest are passed through. - - For example, it is common to skip weight decay for BatchNorm scale and all - bias parameters. In many networks, these are the only parameters with only - one dimension. So, you may create a mask function to mask these out as - follows:: - - mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim != 1, p) - weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn) - - You may alternatively create the mask pytree upfront:: - - mask = jax.tree_util.tree_map(lambda x: x.ndim != 1, params) - weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask) - - For the ``inner`` transform, state will only be stored for the parameters that - have a mask value of ``True``. - - Args: - inner: Inner transformation to mask. - mask: a PyTree with same structure as (or a prefix of) the params PyTree, or - a Callable that returns such a pytree given the params/updates. The leaves - should be booleans, ``True`` for leaves/subtrees you want to apply the - transformation to, and ``False`` for those you want to skip. The mask must - be static for the gradient transformation to be jit-compilable. - - Returns: - New GradientTransformation wrapping ``inner``. - """ - def mask_pytree(pytree, mask_tree): - return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree) - - def init_fn(params): - mask_tree = mask(params) if callable(mask) else mask - masked_params = mask_pytree(params, mask_tree) - return MaskedState(inner_state=inner.init(masked_params)) - - def update_fn(updates, state, params=None): - mask_tree = mask(updates) if callable(mask) else mask - masked_updates = mask_pytree(updates, mask_tree) - masked_params = None if params is None else mask_pytree(params, mask_tree) - - new_masked_updates, new_inner_state = inner.update( - masked_updates, state.inner_state, masked_params) - - new_updates = tree_map( - lambda m, new_u, old_u: new_u if m else old_u, - mask_tree, new_masked_updates, updates) - return new_updates, MaskedState(inner_state=new_inner_state) - - return base.GradientTransformation(init_fn, update_fn) - - -class MaybeUpdateState(NamedTuple): - """Maintains inner transform state and adds a step counter.""" - inner_state: Any - step: Array - - -def maybe_update( - inner: base.GradientTransformation, - should_update_fn: Callable[[Array], Array] -) -> base.GradientTransformation: - """Calls the inner update function only at certain steps. - - Creates a transformation wrapper which counts the number of times the `update` - function has been called. This counter is passed to the `should_update_fn` to - decide when to call the inner update function. - - When not calling the inner update function, the `updates` and the inner state - are left untouched and just passed through. The step counter is increased - regardless. - - Args: - inner: the inner transformation. - should_update_fn: this function takes in a step counter (array of shape [] - and dtype int32), and returns a boolean array of shape []. - - Returns: - An `optax.GradientTransformation`. - """ - - def init_fn(params): - return MaybeUpdateState( - inner_state=inner.init(params), step=jnp.zeros([], dtype=jnp.int32)) - - def update_fn(updates, state, params=None): - - def do_update(_): - return inner.update(updates, state.inner_state, params) - - def reject_update(_): - return updates, state.inner_state - - updates, new_inner_state = lax.cond( - should_update_fn(state.step), do_update, reject_update, operand=None) - return updates, MaybeUpdateState(new_inner_state, - numerics.safe_int32_increment(state.step)) - - return base.GradientTransformation(init_fn, update_fn) diff --git a/optax_add_eve/_src/wrappers_test.py b/optax_add_eve/_src/wrappers_test.py deleted file mode 100644 index 1bdfa95f..00000000 --- a/optax_add_eve/_src/wrappers_test.py +++ /dev/null @@ -1,623 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `wrappers.py`.""" - -import copy - -from absl.testing import absltest -from absl.testing import parameterized - -import chex -import haiku as hk -import jax -import jax.numpy as jnp -import numpy as np -from optax_add_eve._src import alias -from optax_add_eve._src import combine -from optax_add_eve._src import constrain -from optax_add_eve._src import transform -from optax_add_eve._src import update -from optax_add_eve._src import wrappers -import tree - - -def _build_sgd(): - return alias.sgd(1.) - - -def _build_stateful_sgd(): - # This SGD behaves like _build_sgd but also tests the optimizer state. The - # momentum is set to zero rather than None so that the momentum terms are - # calculated, but do not change the results. - return alias.sgd(1., momentum=0.) - - -class WrappersTest(parameterized.TestCase): - - def test_flatten(self): - def init_params(): - return (jnp.array([1., 2.]), jnp.array([3., 4.])) - - per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.])) - - # First calculate new params without flattening - optax_sgd_params = init_params() - sgd = alias.sgd(1e-2, 0.0) - state_sgd = sgd.init(optax_sgd_params) - updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) - sgd_params_no_flatten = update.apply_updates(optax_sgd_params, updates_sgd) - - # And now calculate new params with flattening - optax_sgd_params = init_params() - sgd = wrappers.flatten(sgd) - state_sgd = sgd.init(optax_sgd_params) - updates_sgd, state_sgd = sgd.update(per_step_updates, state_sgd) - sgd_params_flatten = update.apply_updates(optax_sgd_params, updates_sgd) - - # Test that both give the same result - chex.assert_trees_all_close( - sgd_params_no_flatten, sgd_params_flatten, atol=1e-7, rtol=1e-7) - - @chex.variants(with_jit=True, without_jit=True, with_pmap=True) - @parameterized.named_parameters( - ('sgd', _build_sgd), - ('stateful_sgd', _build_stateful_sgd), - ) - def test_apply_if_finite(self, opt_builder): - one = jnp.ones([]) - nan = jnp.array(jnp.nan) - def fn(x): - return x * hk.get_parameter('p', [], init=hk.initializers.Constant(0.)) - - fn = hk.without_apply_rng(hk.transform(fn)) - params = fn.init(jax.random.PRNGKey(1905), one) - opt = wrappers.apply_if_finite(opt_builder(), 2) - state = opt.init(params) - grads_fn = jax.grad(self.variant(fn.apply)) - # Do one successful param update - grads = grads_fn(params, one) - updates, state = opt.update(grads, state, params) - params = update.apply_updates(params, updates) - # We know exactly what should be the value of params since we are - # effectively using sgd in all cases. - self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertTrue(bool(state.last_finite)) - # Check 2 rejected param updates - for step in range(2): - grads = grads_fn(params, nan) - updates, state = opt.update(grads, state, params) - params = update.apply_updates(params, updates) - self.assertEqual(-1., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertFalse(bool(state.last_finite)) - self.assertEqual(step + 1, int(state.notfinite_count)) - # Next successful param update - grads = grads_fn(params, one) - updates, state = opt.update(grads, state, params) - params = update.apply_updates(params, updates) - self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertTrue(bool(state.last_finite)) - # Again 2 rejected param updates - for step in range(2): - grads = grads_fn(params, nan) - updates, state = opt.update(grads, state, params) - params = update.apply_updates(params, updates) - self.assertEqual(-2., float(jax.tree_util.tree_flatten(params)[0][0])) - self.assertFalse(bool(state.last_finite)) - self.assertEqual(step + 1, int(state.notfinite_count)) - # Next param update with NaN is accepted since we reached maximum - grads = grads_fn(params, nan) - updates, state = opt.update(grads, state, params) - params = update.apply_updates(params, updates) - self.assertTrue(bool(jnp.isnan(jax.tree_util.tree_flatten(params)[0][0]))) - self.assertEqual(5, int(state.total_notfinite)) - - def test_apply_if_finite_pmap(self): - # Unlike in `test_apply_if_finite`: - # * pmap is applied to the gradient computation and the optimisation; - # * the NaNs are caused inside the function and do not come from the inputs. - half = jnp.ones([1]) / 2. - two = jnp.ones([1]) * 2. # Causes a NaN in arctanh - def fn(x): - return jnp.arctanh(x) * hk.get_parameter( - 'p', [], init=hk.initializers.Constant(0.)) - fn = hk.without_apply_rng(hk.transform(fn)) - - opt = wrappers.apply_if_finite(alias.sgd(1.), 2) - def fn_update(params, opt_state, x): - grads = jax.grad(fn.apply)(params, x) - grads = jax.lax.psum(grads, axis_name='i') - updates, new_opt_state = opt.update(grads, opt_state, params) - new_params = update.apply_updates(params, updates) - return new_params, new_opt_state - fn_update = jax.pmap(fn_update, axis_name='i') - - params = fn.init(jax.random.PRNGKey(1905), half) - opt_state = opt.init(params) - params = jax.tree_util.tree_map(lambda x: x[None], params) - opt_state = jax.tree_util.tree_map(lambda x: x[None], opt_state) - # Do one successful param update - params, opt_state = fn_update(params, opt_state, half) - self.assertTrue(bool(opt_state.last_finite)) - # Check 2 rejected param updates - for step in range(2): - params, opt_state = fn_update(params, opt_state, two) - self.assertFalse(bool(opt_state.last_finite)) - self.assertEqual(step + 1, int(opt_state.notfinite_count)) - # Next successful param update - params, opt_state = fn_update(params, opt_state, half) - self.assertTrue(bool(opt_state.last_finite)) - # Again 2 rejected param updates - for step in range(2): - params, opt_state = fn_update(params, opt_state, two) - self.assertFalse(bool(opt_state.last_finite)) - self.assertEqual(step + 1, int(opt_state.notfinite_count)) - # Next param update with NaN is accepted since we reached maximum - params, opt_state = fn_update(params, opt_state, two) - self.assertEqual(5, int(opt_state.total_notfinite)) - - @chex.variants(with_jit=True, without_jit=True, with_pmap=True) - def test_multi_steps(self): - batch_size = 32 - x_size = 7 - # Parameters should be updated only every `k_steps` optimisation steps. - k_steps = 4 - data = jnp.ones([batch_size, x_size]) - - def get_loss(x): - loss = jnp.sum(hk.Linear(10)(x)**2) - return loss - - loss_init, loss_apply = hk.without_apply_rng(hk.transform(get_loss)) - params = loss_init(jax.random.PRNGKey(1915), data) - - ms_opt = wrappers.MultiSteps( - # Use a non-trivial inner optimiser: - # * it has a state, - # * it requires the params for the update. - combine.chain(transform.scale_by_adam(), - transform.additive_weight_decay(1e-2), - transform.scale(-1e-4)), k_steps) - opt_init, opt_update = ms_opt.gradient_transformation() - - # Put the training in one function, to check that the update is indeed - # jittable. - def train_step(data, opt_state, params): - grad = jax.grad(loss_apply)(params, data) - updates, opt_state = opt_update(grad, opt_state, params) - return updates, opt_state - - opt_state = opt_init(params) - - prev_loss = loss_apply(params, data) - for idx in range(5 * k_steps): - updates, opt_state = self.variant(train_step)(data, opt_state, params) - new_params = update.apply_updates(params, updates) - new_loss = loss_apply(new_params, data) - if idx % k_steps < k_steps - 1: - # The parameters should not have changed and the loss should be - # constant. - jax.tree_util.tree_map( - np.testing.assert_array_equal, new_params, params) - np.testing.assert_equal(new_loss, prev_loss) - self.assertFalse(ms_opt.has_updated(opt_state)) - else: - # This is a step where parameters should actually have been updated, and - # the loss should accordingly go down. - np.testing.assert_array_less(new_loss, prev_loss) - prev_loss = new_loss - self.assertTrue(ms_opt.has_updated(opt_state)) - params = new_params - - def test_multi_steps_every_k_schedule(self): - # Test a non-trivial schedule which varies over time. - ms_opt = wrappers.MultiSteps( - alias.sgd(1e-4), lambda grad_step: jnp.where(grad_step < 2, 1, 3)) - opt_init, opt_update = ms_opt.gradient_transformation() - params = dict(a=jnp.zeros([])) - opt_state = opt_init(params) - grad = dict(a=jnp.zeros([])) - self.assertFalse(ms_opt.has_updated(opt_state)) - # First two steps have 1 mini-step per update. - for _ in range(2): - _, opt_state = opt_update(grad, opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) - # Subsequently, mini-steps should have 3 mini-steps per update. - for _ in range(5): - for _ in range(2): - _, opt_state = opt_update(grad, opt_state, params) - self.assertFalse(ms_opt.has_updated(opt_state)) - _, opt_state = opt_update(grad, opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) - - def test_multi_steps_computes_mean(self): - k_steps = 4 - ms_opt = wrappers.MultiSteps( - transform.scale(1.0), k_steps, use_grad_mean=True) - opt_init, opt_update = ms_opt.gradient_transformation() - params = dict(a=jnp.zeros([])) - opt_state = opt_init(params) - grads = [dict(a=jnp.ones([]) * i) for i in [1, 2, 3, 4]] - self.assertFalse(ms_opt.has_updated(opt_state)) - - # First 3 steps don't update. - for grad in grads[:-1]: - _, opt_state = opt_update(grad, opt_state, params) - self.assertFalse(ms_opt.has_updated(opt_state)) - - # Actual update. - new_params, opt_state = opt_update(grads[-1], opt_state, params) - self.assertTrue(ms_opt.has_updated(opt_state)) - np.testing.assert_array_equal(new_params['a'], 2.5) - - def test_skip_not_finite(self): - step = jnp.zeros([], dtype=jnp.int32) - - with self.subTest('test_pos_inf'): - should_skip, skip_state = wrappers.skip_not_finite( - [jnp.array(float('inf')), jnp.zeros([])], step, None) - self.assertTrue(bool(should_skip)) - self.assertTrue(bool(skip_state['should_skip'])) - self.assertEqual(int(skip_state['num_not_finite']), 1) - - with self.subTest('test_neg_inf'): - should_skip, skip_state = wrappers.skip_not_finite( - [jnp.array(-float('inf')), jnp.zeros([])], step, None) - self.assertTrue(bool(should_skip)) - self.assertTrue(bool(skip_state['should_skip'])) - self.assertEqual(int(skip_state['num_not_finite']), 1) - - with self.subTest('test_nan'): - should_skip, skip_state = wrappers.skip_not_finite( - [jnp.array(float('nan')), jnp.zeros([])], step, None) - self.assertTrue(bool(should_skip)) - self.assertTrue(bool(skip_state['should_skip'])) - self.assertEqual(int(skip_state['num_not_finite']), 1) - - with self.subTest('test_finite'): - should_skip, skip_state = wrappers.skip_not_finite( - [jnp.array(11.), jnp.zeros([])], step, None) - self.assertFalse(bool(should_skip)) - self.assertFalse(bool(skip_state['should_skip'])) - self.assertEqual(int(skip_state['num_not_finite']), 0) - - def test_skip_large_updates(self): - step = jnp.zeros([], dtype=jnp.int32) - - with self.subTest('test_inf'): - should_skip, skip_state = wrappers.skip_large_updates( - [jnp.array(float('inf')), jnp.zeros([])], step, None, 100.) - self.assertTrue(bool(should_skip)) - self.assertTrue(bool(skip_state['should_skip'])) - self.assertEqual(float(skip_state['norm_squared']), float('inf')) - - with self.subTest('test_nan'): - should_skip, skip_state = wrappers.skip_large_updates( - [jnp.array(float('nan')), jnp.zeros([])], step, None, 100.) - self.assertTrue(bool(should_skip)) - self.assertTrue(bool(skip_state['should_skip'])) - # Recall that NaN != NaN. - norm_squared = float(skip_state['norm_squared']) - self.assertNotEqual(norm_squared, norm_squared) - - with self.subTest('test_large'): - should_skip, skip_state = wrappers.skip_large_updates( - [jnp.array(11.), jnp.zeros([])], step, None, 100.) - self.assertTrue(bool(should_skip)) - self.assertTrue(bool(skip_state['should_skip'])) - self.assertEqual(float(skip_state['norm_squared']), 121.) - - with self.subTest('test_small'): - should_skip, skip_state = wrappers.skip_large_updates( - [jnp.zeros([]), jnp.zeros([])], step, None, 100.) - self.assertFalse(bool(should_skip)) - self.assertFalse(bool(skip_state['should_skip'])) - self.assertEqual(float(skip_state['norm_squared']), 0.) - - def test_multi_steps_skip_not_finite(self): - k_steps = 2 - ms_opt = wrappers.MultiSteps( - alias.sgd(1.), k_steps, should_skip_update_fn=wrappers.skip_not_finite) - opt_init, opt_update = ms_opt.gradient_transformation() - opt_init = jax.jit(opt_init) - opt_update = jax.jit(opt_update) - params = dict(a=jnp.zeros([])) - opt_state = opt_init(params) - - with self.subTest('test_good_updates'): - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) - self.assertEqual(int(opt_state.mini_step), 1) - params = update.apply_updates(params, updates) - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) - self.assertEqual(int(opt_state.mini_step), 0) - params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.ones([])) - - with self.subTest('test_inf_updates'): - updates, opt_state = opt_update( - dict(a=jnp.array(float('inf'))), opt_state, params) - self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step - params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.ones([])) - - with self.subTest('test_nan_updates'): - updates, opt_state = opt_update( - dict(a=jnp.full([], float('nan'))), opt_state, params) - self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step - params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.ones([])) - - with self.subTest('test_final_good_updates'): - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) - self.assertEqual(int(opt_state.mini_step), 1) - params = update.apply_updates(params, updates) - updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) - self.assertEqual(int(opt_state.mini_step), 0) - params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.full([], 2.)) - - -class MaskedTest(chex.TestCase): - """Tests for the masked wrapper.""" - - @chex.all_variants - @parameterized.named_parameters( - ('sgd', _build_sgd, False), - ('stateful_sgd', _build_stateful_sgd, False), - ('sgd_w_mask_fn', _build_sgd, True), - ('stateful_sgd_w_mask_fn', _build_stateful_sgd, True), - ) - def test_masked(self, opt_builder, use_fn): - mask = {'a': True, - 'b': [False, True], - 'c': {'d': True, 'e': (False, True)}} - mask_arg = lambda _: mask if use_fn else mask - params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} - params = jax.tree_util.tree_map(jnp.asarray, params) - input_updates = jax.tree_util.tree_map(lambda x: x/10., params) - - # Negate the updates wherever the mask is True - def masked_negate(updates): - return jax.tree_util.tree_map( - lambda upd, m: -upd if m else upd, updates, mask) - correct_updates = masked_negate(input_updates) - - init_fn, update_fn = wrappers.masked(opt_builder(), mask_arg) - update_fn = self.variant(update_fn) - state = self.variant(init_fn)(params) - updates, state = update_fn(input_updates, state, params) - chex.assert_trees_all_close(updates, correct_updates) - - # Check repeated application, this time with no params. - correct_updates = masked_negate(correct_updates) - updates, state = update_fn(updates, state) - chex.assert_trees_all_close(updates, correct_updates) - - @chex.all_variants - @parameterized.named_parameters( - ('sgd', _build_sgd), - ('stateful_sgd', _build_stateful_sgd), - ) - def test_prefix_mask(self, opt_builder): - """Test when the mask is a prefix of the updates PyTree.""" - mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}} - params = {'a': 1., 'b': {'f': 2.}, 'c': {'d': 3., 'e': ([4., 5.], 6.)}} - params = jax.tree_util.tree_map(jnp.asarray, params) - input_updates = jax.tree_util.tree_map(lambda x: x/10., params) - - # Negate the updates wherever the mask (or mask parent) is True - def _masked_sgd_on_updates(m, upd): - return jax.tree_util.tree_map(lambda x: -x, upd) if m else upd - correct_updates = jax.tree_util.tree_map( - _masked_sgd_on_updates, mask, input_updates) - - init_fn, update_fn = wrappers.masked(opt_builder(), mask) - update_fn = self.variant(update_fn) - state = self.variant(init_fn)(params) - updates, state = update_fn(input_updates, state, params) - chex.assert_trees_all_close(updates, correct_updates) - - # Check repeated application, this time with no params. - correct_updates = jax.tree_util.tree_map( - _masked_sgd_on_updates, mask, correct_updates) - updates, state = update_fn(updates, state) - chex.assert_trees_all_close(updates, correct_updates) - - @chex.all_variants - def test_update_requires_params(self): - weight_decay = 0.1 - mask = {'a': True, - 'b': [False, True], - 'c': {'d': True, 'e': (False, True)}} - params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} - params = jax.tree_util.tree_map(jnp.asarray, params) - input_updates = jax.tree_util.tree_map(lambda x: x/10., params) - - correct_updates = jax.tree_util.tree_map( - lambda m, u, p: u + weight_decay * p if m else u, - mask, input_updates, params) - - init_fn, update_fn = wrappers.masked( - transform.additive_weight_decay(weight_decay), mask) - update_fn = self.variant(update_fn) - - state = self.variant(init_fn)(params) - updates, state = update_fn(input_updates, state, params) - chex.assert_trees_all_close(updates, correct_updates) - - params = update.apply_updates(params, updates) - - # Test repeated application - new_correct_updates = jax.tree_util.tree_map( - lambda m, u, p: u + weight_decay * p if m else u, - mask, correct_updates, params) - updates, state = update_fn(correct_updates, state, params) - chex.assert_trees_all_close(updates, new_correct_updates) - - @parameterized.parameters(list, tuple, dict) - def test_empty(self, container): - init_fn, update_fn = wrappers.masked(_build_sgd(), container()) - update_fn(container(), init_fn(container())) - - @parameterized.parameters( - (False, False), (False, True), (True, False), (True, True)) - def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn): - mask = {'a': True, - 'b': [False, True], - 'c': {'d': True, 'e': (False, True)}} - mask_arg = lambda _: mask if use_fn else mask - params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} - params = jax.tree_util.tree_map(jnp.asarray, params) - - if extra_key_in_mask: - mask['c']['extra'] = True - else: - params['c']['extra'] = 7 - - init_fn = wrappers.masked(_build_sgd(), mask_arg)[0] - with self.assertRaises(ValueError): - init_fn(params) - - @chex.all_variants - def test_mask_fn(self): - params = {'a': jnp.ones((1, 2)), 'b': (jnp.ones((1,)), np.ones((1, 2, 3)))} - mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim > 1, p) - init_fn, update_fn = wrappers.masked(transform.add_decayed_weights(0.1), - mask_fn) - update_fn = self.variant(update_fn) - - state = self.variant(init_fn)(params) - grads = jax.tree_util.tree_map(lambda x: x*2, params) - updates, state = update_fn(grads, state, params) - np.testing.assert_allclose(updates['a'], grads['a'] + 0.1*params['a']) - np.testing.assert_allclose(updates['b'][0], grads['b'][0]) - np.testing.assert_allclose(updates['b'][1], - grads['b'][1] + 0.1*params['b'][1]) - - @chex.all_variants - @parameterized.named_parameters( - ('sgd', _build_sgd), - ('stateful_sgd', _build_stateful_sgd), - ) - def test_nested_mask(self, opt_builder): - # https://github.com/deepmind/optax/issues/271 - params = {'linear_1': {'w': jnp.zeros((1, 1)), 'b': jnp.zeros(1)}, - 'linear_2': {'w': jnp.zeros((1, 2)), 'b': jnp.zeros(2)}, - 'linear_3': {'w': jnp.zeros((2, 3)), 'b': jnp.zeros(3)}} - - outer_mask = lambda p: jax.tree_util.tree_map(lambda x: x.ndim > 1, p) - inner_mask = jax.tree_util.tree_map(lambda _: True, params) - inner_mask['linear_2'] = False - - inner = wrappers.masked(opt_builder(), inner_mask) - init_fn, update_fn = wrappers.masked(inner, outer_mask) - - input_updates = jax.tree_util.tree_map(jnp.ones_like, params) - correct_updates = copy.deepcopy(input_updates) - correct_updates['linear_1']['w'] *= -1.0 - correct_updates['linear_3']['w'] *= -1.0 - - state = self.variant(init_fn)(params) - updates, state = self.variant(update_fn)(input_updates, state, params) - chex.assert_trees_all_close(updates, correct_updates) - - @chex.all_variants - def test_masked_state_structure(self): - # https://github.com/deepmind/optax/issues/271 - params = {'a': [jnp.ones(1), (jnp.ones(2), jnp.ones(3))], - 'b': {'c': jnp.ones(4), 'd': jnp.ones(5)}} - mask = {'a': [True, (True, False)], 'b': False} - tx = wrappers.masked(_build_stateful_sgd(), mask) - trace = self.variant(tx.init)(params).inner_state[0].trace - expected_trace = { - 'a': [jnp.zeros(1), (jnp.zeros(2), wrappers.MaskedNode())], - 'b': wrappers.MaskedNode() - } - chex.assert_tree_all_equal_structs(trace, expected_trace) - - def test_masked_state_is_compatible_with_deepmind_tree(self): - """Checks that the masked state is compatible with deepmind/tree. - - DeepMind's tree library and `jax.tree_util` have slightly different - behavior: jax treats `None`s as tree nodes without children while - deepmind/tree treats them as leaves with `None` values. This has led to bugs - when users used deepmind/tree to manipulate masked optimizer states. - - This test ensures that masked parts of the optimizer state are also ignored - by deepmind/tree. - """ - params = { - 'a': [jnp.ones(1), (jnp.ones(2), jnp.ones(3))], - 'b': [jnp.ones(4)] - } - mask = {'a': [True, (True, False)], 'b': False} - opt_init, _ = wrappers.masked(_build_stateful_sgd(), mask) - state = opt_init(params) - chex.assert_trees_all_equal(tree.map_structure(np.array, state), state) - - -class MaybeUpdateTest(chex.TestCase): - """Tests for the maybe_update wrapper.""" - - NUM_STEPS = 3 - - @chex.all_variants - def test_stateless_inner(self): - params = jnp.zeros([]) - grads = jnp.ones([]) - - def should_update(step): - return step < MaybeUpdateTest.NUM_STEPS - - opt = wrappers.maybe_update(transform.scale(2.), should_update) - state = opt.init(params) - update_fn = self.variant(opt.update) - for _ in range(MaybeUpdateTest.NUM_STEPS): - updates, state = update_fn(grads, state) - self.assertEqual(updates, 2.) - # Further updates stop calling the inner optimiser. - for _ in range(5): - updates, state = update_fn(grads, state) - self.assertEqual(updates, 1.) - - @chex.all_variants - def test_statefull_inner(self): - params = jnp.zeros([]) - grads_with_nan = jnp.array(float('nan')) - grads = jnp.ones([]) - - def should_update(step): - return step < MaybeUpdateTest.NUM_STEPS - - opt = wrappers.maybe_update(constrain.zero_nans(), should_update) - state = opt.init(params) - update_fn = self.variant(opt.update) - for _ in range(MaybeUpdateTest.NUM_STEPS - 1): - updates, state = update_fn(grads_with_nan, state) - self.assertEqual(updates, 0.) - self.assertEqual(state.inner_state.found_nan, True) - updates, state = update_fn(grads, state) - self.assertEqual(updates, 1.) - self.assertEqual(state.inner_state.found_nan, False) - # Further updates stop calling the inner optimiser. - for _ in range(5): - updates, state = update_fn(grads_with_nan, state) - # Warning: do not use assertEqual with a NaN as NaN == NaN returns False. - self.assertTrue(jnp.isnan(updates)) - # Inner state is not be updated. - self.assertEqual(state.inner_state.found_nan, False) - - -if __name__ == '__main__': - absltest.main() diff --git a/optax_add_eve/experimental/__init__.py b/optax_add_eve/experimental/__init__.py deleted file mode 100644 index 61cb5150..00000000 --- a/optax_add_eve/experimental/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Experimental features in Optax. - -Features may be removed or modified at any time. -""" - -from optax_add_eve._src.experimental.complex_valued import split_real_and_imaginary -from optax_add_eve._src.experimental.complex_valued import SplitRealAndImaginaryState -from optax_add_eve._src.experimental.extra_args import GradientTransformationWithExtraArgs -from optax_add_eve._src.experimental.extra_args import named_chain diff --git a/optax_add_eve/optax_test.py b/optax_add_eve/optax_test.py deleted file mode 100644 index ea6af7b9..00000000 --- a/optax_add_eve/optax_test.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for optax.""" - -from absl.testing import absltest -import optax_add_eve - - -class OptaxTest(absltest.TestCase): - """Test optax can be imported correctly.""" - - def test_import(self): - self.assertTrue(hasattr(optax_add_eve, 'GradientTransformation')) - - -if __name__ == '__main__': - absltest.main() diff --git a/setup.py b/setup.py index 75209e72..9998a7f4 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def _parse_requirements(path): setup( - name='optax', + name='optax_add_eve', version=_get_version(), url='https://github.com/deepmind/optax', license='Apache 2.0', From 8a0cccd383eeb69e56c2487990d35a78aec8b838 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:37:09 -0600 Subject: [PATCH 05/35] added eve to build --- optax/__init__.py | 2 ++ setup.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/optax/__init__.py b/optax/__init__.py index 278255dc..40c8f360 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -23,6 +23,7 @@ from optax._src.alias import adamaxw from optax._src.alias import adamw from optax._src.alias import amsgrad +from optax._src.alias import eve from optax._src.alias import dpsgd from optax._src.alias import fromage from optax._src.alias import lamb @@ -223,6 +224,7 @@ "ema", "EmaState", "EmptyState", + "eve" "exponential_decay", "FactoredState", "fisher_diag", diff --git a/setup.py b/setup.py index 9998a7f4..75209e72 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def _parse_requirements(path): setup( - name='optax_add_eve', + name='optax', version=_get_version(), url='https://github.com/deepmind/optax', license='Apache 2.0', From b2db79af60cad566a3df13fe2fc2c5d39902f5c7 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:40:16 -0600 Subject: [PATCH 06/35] reverted accidental deletion --- optax/_src/transform.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 58dd3c29..3a91d7bb 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -419,10 +419,9 @@ def scale_by_adamax( ) -> base.GradientTransformation: """Rescale updates according to the Adamax algorithm. - References:nu = update_infinity_moment(updates, state.nu, b2, eps) - count_inc = utils.numerics.safe_int32_increment(state.count) - mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) - nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) + References:References: + [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + Args: b1: Decay rate for the exponentially weighted average of grads. b2: Decay rate for the exponentially weighted maximum of grads. eps: Term added to the denominator to improve numerical stability. From 578742692e1c2c22cb8416f8c7dcfe59b30879ad Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:40:40 -0600 Subject: [PATCH 07/35] typo --- optax/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 3a91d7bb..c6e2f61f 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -419,7 +419,7 @@ def scale_by_adamax( ) -> base.GradientTransformation: """Rescale updates according to the Adamax algorithm. - References:References: + References: [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: b1: Decay rate for the exponentially weighted average of grads. From 5c46d4e494a2c18f8f02ce629bc751e233cb0181 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:41:43 -0600 Subject: [PATCH 08/35] typo --- optax/_src/transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index c6e2f61f..f9b57751 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -421,6 +421,7 @@ def scale_by_adamax( References: [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + Args: b1: Decay rate for the exponentially weighted average of grads. b2: Decay rate for the exponentially weighted maximum of grads. From b61e47c7a20361fa9af78e886dc449f46a2f1f55 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:42:33 -0600 Subject: [PATCH 09/35] typo --- optax/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index f9b57751..c56eb96d 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -421,7 +421,7 @@ def scale_by_adamax( References: [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) - + Args: b1: Decay rate for the exponentially weighted average of grads. b2: Decay rate for the exponentially weighted maximum of grads. From 923ae6235ed8eada3080ebc54547bedcd23a8d33 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:43:24 -0600 Subject: [PATCH 10/35] alphabetized format --- optax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/__init__.py b/optax/__init__.py index 40c8f360..18e16aa0 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -23,8 +23,8 @@ from optax._src.alias import adamaxw from optax._src.alias import adamw from optax._src.alias import amsgrad -from optax._src.alias import eve from optax._src.alias import dpsgd +from optax._src.alias import eve from optax._src.alias import fromage from optax._src.alias import lamb from optax._src.alias import lars From 06cb64d3af65e48feac3fdab1bafadf49786615c Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 19:46:33 -0600 Subject: [PATCH 11/35] typo --- optax/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index c56eb96d..2ab42454 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -512,7 +512,7 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, f: float): updates = jax.tree_util.tree_map( lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) mu = utils.cast_tree(mu, mu_dtype) - return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=f) + return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f_prev=f) return base.GradientTransformation(init_fn, update_fn) From 23211e17e875a94556a9316f721675fc80844b3b Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 20:39:18 -0600 Subject: [PATCH 12/35] conform to optax api --- optax/_src/alias_test.py | 1 + optax/_src/transform.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index be1a68b3..26099139 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -35,6 +35,7 @@ dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='eve', opt_kwargs=dict()), dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)), dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 2ab42454..98a8ce82 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -455,6 +455,7 @@ class ScaleByEveState(NamedTuple): mu: base.Updates nu: base.Updates d: float + f: float f_prev: float @@ -491,20 +492,22 @@ def init_fn(params): mu = jax.tree_util.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=1.) + return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f= 1., f_prev=1.) - def update_fn(updates: base.Updates, state: ScaleByEveState, f: float): + def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): """ - Eve requires an additional parameter: the loss for the current iteration: f = f_t - ScaleByEveState holds the loss from the previous iteration: state.f_prev = f_{t-1} + Eve requires an additional parameter: the loss for the current iteration: state.f = f_t + ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} + It is up to the user to update the state with the current loss before injecting. """ + del params mu = update_moment(updates, state.mu, b1, 1) nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = utils.numerics.safe_int32_increment(state.count) - mu_hat = jax.tree_util.tree_map(lambda m: jnp.asarray(m / (1-b1)), mu) - nu_hat = jax.tree_util.tree_map(lambda v: jnp.asarray(v / (1-b2)), nu) + mu_hat = jax.tree_util.tree_map(lambda m: m / (1-b1), mu) + nu_hat = jax.tree_util.tree_map(lambda v: v / (1-b2), nu) if count_inc > 1: - d_new = jnp.abs(f - state.f_prev) / (jnp.min(jnp.array([f,state.f_prev])) - f_star) + d_new = jnp.abs(state.f - state.f_prev) / (jnp.min(jnp.array([state.f,state.f_prev])) - f_star) d_tilde = jnp.clip(d_new,1/c,c) d = b3*state.d + (1-b3)*d_tilde else: @@ -512,7 +515,7 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, f: float): updates = jax.tree_util.tree_map( lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) mu = utils.cast_tree(mu, mu_dtype) - return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f_prev=f) + return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=state.f, f_prev=state.f) return base.GradientTransformation(init_fn, update_fn) From 01ad74f8740e7c84d699c916b2d6fb72daf42320 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 20:54:53 -0600 Subject: [PATCH 13/35] tests for eve --- optax/_src/transform.py | 11 ++++------- optax/_src/transform_test.py | 1 + 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 98a8ce82..bcadb76c 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -492,7 +492,7 @@ def init_fn(params): mu = jax.tree_util.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f= 1., f_prev=1.) + return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f= 1., f_prev=10.) def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): """ @@ -506,12 +506,9 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): count_inc = utils.numerics.safe_int32_increment(state.count) mu_hat = jax.tree_util.tree_map(lambda m: m / (1-b1), mu) nu_hat = jax.tree_util.tree_map(lambda v: v / (1-b2), nu) - if count_inc > 1: - d_new = jnp.abs(state.f - state.f_prev) / (jnp.min(jnp.array([state.f,state.f_prev])) - f_star) - d_tilde = jnp.clip(d_new,1/c,c) - d = b3*state.d + (1-b3)*d_tilde - else: - d = 1. + d_new = jnp.abs(state.f - state.f_prev) / (jnp.min(jnp.array([state.f,state.f_prev])) - f_star) + d_tilde = jnp.clip(d_new,1/c,c) + d = jnp.where(count_inc > 1, b3*state.d + (1-b3)*d_tilde, 1.) updates = jax.tree_util.tree_map( lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) mu = utils.cast_tree(mu, mu_dtype) diff --git a/optax/_src/transform_test.py b/optax/_src/transform_test.py index 2c4ea948..628db3ac 100644 --- a/optax/_src/transform_test.py +++ b/optax/_src/transform_test.py @@ -44,6 +44,7 @@ def setUp(self): @parameterized.named_parameters([ ('adam', transform.scale_by_adam), ('adamax', transform.scale_by_adamax), + ('eve', transform.scale_by_eve), ('rmsprop', transform.scale_by_rms), ('stddev', transform.scale_by_stddev), ('trust_ratio', transform.scale_by_trust_ratio), From 6f1b3f97f8ac0058535ecbcfcebb8234f34dfc4a Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 21:42:58 -0600 Subject: [PATCH 14/35] added custom update function for eve state --- optax/_src/alias.py | 19 +++++++++++++++---- optax/_src/alias_test.py | 8 +++++++- optax/_src/transform.py | 3 ++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 6dbcf7fd..b4d66c04 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -14,7 +14,7 @@ # ============================================================================== """Aliases for popular optimizers.""" -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, Tuple import jax.numpy as jnp @@ -373,13 +373,24 @@ def eve( `None` then the `dtype` is inferred from `params` and `updates`. Returns: - the corresponding `GradientTransformation`. + the corresponding `GradientTransformation` and a function with which to update + the optimizer state with the required loss parameter before injecting. """ + def update_opt_state(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleState], f: float): + return transform.ScaleByEveState( + count=opt_state[0].count, + mu=opt_state[0].mu, + nu=opt_state[0].nu, + d=opt_state[0].d, + f=f, + f_prev=opt_state[0].f_prev + ), transform.ScaleState() + return combine.chain( transform.scale_by_eve( b1=b1, b2=b2, b3=b3, c=c, eps=eps, f_star=f_star, mu_dtype=mu_dtype), - _scale_by_learning_rate(learning_rate), - ) + _scale_by_learning_rate(learning_rate) + ), update_opt_state def fromage( learning_rate: float, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 26099139..f44377ae 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -109,7 +109,10 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype): raise absltest.SkipTest( f'{opt_name} does not support complex parameters.') - opt = getattr(alias, opt_name)(**opt_kwargs) + if opt_name != 'eve': + opt = getattr(alias, opt_name)(**opt_kwargs) + else: + opt, state_update = getattr(alias, opt_name)(**opt_kwargs) initial_params, final_params, get_updates = target(dtype) @jax.jit @@ -117,6 +120,9 @@ def step(params, state): updates = get_updates(params) if opt_name == 'dpsgd': updates = updates[None] + elif opt_name == 'eve': + f = jnp.mean(jnp.square(params-final_params)) + state = state_update(opt_state=state,f=f) # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index bcadb76c..69b2c476 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -512,7 +512,8 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): updates = jax.tree_util.tree_map( lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) mu = utils.cast_tree(mu, mu_dtype) - return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=state.f, f_prev=state.f) + # assign a decayed f to ScaleByEveState for testing purposes. User will need to update f in practice + return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=b3*state.f, f_prev=state.f) return base.GradientTransformation(init_fn, update_fn) From 895d20f5fd38abe3ae23770d7f097efadff6864f Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 21:48:34 -0600 Subject: [PATCH 15/35] update init --- optax/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/optax/__init__.py b/optax/__init__.py index 18e16aa0..d3242dca 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -131,6 +131,7 @@ from optax._src.transform import scale_by_adamax from optax._src.transform import scale_by_amsgrad from optax._src.transform import scale_by_belief +from optax._src.transform import scale_by_eve from optax._src.transform import scale_by_novograd from optax._src.transform import scale_by_optimistic_gradient from optax._src.transform import scale_by_param_block_norm @@ -146,6 +147,7 @@ from optax._src.transform import ScaleByAdamState from optax._src.transform import ScaleByAmsgradState from optax._src.transform import ScaleByBeliefState +from optax._src.transform import ScaleByEveState from optax._src.transform import ScaleByNovogradState from optax._src.transform import ScaleByRmsState from optax._src.transform import ScaleByRssState @@ -286,6 +288,7 @@ "scale_by_adamax", "scale_by_amsgrad", "scale_by_belief", + "scale_by_eve", "scale_by_factored_rms", "scale_by_novograd", "scale_by_param_block_norm", @@ -303,6 +306,7 @@ "ScaleByAdamState", "ScaleByAmsgradState", "ScaleByBeliefState", + "ScaleByEveState", "ScaleByNovogradState", "ScaleByRmsState", "ScaleByRssState", From 349516f9033dc8b93f27fd33e503e3f43b992c8e Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 22:08:49 -0600 Subject: [PATCH 16/35] documentation --- optax/_src/alias.py | 4 ++-- optax/_src/transform.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index b4d66c04..9fcb5ee1 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -373,8 +373,8 @@ def eve( `None` then the `dtype` is inferred from `params` and `updates`. Returns: - the corresponding `GradientTransformation` and a function with which to update - the optimizer state with the required loss parameter before injecting. + the corresponding `GradientTransformation` + a function with which to update the optimizer state with the required loss parameter before injecting. """ def update_opt_state(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleState], f: float): return transform.ScaleByEveState( diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 69b2c476..9301516e 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -498,7 +498,21 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): """ Eve requires an additional parameter: the loss for the current iteration: state.f = f_t ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} - It is up to the user to update the state with the current loss before injecting. + It is up to the user to update the state with the current loss before injecting using the + second returned function from optax.eve() as follows: + + Example + -------- + Initialize: + >>> optimizer, state_update_fn = optax.eve() + >>> opt_state = optimizer.init(params) + + Train: + >>> while training: + ... loss, grads = jax.value_and_grad(loss_fn)(params, data) + ... opt_state = state_update_fn(opt_state, loss) # <-- Update state here + ... updates, opt_state = optimizer.update(grads, opt_state) + ... params = optax.apply_updates(params, updates) """ del params mu = update_moment(updates, state.mu, b1, 1) From 118d669e9d741ea0aa1a17ab59364333b4f7d094 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 22:09:54 -0600 Subject: [PATCH 17/35] clearer documentation --- optax/_src/alias.py | 4 ++-- optax/_src/alias_test.py | 4 ++-- optax/_src/transform.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 9fcb5ee1..7c7947b8 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -376,7 +376,7 @@ def eve( the corresponding `GradientTransformation` a function with which to update the optimizer state with the required loss parameter before injecting. """ - def update_opt_state(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleState], f: float): + def eve_update_state(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleState], f: float): return transform.ScaleByEveState( count=opt_state[0].count, mu=opt_state[0].mu, @@ -390,7 +390,7 @@ def update_opt_state(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleS transform.scale_by_eve( b1=b1, b2=b2, b3=b3, c=c, eps=eps, f_star=f_star, mu_dtype=mu_dtype), _scale_by_learning_rate(learning_rate) - ), update_opt_state + ), eve_update_state def fromage( learning_rate: float, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index f44377ae..1b7eb8c4 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -112,7 +112,7 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype): if opt_name != 'eve': opt = getattr(alias, opt_name)(**opt_kwargs) else: - opt, state_update = getattr(alias, opt_name)(**opt_kwargs) + opt, eve_update_state = getattr(alias, opt_name)(**opt_kwargs) initial_params, final_params, get_updates = target(dtype) @jax.jit @@ -122,7 +122,7 @@ def step(params, state): updates = updates[None] elif opt_name == 'eve': f = jnp.mean(jnp.square(params-final_params)) - state = state_update(opt_state=state,f=f) + state = eve_update_state(opt_state=state,f=f) # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 9301516e..b497d134 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -510,7 +510,7 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): Train: >>> while training: ... loss, grads = jax.value_and_grad(loss_fn)(params, data) - ... opt_state = state_update_fn(opt_state, loss) # <-- Update state here + ... opt_state = eve_update_state(opt_state, loss) # <-- Update state here ... updates, opt_state = optimizer.update(grads, opt_state) ... params = optax.apply_updates(params, updates) """ From d6dc2f01f9a5cf4edb1b29a5849d5843b9b06314 Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 22:10:15 -0600 Subject: [PATCH 18/35] clearer documentation --- optax/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index b497d134..e3a8c250 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -504,7 +504,7 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): Example -------- Initialize: - >>> optimizer, state_update_fn = optax.eve() + >>> optimizer, eve_update_state = optax.eve() >>> opt_state = optimizer.init(params) Train: From 46a089f83c5ec0c687fddb03392362ab12b0788a Mon Sep 17 00:00:00 2001 From: wglao Date: Fri, 20 Jan 2023 22:11:44 -0600 Subject: [PATCH 19/35] clearer documentation --- optax/_src/alias.py | 4 ++-- optax/_src/alias_test.py | 4 ++-- optax/_src/transform.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 7c7947b8..7382f0e0 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -376,7 +376,7 @@ def eve( the corresponding `GradientTransformation` a function with which to update the optimizer state with the required loss parameter before injecting. """ - def eve_update_state(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleState], f: float): + def eve_update_state_fn(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleState], f: float): return transform.ScaleByEveState( count=opt_state[0].count, mu=opt_state[0].mu, @@ -390,7 +390,7 @@ def eve_update_state(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleS transform.scale_by_eve( b1=b1, b2=b2, b3=b3, c=c, eps=eps, f_star=f_star, mu_dtype=mu_dtype), _scale_by_learning_rate(learning_rate) - ), eve_update_state + ), eve_update_state_fn def fromage( learning_rate: float, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 1b7eb8c4..95ddd7d9 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -112,7 +112,7 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype): if opt_name != 'eve': opt = getattr(alias, opt_name)(**opt_kwargs) else: - opt, eve_update_state = getattr(alias, opt_name)(**opt_kwargs) + opt, eve_update_state_fn = getattr(alias, opt_name)(**opt_kwargs) initial_params, final_params, get_updates = target(dtype) @jax.jit @@ -122,7 +122,7 @@ def step(params, state): updates = updates[None] elif opt_name == 'eve': f = jnp.mean(jnp.square(params-final_params)) - state = eve_update_state(opt_state=state,f=f) + state = eve_update_state_fn(opt_state=state,f=f) # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index e3a8c250..026711f2 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -504,13 +504,13 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): Example -------- Initialize: - >>> optimizer, eve_update_state = optax.eve() + >>> optimizer, eve_update_state_fn = optax.eve() >>> opt_state = optimizer.init(params) Train: >>> while training: ... loss, grads = jax.value_and_grad(loss_fn)(params, data) - ... opt_state = eve_update_state(opt_state, loss) # <-- Update state here + ... opt_state = eve_update_state_fn(opt_state, loss) # <-- Update state here ... updates, opt_state = optimizer.update(grads, opt_state) ... params = optax.apply_updates(params, updates) """ From a3485a4db05333ffb2f587497db319747fd91cc1 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:01:26 -0600 Subject: [PATCH 20/35] eve passes all tests --- optax/_src/alias.py | 88 ++++++++++++++++++++++++++++++++-------- optax/_src/alias_test.py | 13 +++--- optax/_src/transform.py | 28 +++---------- 3 files changed, 82 insertions(+), 47 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 7382f0e0..19be3835 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -23,6 +23,7 @@ from optax._src import combine from optax._src import factorized from optax._src import privacy +from optax._src import schedule from optax._src import transform from optax._src import wrappers @@ -339,17 +340,18 @@ def amsgrad( _scale_by_learning_rate(learning_rate), ) -def eve( - learning_rate: float = 1e-3, +def _eve( + a1: float = 1e-3, b1: float = 0.9, b2: float = 0.999, b3: float = 0.999, c: float = 10., eps: float = 1e-8, + f: float = 1., f_star: float = 0., mu_dtype: Optional[Any] = None, ) -> base.GradientTransformation: - """The Eve optimizer. + """The Eve optimizer (uninjectable, see eve() below). Eve is an SGD variant with adaptive global and local learning rates. The `learning_rate` used for each weight is computed from estimates of first- and second-order @@ -361,7 +363,7 @@ def eve( Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 Args: - learning_rate: this is the initial global scaling factor. + a1: this is the initial global scaling factor. b1: the exponential decay rate to track the first moment of past gradients. b2: the exponential decay rate to track the second moment of past gradients. b3: the exponential decay rate to track the sub-optimality. @@ -374,23 +376,73 @@ def eve( Returns: the corresponding `GradientTransformation` - a function with which to update the optimizer state with the required loss parameter before injecting. + + Note: + Eve requires an additional parameter: the loss for the current iteration: state.f = f_t + ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} + Since it is up to the user to inject the current loss before calling update on the + parameters, the eve alias returns an injectable state by default by wrapping _eve in + inject_hyperparameters. """ - def eve_update_state_fn(opt_state: Tuple[transform.ScaleByEveState,transform.ScaleState], f: float): - return transform.ScaleByEveState( - count=opt_state[0].count, - mu=opt_state[0].mu, - nu=opt_state[0].nu, - d=opt_state[0].d, - f=f, - f_prev=opt_state[0].f_prev - ), transform.ScaleState() - return combine.chain( transform.scale_by_eve( - b1=b1, b2=b2, b3=b3, c=c, eps=eps, f_star=f_star, mu_dtype=mu_dtype), - _scale_by_learning_rate(learning_rate) - ), eve_update_state_fn + b1=b1, b2=b2, b3=b3, c=c, eps=eps, f=f, f_star=f_star, mu_dtype=mu_dtype), + _scale_by_learning_rate(a1) + ) + + +def eve( + a1: float = 1e-3, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.999, + c: float = 10., + eps: float = 1e-8, + f: float = 1., + f_star: float = 0., + mu_dtype: Optional[Any] = None, +): + """Injectable Eve optimizer. + + Eve requires an additional parameter: the loss for the current iteration: state.f = f_t + ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} + Since it is up to the user to inject the current loss before calling update on the + parameters, the eve alias returns an injectable state by default. + + Args: + a1: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation` wrapped in inject_hyperparameters + + Inject the current loss as follows: + + Initialize:: + + optimizer = optax.eve() + opt_state = optimizer.init(params) + + Train:: + + while training: + loss, grads = jax.value_and_grad(loss_fn)(params, data) + opt_state.hyperparams['f'] = loss # <-- Update state here + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + """ + return schedule.inject_hyperparams(_eve)( + a1=a1, b1=b1, b2=b2, b3=b3, c=c, eps=eps, f=f, f_star=f_star, mu_dtype=mu_dtype + ) + def fromage( learning_rate: float, diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 95ddd7d9..f8986e5a 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -35,7 +35,7 @@ dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='eve', opt_kwargs=dict()), + dict(opt_name='eve', opt_kwargs=dict(f=10)), dict(opt_name='lars', opt_kwargs=dict(learning_rate=1.0)), dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1e-3)), dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1e-3, eta=1e-4)), @@ -109,10 +109,7 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype): raise absltest.SkipTest( f'{opt_name} does not support complex parameters.') - if opt_name != 'eve': - opt = getattr(alias, opt_name)(**opt_kwargs) - else: - opt, eve_update_state_fn = getattr(alias, opt_name)(**opt_kwargs) + opt = getattr(alias, opt_name)(**opt_kwargs) initial_params, final_params, get_updates = target(dtype) @jax.jit @@ -122,7 +119,7 @@ def step(params, state): updates = updates[None] elif opt_name == 'eve': f = jnp.mean(jnp.square(params-final_params)) - state = eve_update_state_fn(opt_state=state,f=f) + state.hyperparams['f'] = f # Complex gradients need to be conjugated before being added to parameters # https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 updates = jax.tree_util.tree_map(lambda x: x.conj(), updates) @@ -151,6 +148,10 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( # https://github.com/deepmind/optax/issues/412. opt_inject = schedule.inject_hyperparams( opt_factory, static_args=('min_dim_size_to_factor',))(**opt_kwargs) + elif opt_name == 'eve': + # Eve is injectable by default. Reassign opt to uninjectable _eve alias + opt = alias._eve(**opt_kwargs) + opt_inject = opt_factory(**opt_kwargs) else: opt_inject = schedule.inject_hyperparams(opt_factory)(**opt_kwargs) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 026711f2..af8adf2b 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -455,7 +455,6 @@ class ScaleByEveState(NamedTuple): mu: base.Updates nu: base.Updates d: float - f: float f_prev: float @@ -464,6 +463,7 @@ def scale_by_eve(b1: float = 0.9, b3: float = 0.999, c: float = 10., eps: float = 1e-8, + f: float = 1., f_star: float = 0., mu_dtype: Optional[Any] = None, ) -> base.GradientTransformation: @@ -492,42 +492,24 @@ def init_fn(params): mu = jax.tree_util.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f= 1., f_prev=10.) + return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=10.) + def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): - """ - Eve requires an additional parameter: the loss for the current iteration: state.f = f_t - ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} - It is up to the user to update the state with the current loss before injecting using the - second returned function from optax.eve() as follows: - - Example - -------- - Initialize: - >>> optimizer, eve_update_state_fn = optax.eve() - >>> opt_state = optimizer.init(params) - - Train: - >>> while training: - ... loss, grads = jax.value_and_grad(loss_fn)(params, data) - ... opt_state = eve_update_state_fn(opt_state, loss) # <-- Update state here - ... updates, opt_state = optimizer.update(grads, opt_state) - ... params = optax.apply_updates(params, updates) - """ del params mu = update_moment(updates, state.mu, b1, 1) nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = utils.numerics.safe_int32_increment(state.count) mu_hat = jax.tree_util.tree_map(lambda m: m / (1-b1), mu) nu_hat = jax.tree_util.tree_map(lambda v: v / (1-b2), nu) - d_new = jnp.abs(state.f - state.f_prev) / (jnp.min(jnp.array([state.f,state.f_prev])) - f_star) + d_new = jnp.abs(f - state.f_prev) / (jnp.min(jnp.array([f,state.f_prev])) - f_star) d_tilde = jnp.clip(d_new,1/c,c) d = jnp.where(count_inc > 1, b3*state.d + (1-b3)*d_tilde, 1.) updates = jax.tree_util.tree_map( lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) mu = utils.cast_tree(mu, mu_dtype) # assign a decayed f to ScaleByEveState for testing purposes. User will need to update f in practice - return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f=b3*state.f, f_prev=state.f) + return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f_prev=f) return base.GradientTransformation(init_fn, update_fn) From b19077d2cb76db734dd50914e0eecce58b7f3492 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:05:47 -0600 Subject: [PATCH 21/35] typo --- optax/_src/alias.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 19be3835..8242233b 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -382,7 +382,7 @@ def _eve( ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} Since it is up to the user to inject the current loss before calling update on the parameters, the eve alias returns an injectable state by default by wrapping _eve in - inject_hyperparameters. + inject_hyperparams. """ return combine.chain( transform.scale_by_eve( @@ -407,7 +407,8 @@ def eve( Eve requires an additional parameter: the loss for the current iteration: state.f = f_t ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} Since it is up to the user to inject the current loss before calling update on the - parameters, the eve alias returns an injectable state by default. + parameters, the eve alias returns an injectable state by default by wrapping _eve in + inject_hyperparams. Args: a1: this is the initial global scaling factor. @@ -422,7 +423,7 @@ def eve( `None` then the `dtype` is inferred from `params` and `updates`. Returns: - the corresponding `GradientTransformation` wrapped in inject_hyperparameters + the corresponding `GradientTransformation` wrapped in inject_hyperparams Inject the current loss as follows: From 3099b6cfe01bf32739020022b25b876dd50c66d8 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:06:39 -0600 Subject: [PATCH 22/35] remove unnecessary import --- optax/_src/alias.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 8242233b..d8798861 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -14,7 +14,7 @@ # ============================================================================== """Aliases for popular optimizers.""" -from typing import Any, Callable, Optional, Union, Tuple +from typing import Any, Callable, Optional, Union import jax.numpy as jnp From 252be75cf9eaa6fae30a90321ccc7f0797135a40 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:10:01 -0600 Subject: [PATCH 23/35] update documentation --- optax/_src/alias.py | 3 ++- optax/_src/transform.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index d8798861..768207d8 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -357,7 +357,8 @@ def _eve( used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages) as in ADAM. The global learning rate is scaled by some notion of sub-optimality and is increased - when far from optimal and is decreased when approaching optimality + when far from optimal and is decreased when approaching optimality. This is also computed + with exponential moving averages, similar to the first and second moments. References: Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 diff --git a/optax/_src/transform.py b/optax/_src/transform.py index af8adf2b..cbda28b4 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -479,6 +479,7 @@ def scale_by_eve(b1: float = 0.9, c: the clipping limit to prevent extreme global learning rate changes eps: a small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) f_star: estimation of the global minimum mu_dtype: optional `dtype` to be used for the first order accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. From 12ef13e6bcec00de8059e126b9514e87900c093f Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:11:59 -0600 Subject: [PATCH 24/35] update documentation --- optax/_src/alias.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 768207d8..3bf326ce 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -379,11 +379,10 @@ def _eve( the corresponding `GradientTransformation` Note: - Eve requires an additional parameter: the loss for the current iteration: state.f = f_t + Eve requires an additional parameter: the loss for the current iteration: f = f_t ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} - Since it is up to the user to inject the current loss before calling update on the - parameters, the eve alias returns an injectable state by default by wrapping _eve in - inject_hyperparams. + Since it is up to the user to inject the current loss before calling the update function, + the eve alias returns an injectable state by default by wrapping _eve in inject_hyperparams. """ return combine.chain( transform.scale_by_eve( @@ -405,11 +404,10 @@ def eve( ): """Injectable Eve optimizer. - Eve requires an additional parameter: the loss for the current iteration: state.f = f_t + Eve requires an additional parameter: the loss for the current iteration: f = f_t ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} - Since it is up to the user to inject the current loss before calling update on the - parameters, the eve alias returns an injectable state by default by wrapping _eve in - inject_hyperparams. + Since it is up to the user to inject the current loss before calling the update function, + the eve alias returns an injectable state by default by wrapping _eve in inject_hyperparams. Args: a1: this is the initial global scaling factor. From ea0d0cad07c963491003d90d0110267c4a785010 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:19:42 -0600 Subject: [PATCH 25/35] update doc strings --- optax/_src/alias.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 3bf326ce..e86faa3d 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -351,14 +351,14 @@ def _eve( f_star: float = 0., mu_dtype: Optional[Any] = None, ) -> base.GradientTransformation: - """The Eve optimizer (uninjectable, see eve() below). + """The Eve optimizer (uninjectable, see `eve()`). - Eve is an SGD variant with adaptive global and local learning rates. The `learning_rate` - used for each weight is computed from estimates of first- and second-order - moments of the gradients (using suitable exponential moving averages) as in ADAM. - The global learning rate is scaled by some notion of sub-optimality and is increased - when far from optimal and is decreased when approaching optimality. This is also computed - with exponential moving averages, similar to the first and second moments. + Eve is an SGD variant with adaptive global and local learning rates. The local learning rate + used for each weight is computed from estimates of first- and second-order moments of the + gradients (using suitable exponential moving averages) as in ADAM. These are then scaled by the + global learning rate `a1`, which is scaled by some notion of sub-optimality `d`: increasing the + global rate when far from optimal and decreasing it when approaching optimality. This is also + computed with exponential moving averages, similar to the first and second moments. References: Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 @@ -371,6 +371,7 @@ def _eve( c: the clipping limit to prevent extreme global learning rate changes eps: a small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) f_star: estimation of the global minimum mu_dtype: optional `dtype` to be used for the first order accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. @@ -379,10 +380,10 @@ def _eve( the corresponding `GradientTransformation` Note: - Eve requires an additional parameter: the loss for the current iteration: f = f_t - ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} + Eve requires an additional parameter: the loss for the current iteration: `f` = `f_t` + ScaleByEveState also holds the loss from the previous iteration: `state.f_prev` = `f_{t-1}` Since it is up to the user to inject the current loss before calling the update function, - the eve alias returns an injectable state by default by wrapping _eve in inject_hyperparams. + the `eve` alias returns an injectable state by default by wrapping `_eve` in `inject_hyperparams`. """ return combine.chain( transform.scale_by_eve( @@ -404,10 +405,10 @@ def eve( ): """Injectable Eve optimizer. - Eve requires an additional parameter: the loss for the current iteration: f = f_t - ScaleByEveState also holds the loss from the previous iteration: state.f_prev = f_{t-1} + Eve requires an additional parameter: the loss for the current iteration: `f` = `f_t` + ScaleByEveState also holds the loss from the previous iteration: `state.f_prev` = `f_{t-1}` Since it is up to the user to inject the current loss before calling the update function, - the eve alias returns an injectable state by default by wrapping _eve in inject_hyperparams. + the `eve` alias returns an injectable state by default by wrapping `_eve` in `inject_hyperparams`. Args: a1: this is the initial global scaling factor. @@ -417,6 +418,7 @@ def eve( c: the clipping limit to prevent extreme global learning rate changes eps: a small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) f_star: estimation of the global minimum mu_dtype: optional `dtype` to be used for the first order accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. @@ -425,6 +427,7 @@ def eve( the corresponding `GradientTransformation` wrapped in inject_hyperparams Inject the current loss as follows: + ----------------------------------- Initialize:: From c23520f4c1d635e32e20bf720d74a67a919a9082 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:21:53 -0600 Subject: [PATCH 26/35] update doc string --- optax/_src/alias.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index e86faa3d..3e115b0d 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -356,9 +356,9 @@ def _eve( Eve is an SGD variant with adaptive global and local learning rates. The local learning rate used for each weight is computed from estimates of first- and second-order moments of the gradients (using suitable exponential moving averages) as in ADAM. These are then scaled by the - global learning rate `a1`, which is scaled by some notion of sub-optimality `d`: increasing the - global rate when far from optimal and decreasing it when approaching optimality. This is also - computed with exponential moving averages, similar to the first and second moments. + global learning rate `a1`, which is adaptively modified by some notion of sub-optimality `d`: + increasing the global rate when far from optimal and decreasing it when approaching optimality. + This is also computed with exponential moving averages, similar to the first and second moments. References: Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 From 74e4a1229952355183cfbc17a54688bf8f021856 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:24:08 -0600 Subject: [PATCH 27/35] formatting --- optax/_src/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index cbda28b4..2eabfd07 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -458,7 +458,8 @@ class ScaleByEveState(NamedTuple): f_prev: float -def scale_by_eve(b1: float = 0.9, +def scale_by_eve( + b1: float = 0.9, b2: float = 0.999, b3: float = 0.999, c: float = 10., From 58d26593851c82e2f9d43de36e930d3831b1c1aa Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:49:44 -0600 Subject: [PATCH 28/35] formatting and typo --- optax/__init__.py | 2 +- optax/_src/alias.py | 46 ++++++++++++++++++++++++++------------- optax/_src/transform.py | 48 ++++++++++++++++++++++------------------- 3 files changed, 58 insertions(+), 38 deletions(-) diff --git a/optax/__init__.py b/optax/__init__.py index d3242dca..54062529 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -226,7 +226,7 @@ "ema", "EmaState", "EmptyState", - "eve" + "eve", "exponential_decay", "FactoredState", "fisher_diag", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 3e115b0d..ce79a5c6 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -353,12 +353,14 @@ def _eve( ) -> base.GradientTransformation: """The Eve optimizer (uninjectable, see `eve()`). - Eve is an SGD variant with adaptive global and local learning rates. The local learning rate - used for each weight is computed from estimates of first- and second-order moments of the - gradients (using suitable exponential moving averages) as in ADAM. These are then scaled by the - global learning rate `a1`, which is adaptively modified by some notion of sub-optimality `d`: - increasing the global rate when far from optimal and decreasing it when approaching optimality. - This is also computed with exponential moving averages, similar to the first and second moments. + Eve is an SGD variant with adaptive global and local learning rates. + The local learning rate used for each weight is computed from estimates of + first- and second-order moments of the gradients (using suitable exponential + moving averages) as in ADAM. These are then scaled by the global learning + rate `a1`, which is adaptively modified by some notion of sub-optimality `d`: + increasing the global rate when far from optimal and decreasing it when + approaching optimality. This is also computed with exponential moving + averages, similar to the first and second moments. References: Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 @@ -378,12 +380,19 @@ def _eve( Returns: the corresponding `GradientTransformation` - + Note: - Eve requires an additional parameter: the loss for the current iteration: `f` = `f_t` - ScaleByEveState also holds the loss from the previous iteration: `state.f_prev` = `f_{t-1}` - Since it is up to the user to inject the current loss before calling the update function, - the `eve` alias returns an injectable state by default by wrapping `_eve` in `inject_hyperparams`. + Eve requires an additional parameter: the loss for the current iteration:: + + f := f_t + + ScaleByEveState also holds the loss from the previous iteration:: + + state.f_prev := f_{t-1} + + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. """ return combine.chain( transform.scale_by_eve( @@ -405,10 +414,17 @@ def eve( ): """Injectable Eve optimizer. - Eve requires an additional parameter: the loss for the current iteration: `f` = `f_t` - ScaleByEveState also holds the loss from the previous iteration: `state.f_prev` = `f_{t-1}` - Since it is up to the user to inject the current loss before calling the update function, - the `eve` alias returns an injectable state by default by wrapping `_eve` in `inject_hyperparams`. + Eve requires an additional parameter: the loss for the current iteration:: + + f := f_t + + ScaleByEveState also holds the loss from the previous iteration:: + + state.f_prev := f_{t-1} + + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. Args: a1: this is the initial global scaling factor. diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 2eabfd07..1d6f7acf 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -450,12 +450,12 @@ def update_fn(updates, state, params=None): class ScaleByEveState(NamedTuple): - """State for the Eve algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: base.Updates - nu: base.Updates - d: float - f_prev: float + """State for the Eve algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + d: float + f_prev: float def scale_by_eve( @@ -471,22 +471,22 @@ def scale_by_eve( """Rescale updates according to the Eve algorithm. References: - [Hayashi et al, 2018](https://arxiv.org/abs/1611.01505) + [Hayashi et al, 2018](https://arxiv.org/abs/1611.01505) Args: - b1: the exponential decay rate to track the first moment of past gradients. - b2: the exponential decay rate to track the second moment of past gradients. - b3: the exponential decay rate to track the sub-optimality. - c: the clipping limit to prevent extreme global learning rate changes - eps: a small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - f: the current loss value. (needs to be injected before update is called) - f_star: estimation of the global minimum - mu_dtype: optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. Returns: - An (init_fn, update_fn) tuple. + An (init_fn, update_fn) tuple. """ mu_dtype = utils.canonicalize_dtype(mu_dtype) @@ -494,7 +494,9 @@ def init_fn(params): mu = jax.tree_util.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByEveState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=10.) + return ScaleByEveState( + count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, d=1., f_prev=10. + ) def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): @@ -504,14 +506,16 @@ def update_fn(updates: base.Updates, state: ScaleByEveState, params=None): count_inc = utils.numerics.safe_int32_increment(state.count) mu_hat = jax.tree_util.tree_map(lambda m: m / (1-b1), mu) nu_hat = jax.tree_util.tree_map(lambda v: v / (1-b2), nu) - d_new = jnp.abs(f - state.f_prev) / (jnp.min(jnp.array([f,state.f_prev])) - f_star) + d_new = jnp.abs(f-state.f_prev) /\ + (jnp.min(jnp.array([f,state.f_prev]))-f_star) d_tilde = jnp.clip(d_new,1/c,c) d = jnp.where(count_inc > 1, b3*state.d + (1-b3)*d_tilde, 1.) updates = jax.tree_util.tree_map( lambda m, v: m / (jnp.sqrt(v) + eps) / d, mu_hat, nu_hat) mu = utils.cast_tree(mu, mu_dtype) - # assign a decayed f to ScaleByEveState for testing purposes. User will need to update f in practice - return updates, ScaleByEveState(count=count_inc, mu=mu, nu=nu, d=d, f_prev=f) + return updates, ScaleByEveState( + count=count_inc, mu=mu, nu=nu, d=d, f_prev=f + ) return base.GradientTransformation(init_fn, update_fn) From 77df2f3839adccd4ca34fffdbb57134a0a376642 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 18:54:45 -0600 Subject: [PATCH 29/35] formatting --- optax/_src/alias.py | 141 ++++++++++++++++++++++---------------------- 1 file changed, 71 insertions(+), 70 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index ce79a5c6..682c9af1 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -353,46 +353,46 @@ def _eve( ) -> base.GradientTransformation: """The Eve optimizer (uninjectable, see `eve()`). - Eve is an SGD variant with adaptive global and local learning rates. - The local learning rate used for each weight is computed from estimates of - first- and second-order moments of the gradients (using suitable exponential - moving averages) as in ADAM. These are then scaled by the global learning - rate `a1`, which is adaptively modified by some notion of sub-optimality `d`: - increasing the global rate when far from optimal and decreasing it when - approaching optimality. This is also computed with exponential moving - averages, similar to the first and second moments. - - References: - Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 - - Args: - a1: this is the initial global scaling factor. - b1: the exponential decay rate to track the first moment of past gradients. - b2: the exponential decay rate to track the second moment of past gradients. - b3: the exponential decay rate to track the sub-optimality. - c: the clipping limit to prevent extreme global learning rate changes - eps: a small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - f: the current loss value. (needs to be injected before update is called) - f_star: estimation of the global minimum - mu_dtype: optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. + Eve is an SGD variant with adaptive global and local learning rates. + The local learning rate used for each weight is computed from estimates of + first- and second-order moments of the gradients (using suitable exponential + moving averages) as in ADAM. These are then scaled by the global learning + rate `a1`, which is adaptively modified by some notion of sub-optimality `d`: + increasing the global rate when far from optimal and decreasing it when + approaching optimality. This is also computed with exponential moving + averages, similar to the first and second moments. + + References: + Hayashi et al, 2018: https://arXiv.org/abs/1611.01505 - Returns: - the corresponding `GradientTransformation` + Args: + a1: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation` - Note: - Eve requires an additional parameter: the loss for the current iteration:: + Note: + Eve requires an additional parameter: the loss for the current iteration:: - f := f_t + f := f_t - ScaleByEveState also holds the loss from the previous iteration:: + ScaleByEveState also holds the loss from the previous iteration:: - state.f_prev := f_{t-1} + state.f_prev := f_{t-1} - Since it is up to the user to inject the current loss before calling the - update function, the `eve` alias returns an injectable state by default by - wrapping `_eve` in `inject_hyperparams`. + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. """ return combine.chain( transform.scale_by_eve( @@ -411,55 +411,56 @@ def eve( f: float = 1., f_star: float = 0., mu_dtype: Optional[Any] = None, -): +) -> base.GradientTransformation: """Injectable Eve optimizer. - - Eve requires an additional parameter: the loss for the current iteration:: - f := f_t + Eve requires an additional parameter: the loss for the current iteration:: - ScaleByEveState also holds the loss from the previous iteration:: + f := f_t - state.f_prev := f_{t-1} + ScaleByEveState also holds the loss from the previous iteration:: - Since it is up to the user to inject the current loss before calling the - update function, the `eve` alias returns an injectable state by default by - wrapping `_eve` in `inject_hyperparams`. + state.f_prev := f_{t-1} - Args: - a1: this is the initial global scaling factor. - b1: the exponential decay rate to track the first moment of past gradients. - b2: the exponential decay rate to track the second moment of past gradients. - b3: the exponential decay rate to track the sub-optimality. - c: the clipping limit to prevent extreme global learning rate changes - eps: a small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - f: the current loss value. (needs to be injected before update is called) - f_star: estimation of the global minimum - mu_dtype: optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. + Since it is up to the user to inject the current loss before calling the + update function, the `eve` alias returns an injectable state by default by + wrapping `_eve` in `inject_hyperparams`. + + Args: + a1: this is the initial global scaling factor. + b1: the exponential decay rate to track the first moment of past gradients. + b2: the exponential decay rate to track the second moment of past gradients. + b3: the exponential decay rate to track the sub-optimality. + c: the clipping limit to prevent extreme global learning rate changes + eps: a small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + f: the current loss value. (needs to be injected before update is called) + f_star: estimation of the global minimum + mu_dtype: optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + the corresponding `GradientTransformation` wrapped in inject_hyperparams - Returns: - the corresponding `GradientTransformation` wrapped in inject_hyperparams + Inject the current loss as follows: + ----------------------------------- - Inject the current loss as follows: - ----------------------------------- + Initialize:: - Initialize:: + optimizer = optax.eve() + opt_state = optimizer.init(params) - optimizer = optax.eve() - opt_state = optimizer.init(params) + Train:: - Train:: - - while training: - loss, grads = jax.value_and_grad(loss_fn)(params, data) - opt_state.hyperparams['f'] = loss # <-- Update state here - updates, opt_state = optimizer.update(grads, opt_state) - params = optax.apply_updates(params, updates) + while training: + loss, grads = jax.value_and_grad(loss_fn)(params, data) + opt_state.hyperparams['f'] = loss # <-- Update state here + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) """ return schedule.inject_hyperparams(_eve)( - a1=a1, b1=b1, b2=b2, b3=b3, c=c, eps=eps, f=f, f_star=f_star, mu_dtype=mu_dtype + a1=a1, b1=b1, b2=b2, b3=b3, c=c, eps=eps, + f=f, f_star=f_star, mu_dtype=mu_dtype ) From 13975db01429746f2c1221d725c202d8ad20a31c Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 19:04:00 -0600 Subject: [PATCH 30/35] test --- optax-0.1.5.dev0-py3-none-any.whl | Bin 0 -> 156407 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 optax-0.1.5.dev0-py3-none-any.whl diff --git a/optax-0.1.5.dev0-py3-none-any.whl b/optax-0.1.5.dev0-py3-none-any.whl new file mode 100644 index 0000000000000000000000000000000000000000..28a2c86228f0badd42c41ba4f23a378efddadca1 GIT binary patch literal 156407 zcmaI7V{oQX*R31dw$riAj&0j|V%xTD+qTiMZQD*d$?133*?YfV)u~hWzxDI3TGt$7 z%&`=tLBY^~fPkQYfZ^EGb_yquYTY#O_(!@N5}rkSLZYR#;4^ED!3@IOvRn`MqZ(yp>hTmP+5#dXhji@Ab>$ z!~Hrkj#swm^aEKYnZ=e#@&nbhrTT*%{kFD zXc_s0w2AuV#HV7H?zXfnADvnB=!qw0#5nixPHYP~Is3jGJQ;a%Lkx8vyJXN9i`4j9 zHR2S`a0sdGQ6xCoXH2%twZuLaApyN`{BvYmXgl7JF$CtynFuI!PD2p*Poh^D?50Q>Td7m5Fd;)z2{B%>QAPbRU?ydISdOSr;(?=WHvXeRJ&vIX(lKF5nkGg2EVWEN zjGR&#H8p37Q=b+nB?m!57q^8P8SC-w_XQ3K3CZBx@h4Swx{O;G;E|`cs_8+al*1ZA zjjwOCW5R_MRnG<_+%(heoMC4`MZOMqge-PUy-&B}2xCK=p(iMe4vQT@cvd@7IPQc& zNE5w!G}C}O^)fZZXl_hvsH}XfNF1a3tT?!UWPI4*PHwm$0t3} zs*`s66`^hOP~D9BK;k%gI(e^CYK+xgd#$lr@81*k7X*qo)z#Vsy&btcfOnv^ZlhTq zsm-T33&wW8IwS8&2styaBU!bi{%Lv{$YsBM#&teLe8<3OrlQsmAQ#3@8^ zL}YJrE6sE|;Ggfi&{b9Hnu}`-NP-YvtRmJ`33>{1I?!z`F-DBsI7fyIm9+r`YcIhiVUQ1w0V>*3A5VD3Q~w*d$|ko!5S0l7#%EM?O;LCq@pQ& zlC9LfP$p&OmKl0`Gd4Ca79C)O8`LPCwj6a>@2=6UYH+(|sJ9ehWaRkD`CtL$n`L?{abkcSG!Z2LveiGE zH&BCEM1Wb>r0uH9TaZl&s4(*32#Ft^D8#*`Vb#X>Aw2{|ULQU@C3VKxuJK6Q?TC-b z2sj|znMkx^#Th|ECLE`F@T?SqOv)Jss}R*r5Y5RXQ6W^t$iJ(<7FQk#@^AWj+O#*8 z!;~XTDH83e#3RT#$X{%6Uv>V(YL@XS&cL6qxFa~pieZb$g<|0j1#LyfvsxEqMJyG4 zBRiqsp&Y(za71kmVMhQEjrs*~m~qsx`{XJ?T)Km>646+-sV@su(OEd(dwGkU*6ooK zhxry`7ebHtQQOSlLd(UQ!d+qOt#R}1Pw|`<49VcIAvxy!7~GZ!rZk-~i6o|iK}P_D z&0D)tHR_tu`CjCBYDx3W4yEqKFtJG*VEDeODLlJ^&4KEqox%W5QgXv~xJAZRtY0$q z=2q)C<=(ZPEhwiLTx)OAV0VHzPh>UuN|*UD0udp3X#u8bke0 z8p#Oc1PY;Neo#cZtvU_ENJwsl{j`4&d>1w@ReJ4(wdl)k(TcJpg?Jv)drVY=Px}H& zVHdMD9|jD-yecmdMZp(xnO3!eSEj}?2MI!+?TxIiG)SPTj_(%2J3c9gt5m%cYe#BR zWKf6axzQ;|e$FwyrVwhKWDVap{T+4sUR$t$vcr4iV1l0f3q8nhmh^8=o={g-?LcA->*7%-%s z1fX%I$?sh%o{7+Dc2piPb;vmLyJ4)BDFXc(ZQ&_JRr1LI|F&&}-;i?ix5x&CN+E zVf|VQr1dJ}U=gsLyNjnclb465qlXdJr+e*`s_If|X>UpDd{xA-Wq;tm$K%HazE4+O zc70!yi?_MHztPky=lFYV07~_t7~!Wa;@p6mD)!AEt`8cEjM9TAIdGnLoW*zh0_o>{ zf*lIVL0A4T`JMFEV3!tW@nR0p+MliU)Y&|NXbC%w6S zw9j$P5~qVR;l1DpX&DriDN`o-lJ=4h%7!4|#6inHp~CXT@0(B&W&Sgi#I;(1BHmc- zMwD9Alj``7x^@+fTz=ch;cVw3Px1r?)RAuV^aKz#s@PnAS=^aBwQ(K45yE;(5bSm) zJBTj~v{siL(H23(h-bv0b(hw=V`w?TU)yXX^)BP^vL|at-F0jTz)~olb%N7|1+mC< z5fO}+LUCPu+T-UowjEp4pv3LFon>USwFL7QM6VVv5$?Q8eg7tz4x!{jU)&beP zQedmiM6U&;1ErnE%CX9g9kNzEy1N5yc)Y*zwR}V%Zgt~`&q5clNl?q5HqiT%uZ&GD zbI<(4i%Fc-F$^7^*z$RccYz?HikZ|f^M!t)gY#+GKyUA!2PtuG@EB(n#y^|Q(!MCL z)tmpKA=Wdi5XUad&>>m8))ifk+kKr;^fKByl`JKEQkrk|*kpyHhBR&cz&EUH&FB#EbLI9`>sFx z9oNtL$`T&pUYw+7j3CWFPksjl3%OHFn-h88`_Q?8jW)%Enb7MfoV=-1(RU$gJqymb z__;Dxv2`T5VtWXO3Wd!SFX-ZwPl%ht-Q<=PYU>5ducIog@wMXN`0CWD09U)fz;l?Y zR`buyve9iUH8$|?ZV_~dA6sa(W6Hpo7wq3~?Z;2S z0dTkdzau;+{+~qIvg};ws8X6N&WfN*z7)3Y&B86PDE6TS$9Ies?W#sUWQAh6m)xNA zu=wuc;V8xkI!#d^#WYi&!LAq|(&t@YpzLA+&kegkfqTZvHk9&J+5*N^2lO9_FMbw)Z;+^t(8S_c2?JY^sg12`c^En4?kHH+-@8GrNyR=Wyjql5KV*a0NgfY4h#uG`Yz)#`Gp)t2h;GkbAxw2knB9VNy+CC=#%e0)%r zz0t6gIYpOG-N(YEPYWm9eT4hT^Pys&0$WEmak^$4xeX(2jSg1+c&8v5^+f)Uq)`^w zBVx6iJYogHaa8zCi)N9G`6*H~d~vD^hGc>YJGf1)iWlV?-1iB*T)&isJOyuZi0hp9 zMATEq@%0NQvbUs>M^^&L@X$@0?qwh1C;%+Xu3=TJoXsxO#T};){D=mx3qEZ7@p^0K zoAENf)FBRa%6Nhbp39owZbBh0 zro+xj*znP|&+GQK+Y|UyGC}m(HNH+y&j+wA(laEn)tW&|TBRya0LD&~9A`Bw(|#R{%x&>IlL zpLL$Hvb40Yp@O{t>w>A9MtuD1QhP-nIyBzuS zPlN9rdf#0<+5+N^Emi5+A2-=()fu(vZBMe^9RmApm+uIMWd!pQ)R3&1avHy7q@OH>06hZjupA7L4}jXA zRr$nH06@~OSXeG`=__><$n~tXvV2*HSmh_ha(>Q;g&033p-l$S zwvG&!&juqA{0@2M@a=70k5Mef%(__|1ibF@ij2-15s*SYg&KEtw7OYTGa zLxklwfzq;d+_L_91L;WDy>!0dVx)K`k(+|YKWxyMEWfY;xZym=B;cj0(!)dU3?6E> z3ydq>*vjaP_3$!AS~L2y^88!vs+k-8;0GBvLhJE*QF$sPY?bRu{2LWe#Ia;W35_j4 z25<75_`HzAdWFx=%dhh^xPI%Q{K4!Kq!&S>B#hE>d1^(r^Ub>;_Gy$e7HIZ8 z&FB5&!H1l{vJRSRF*ArSdM8D_dYOtN>^REJbIOt-0yG59$5nI^<)RU+?f0i}+i&6F zxW8*>0f!H9`%>lk#0uN++0j$-@~$Z{cQBz9@b(?>>-l3fc1@-uu?^sj>+cQGYAed@F~A}@}* za_m(YDU$?age#YCG7au#v=rCa8<^G^f_f~IxZOMHxvr1D|1SVkAlcOFtRqkZ!GM4= zVSs@C6Ze?98`{`gn>sO=7&;p|nL0cDud6R>cgT*^bEyuOJ+x6%$O&i@&@J}okZcfZ zLNMTg#~2vQFG-LU&?m3pw4MLi5sz#<2D_=?UPIyXel$CozKgbRgpm!Y^k66SI{t8y~Eu(+D@_w@+2cM3R_o zNHsaw@H4!*2oDaX*$w{F^y;jNdeTt-zhEn&VUWp$nZ zy}=9KP8cIGu93XYYf*u6X(3e%kf&8M$cUgzi5%8qdXWLV+orHKYJ?R}E?I{d<&!Dl zlB>6HNvTaDFuB3U(<>_2@s^Eui|36#j`Mh*}AFf;T>dM@lJTVq!g`HoU zFnOdOTyo&Z_buEeo!dx6s|>*RsZBgUN&*(oq{N+0NI?MWNDJs=%qdI<)GI9-EH#su zYT$zfk*f49bfpkIV3Zl?;f%J&x(X}P&p#hvJ$kolSOwTZ)iqb~)|UzT;6ldu14IJJ zCCteyqSh14i-MIC*;BQROmEZL*je$t`LP;s8>FH+!kj|NiY?L*3CYyc6E(UV=bz7} z3(SrX5eR`?3)W{FKB70IP=Ue9T7_sS)W%E3R8A&sd|Ri{PV`gd$4_uf!9d{{GxKeX z`9IzQJ{b|{=AOaD>A*QZ7lGbUEFcJp&ZMkRJTPia;lgXTzp`LT%smN0({`yi^K|AxS;s{BzA-_@H z;D1iq^)EilcY(Q;5yS0)FFcuBJElt;9nQSDN-zDH`MD6X5Ar^2=HJAN#faOF!W;=J8(td+|E z5OFllpmP|NCyQ`~lNSQWFKNQO6o|vF{se|R*9G+KbP0Cq3|(Givoi=ngxQ~B25oZ^ zW9EsOu|B&-ia^;TCkEji>mgE%4!XFh&4IC+WL;*BR6202U*N~__Uw%~5Xp})KamdR zXz5O+8T&C-Hn;JpHG;hyFFpCd<_@u0rN8{vD~QnWBT7sp9H z6rVRkPN4O;x=tq&Jaz*dg93=AIpi1v+(x&Sb%|y0U0$Gaf3+|NA1^9RY^glNuOmkD znOW>MGZA=mw5RuC)@5cdBU zv-F)!ot*!Ht3_?o_8+)>R{wzua?-z0$Dg-43e7%=*3$VBs9|wr5DH2PXZ=d4P)ezx zNA#yhSn7e!mIJW{Nd$$9h1cn1>W{g7S!E1HR7uF?WEG-)oW>XR1e{qqy)^l3oO__=SG9M8sI&kg#Pg z*|HH>4>>oPB3S7Gof6)=AmOFj(Oq3jLy?|@k-CkhHlfgDP=hXUgCKsnpeXwlKAc`p z$c{DSz8Ug&Oz9y&go!`S{P%DE*}ccJa{y~b3}NZr8<<;M(&xSPU(SNIIL1ux0R7RV zPpJ6vLSuVfKv%V&`~9@3AKzUz0yRL&XRXPy+U2Nv=k)uQV3&hdgz`h znkaXG{iXU2*>P;gmwa9R>gDuYJPf{$s%CGj!!#8QtHUTAhlQ++m|(3cu&XlHT5`(* z#wbQhm4?V$G&~M!E*7>nOo#c|%GQE*RQu+I3~Y)o{Kjb`y@bn_fToLds)Yh~n&t6| z&9`fgOF3@P6qtxsOG06@a2XWjo2x_a-BtHJpL(Xn4Tw!ofRxY;eJuSm8Hh-P{02kq$mB;^)kgx19XMU0qC|^iM z(w?4cp-J2Ir1E~OR6!t;6jiJ#OvQL_C1;o|0&!@YOYx8(i@eU>_q{RBP=JYbLo01& zSsYe7AF=>a*L7K`A!vP5Cb?CL(t?yNS{`3MvkLbWy!CGt5#p# z1%{w#a@-TfImNN=xrxRZCmc({4l-m2o~50orHC}ET~JG@rCkt91l!XA1T~#I>f~+- zW`s-P>QkE%eTXHtsW`6_?Pcoi39MFGC%i|4QiAd?kXk$212?#|ZI^6cx3|97R#1_7 zK>jXqZNfz?=g&96K_-Daro$MMJd)cI*AI^qWPIVG#I`rulY=2SLp;;RKB&(q%m?3u z#_oMfPO@!G<2UaArQCm3ZgYk|3=IENS;9ZW{rVr21DKhaI-1%#0}QRLJ@oAz0j`G5 zrut6iCjUX6hN06YJIYtBet<(MDK&*+ViISBowt*v^{KN9UR+7@c{4r+3ULIqBqNZ8 z| z4w+yxOvzL%a(-YICUO~;vgO*>etmq31@Ok+q=r>{+k8fqjqbxcLF?!W8QyBMjt z;@W{RuFdWQXJaD(5Z*!LqdM^pM?p&aQkX^@`ug9>xJfWZhON>16t@8L6;MCsO@V*%n2K zngE=E%DI8lPN2aNCs)SI`uih)pWm6$6uN6~0N$_JO*iG2#j_7jXD<9a*|klpmy7G; zb9ZiEuZS`y_RhZX#0?Vf(I>>$9@U10?`(O7Q-^hV~peZ9#1AR zfNM|S7C5#K!4x!#C^NI|jnpTi2(~xSyk-~L$bB#LrE+s}rc;w)10mIq2vH_v1(%(t zqEmP5ixIuQ$H+*7j3}o@DXsfWX;sFdt#PCsj{0TrTHc|EiPMHX1_K6s({a|B2lRq6^`*MP6hrZxx}Q(1ap*O;e&*p8x6R_|l{{ z|BALERoTs&`#NDy*!{II$w}K5fCuFO!aB-VhUktr|8=1)SUHBV$rX2N~6-7-qGdg#k;Y|$s z2l%=7*U*k9VRJmH-il4dsEiXz5M@x|OWq0_=uE4ZJ2)v1e#L7#GGU?94gH2GTop0X+EVzAP^vf19J8*3_0pHd&}Zz;A12liQpWX3?Xg}*SJ~U6r(>{x zgeNd!ONYLNks%?tPlk7+w`PBtG$aVWwkpW?`It{FqpTWfNuCygp%`lNONDKDoY7mU zM;ft6r;@htF!0%Yxf7CSi*C!|@@}2O%K?qx%f=V(^^K|haBSWnj&R8C_3(Z7XoSf7 zJ7(SeBgHyDuUU)J;|%y_>2ueLB0R}N9{$2~;xEn9_h(d8vhdpPVNtL&(+uT|N({xj z930=^?;}wFv$D>7`Ef~ez*>!1dKz}&Alqj=<}!CALGYYT6y2arEWj~e@yKL7>%RT8 z*i)$h{}#z1=qR6>*c?*@NdeOby*v`g(ikSD-c0@IbE=howwNsn8D!d(YRB&tOZeNzLIcX zXRmk9SeHg|Xj0~xqEabqN11~^c%x=a%MC@MUENKr(@zS}o zt#Xxw&MMCC?F1$pA)5#eVP8Uc;N}^uE~rfRrMY)Ct|;pMT)%7GV*Sp=3woI z#w_QLJrFXC1uceZH>u(oS&E1Md!c|%%(SeDlatW<1>TTMqd|=4T|4pEH~1jI=5}Qz z&O%&%j91o!$IO+gb0|+lvULdL1f6kSyr8z&pf*7hzteEVt>j=UE5M?-T@su!op>UB z6dbWP1g%2wt#MMX_sV(gcQ%|x)39_mkLLz4YOnQUcKIUZVL@cn9XGugPelLqs!9=4 zP5&QE)|5I0j^-(f)}_(1k#+bp+_%R23NQ1j)$U=ssMD0}sQ&GZpl+5d45555h_!Bn z-~)Zry4+amp^7DLts_Af6Kj7P?(+vb%Ghr-Hc?*y1$J%7^{(aSMl9NbDaT=!MeQF} ze;S=3P;;BI?1v1A0z9)Nh`m5BM1`MH-&bNLj2z?*k)!RPw|{HaLhm)?G;|#X1zNdu zabyw%0W3J4_VgV==g-N!Z07e3g1_6iiggmBBzo6T-MA|v!y*Poins_8V z+pD9kh6O~^z|wvuDSr4+(^MNKfrA~u?}ho(CU~S6L&p~StvV~&k1W>z# zsUR8yWb?wPyx#lYxEp%zW9kcV)O}~%WsEQ2*ZPmpol#!|%F{E3XPF!^zBk6u8pql; zs&`vN5&XPQ)hj0Cg)~?;VSYj%`g?(bsd)Hj3rY-4h^qOKN3S!;n%|+uX3KQ_kazWZ z@Rcdf0UFJ-)m3>l2bdrq?<=GCM7%oQ75t>@9m@$nSD|I(f<=v;myETQ3I#4XW8cHl zfsS=dI~9vL5am=EJ^l&CFCON1Yk`I@+vS#)18zA|5ImX*a`0zVKA)3kuKF>1CJ}&f z3>}}6g*9mKc6?+44#*MTAHAPPOF6(S+C`hgV)Sgw>#NR3#KV!3>;)%YID@X+Hg!rP z;jbLYR0Wee1AN!+nzR@E+3PUDEm0iInGZ;29sa7>gm%%ctnolm#Tgc4QC7R=qVUoX zO3NghC4Yq!4*}_%m7o7I*8h>5?1vI9)8Ifr)&DA9SpO|K&8!XG_5a^!@IQf|+JNmQ zCsNNR4a7LHV-;m=0LUi2_uVhRs>LOm!$zS7Et*y#ekn>xqKb9VS>G*gq~jG?j}1bD z2!iQmXv?2FR9B-9k{H#LQbXqaaCn5l6~P~Q_cayx6eHnI z$0ZGD%!vgd^u#0Xq!!W=s;Y@j1tIf9wTbNy#sO%9xwgH+Oden>;7=U8C|w?u8PJnH z?b8bm3W};_YGeXyrHo_Ad7OB?)Uz=!d%4$H{u5CRHKY(4_n1>MPi6c3T9R_}oX`RzRJ>r9t=6znVZlta!7g zjNab3JH9_39@%{G^>t&)R+FNzb@k@gkDt69`8u)>uIXaRcF(>{JPDG(RG~;$TMYjK zRiumgxv(ZVBK8+r0{YF5 zr&p3OwShY+V~=s*s$p8=s==0!*A6GH^wc|3tnkD|*h*O-V7PVOtk6I#T|PT@`17a4 z25auT?mmS(Y9jSL-RVf|pUcMBCyF^!$AxLERLBQ$R426@=M4BG;w_-BK?)6L2R)Q5 zp_n@aV4$RwCxWR#L`Up#?i2seav*Fc)K{1{Rf? z4P;apg-}dsl?&D|K!c9ymvfxqAf&Ue47d=h3(SiU6Ac&Ad_U8BYxg?@>&XL?^nx04 z$IPgep^7XAC&)oK@csRkD=FSmp%6_TSq=uYLdvtUA|0%Y2sda(sDa(isYA&)EfX9goe)EeCl+q%fj-6svzqqevvFu~3>?t*)q2GeGHr1@e|I`ekD z2XCH;s|0ga^g-ek1U%CF=F|~b)1m8=vn#()s7#BM`*0XlpGt>k*HlX@;{n2dDnxj3 zX1~nqNfbSrxKB;N(19Y{&i*&ffNpopa86dcSAK$smHU-&%mmp!84nz!5aenCnu#!&=4Z1EH8)nfS>Gho^2V|R^y<)y{DE^g z{jLjSI-|=u-8EYk(wD-LDkS#J@Qr69fYX zKfPrSmCgb>PaDLNSO40MS1For^mjk`jOdnPQW!-_Ktxzw0}QQDSG{`8qFY9 zN-*Rt6zH!y#`OSwjzUWIzgM;)o4yk!+T9vVDmARzCgJJ>pBOAfj#hJXqcgT10@mdH zeAnkdht;T=Y5}WX;{oze-=Rj;IUriA?5H6!r&-C5Q72_Zf)69C`78!h7H+eKGgsGOuo&a@^P+!%Nc}86()(1}Z%jB%oGq1iIs}Y*tC* z49p)h>OX@#%vr-=_7T!T88_uHceLGF6gNbpiV!6p^YUa?>--$P&Bcf7NbPr(VWpH~ z6abNSBHU8S-vUyLq-Ge3D2r#|X&*iRy%7G-2rwJ#7bf~Q0^I*45C0=uwJ-!&x%`hA zAZz!Z8StN_knmBzVzxoUrgK6N$91cia#{U0U1WYdAR=}BPqCDe!&c#s&#pqFahcba z&Q#p~RF;RG7+0Z{vQ0I0YA*f9Y1)(2#6`~nm#k^)QYmXlVyx|wl*hjjKq-yW)Kb*o zyuiQ5$Q>DLkRW+&BG)ogLf)6mS}qcr_fjhUBXOBL@0{>NWfmmowS4{!jfktfC}k~y zE8mi2Uw}uaLM6#>*7#7&YHFsMB6TLSp!IGFwa5(&_p#N<=)_<^XP|G`w$?_v4{GqO z?Wy0x%cEVl8Uba1Y8=(n$mDF!F?6|FG=8ZJgdop)y{iQ?zq>&C0Uu9B?K`pG;SuRRk!HBqESuC8=Ommzr1L@AyolJ|gKQzwpZva&N zJTqK4bAvi&Lre+>>)+|{gt}>a-2Z`dpJX;5OQ|;^=R^Br7PtfBf>@K@O>YWfuGm~} zCTeseMIYOcTM<*}+5>koIRD`5^`y#sLx-wcN5rP#y}bOg#TZK&`49fIDX}4R1f2|| zg*vIjSx}cskqXO8FbwCBGsoxPv8RLPlS>u1-Y=haoQj>5UoS`C2+lYt|B8^46ZeO{ zc|HX5SlvX#{h2}0LYC04`746mC|^na7kaizSNGN)dEQuYJ2D%SZ-3*WCT$Nb)K73r z0+KOWX3|f|iW>}i1lByXpid9R;_(Yay=~aKZ1hG>zv$`Li&A}Uu6$6Fv(|{HY>Q+q zU62ovBmt0ZRqnfPZ6hp9G*9(8H0kT`DWe^9$th^uF1wabniHoE2HVdTu&@ORLcz9z zWW+bQn)L@|^}q=cRq76pXGcecU3oaTZ2q+c9(I>-k?X0J@lohaai#sy)~e&oTO}v% z++aRr1JLrdQ%k0Dh0dY9E4Dmf{Ls~UWTgb9!0YdvaKe9VUb{J#XoWU zhumbW%)OFF3g_w0ycgI$3m*WTE}3NJ96O~M&fGlx=Ok$p?}5HF4Z(ADi$u)r9s0jv znN@vCovwTNb2_1@*22sbvUh5fvj@Xvkmc!zY8Jx9wF)<>;8QS~1YKWn^G5rY>TEhZ zGJ@;&+@?2!z8ImgJ4s*?d-7=BdqETI>F?A|Qyy=%{Wf>VxeON%l3dsEHm4i3cZ@}x zSBeWVKz8x7SNX9`^g))%y{kY|cGbs4ahry7*WmkpX)En-H{ERYkf`H0>650EnvP0J6I=xH+mL~U{ zHLWr4rI>H<{~6{k^D%3+{=@3oKV^;k-^$wB&d$ov!qm`2-^Ld3zm`@tYBv8V+I-jQ zJ4o7}_4Ni78XmZ)4~xLNzGB55{v#C4G%Io}Q38aNV zB{QhEc!PZbY6)jO6Qy!6L(=UHedQn|cz1o_vk+7=I*H!=HdAec;K$(iHrHCL>|9&+ zED86<1oo=MYLwE}$6LxOmLA~RciMfAe8*wW5jLx+{%5ZrSjA3R}+&PHh38EaX zX>SfO?g}-&de3wAL85}<1#VKz5*~ZCp;&8M*!ZsCz@mM)(I!%NrtBDVzCm}6T>ahN zjG?Z5XYTB&u_Ej(x%~%MN4DHJJBp_lTsVr`{bhH@Wl%EpdA2tNtP^=M@~-(X1ha5R zu+&Zx6IhuerC%m#S}8^$^$V{SXc?g3nL1hsrR1Je1t!`5gf&B4^nn%q_o}|Vvoov8 zohNhd^fm?y|6xHKg_#ZIlFRCn;{j{iF>53g!Yi2im3#HOC&7(}GU5 z90mzrMX_8n@faI*h5aDOGPZ)lxrjZxf>?t`%YX_#H)%|{;yAkL%@vmGK(OZI9G0a?c zLz46j3egcjfXvTWYkfqeVpa=fqF7V8j%}Q9T)X|a`PTGzO`)1^`beKZ-Z=X3KcZpg zL?7p^XYeS^56DncI}sE)8L0>1oK)yf0{_m94}&z#jQyz<4IrH1hbX2?adse%{$tav z!nd2;`iU5P<{=)XWCqB-L~E)q#s8j*Rj!jQ-ZZtgMOuo3&6AKe)UB!&5_DO4J{}QW zu4FP~bpa-Vg$`^v=vfg}g-50dgUlGtb*F6;hDQwWoRi}6^hi$6{cS!6qniz|qVb<# z;M0IgtL_55g(7f@)%}UN%$<9iyXTsJUwh7k&;_8adL|bs7>;CY`p^>`49~SNiO$Xu z!BlCxa=K8`{QjA!IZSe9JEjJ&9wM=}-Nly6tLQ|8*T7C-_7mGbt#9biz7)5XEIL%3 z2xNw%*@J0x^l*_>!4D&NiN-CXsk& zH0U?iim}*9IA6F^GGytwAHSU(+zlU|7h_x{s4o>G(&6(!WQU3nd~t>~tx9Z}k1Srm zfO-=YwzMI;Yw>bNPb*;+@yFo?F-1!#Bp0fRR-=-!*n@4dDeZ*xVui{3G?h{YD}Hp@ z(OE4_I%6)12`js@t;fSTFY{xQdN0`Ne6f`Qm`mxj9;epBs->`Sz!!Boa&<&>Mp5knHw;m+a@?6x4uh z-+OlYBGPc&eT?0ucBiTB9A3x$!mac1g*-I48Cs6)QXlbpkWN9t{KWr7!t|o?Cd*)D z$!@e%Y5#SGD%qHYL=fGVEJn6z?>F!o)A2&-(`e5pWoYpXm6nqbH-br+NI2Vz(9Jz} z4Enafh%0hBS%-gH**VEd(Tx9xe7J!mZJ6mgZ7#xi2mPRZD1rErRJk(#w@Xmpf7L4g z)2hnF?ZyoMSF4e14p!6htb*5W$;U8!U%)020Df_zsqv?|4Po# zSnVn42;iV{m(|GKVv2FD;YjiNHvn78!xiwNNukaO%c%mH61AC@pl40Cul%rL z@j!s*Q=MX5aS3WL;_axzA|2yZk6?cp{u&tL{VE4;+Uj%H3gzWIOo@9Oq_)eWgw{HfVl=aT4ZOY-K_XTnI z?)65RCI2a4jZP?xm=g6jqd1bCy(Wl!o&_X?59I;vkm&4;2dZp^@EVMR~XM=3@yAi6(@%z zjJb4w%PNVzF4_Qr24U_~Iz$zlV3i|O{mmK=vv&C2k5q>7_99k#VMuqMoxPqj4VeFF9 zT<3x1djiL>Upy1HYt8tfp>L!LG{9Kld7v2EKy1?%yjCB$ih3_!p6q&4P4BmkOnTD> zV8S>$-^K@**4#PM$IPEmxARa7hJC9?qRx7LmKgV!m5B^OuaRE#bN4;@&i2H9KhGcX%Q&mWo0ef=PcNy;^!(obw|kIp1h#g@icQ`wCU@@%kn#}9 z&>=4u3EavyKD@2ofCIOv8FeispJq26Q8vHsIl*fprC(Xd_LSr762j@DuI zv(|LDOHiq07HzY3`sFzZujN)aIfG@tK-|XiQ`1jdhW_?D>*MG?RAu{qY)NMZasmaO zM(d>eR4bOvUarFoWJFL4NClBgj@LmsmqarvxgIveJ+2A7CNlg_(N-%E0FM5z7(Iyu z1cdzGxA^}^HFcYR#YmKYQjO`4Z}^f68bv>_8-IH_$GvT|Sx(4xPa!-=K&nL+ZiI>? zb*%PN-!2^0xKifz$YzAdUgtWQh(|-XJM6ouF-p^kmu*zvpqF@1YiCsXYRWa>;zkARSnA+_m_J;1r^#s;!NDr%ru6ADTvs96Tr4dp&Y-^4 zs^u-Fl+y~m!)3DS2lbc5v-*j4xq<2ep-N9{;+MAQrrsXT!SovE+ z<{#wVyy-f~eq?Iap*(5X@1mqY1Luk$@>4!Y_6Z&nsocmI2`2iEnzhs+H9iMACOpax z`rrZ+9A^e$Ixgg3&DRq*!Gw(apJ&wU>}(k_v0+A}CzrC9Bg}8oy&Bi`w7dHk)M74S zc@7X=A;#Y`=9B0PQN^l&b*Jb|kYHf|Y?f51TE!L^ZVgO*SmN--5KZc7@+t5X?dtcY zxJ;k)Eeh2@LP|j2_{uN4n*xaT3}8yyiQG69JJu9Vql?;9gdHrVQLQL) zq|C7!+3m^kovmfy+PN`(_JlBT#Dpz`3GQ(Jj=4RzH_y9Qmz<;N^75QlbH-PwGM2a- z3n2_NzfbYsU+(O55O;TUc%L=MmHEA`yB)vxJMh@qQ}*8j(zn=y9}?^Mdw8!g0&tB+y7OxL zJirqCT3T>g#dw)i1KO9v(qHsB=V8gA{NnhT+Z>id?jh8!5>~@*AQsdUq*Nt;OQz#8 z>2#$|_M?R(y+BRFS`#ewZym2yJmOW+b%lK@D;mi2`h-u%ru_AN+XJ(LiS}_e)Ms@h z>8am|{Yf7S8qJjEuk^R$o&n22^SIrK?HUv3e~!+5y#ZU|xG4|dBx+T{j?yKZ-?#-U z#Dtkj9d3-rQ$2c{y-D#Mj_q5KAv{Bov(kazxyReQ%@xzBSWf!GcxZpPv6Ynfw}JGm z6U8Id-QUr5k9%AsfPC#~&Pyi6eDQ`!&k&R~{WQ>=ttfS{#IZnl1AX>< zm8Tn!)64@UHptFU-CD;y38p&S?A@7RvbE_o`}zr*i*T!$s3`+#ybjPcEKDJ6^V)~R zmz?L=tGz_Ms?#rmzZ7!ndU~BWZlJw)w!6%Fgx#`RIuG50Xn5D4BAQ#a)>`3=2w}gU zDOaBDub6+0#y1T+lykw$<1dtfaw5=5F$2DL+5YZ3e^?5$-=ko1;?|W;qQCn~3m%lu zRJ=b0MtaP^9tjG;>TXJ0w7T32#1n6@Q)QTiNU}Qu<+@zMUU=kW!uOq*n#{G}bK#)! zw5{nnn#!EDEcwvlJa-+z3j6G1`aB++Jkgf*;Di|`TC%1Zq~maW*Rjc;>UN2^i}R_C zJ~X!MB%MgzcM429sa#|ML8lnQ8NK`0vQ_{h2Ts&>upEdBmIi0!n8!pjb245Au zwk(fpwm-NoG>ZH$UFaQbpr@t3t%vwU(nKT~!?&v10y$6*1?? z@nt4#^(#!`Yc053#{k(--}F_-T>?Qx)hQQ!O5Uxx)#RY*pQ^_BO@JSk5(OQr9N|{q z?wKH7U?8_VzWS{+AYCgzrXlx9sMIie_U+k%OVqWHOxS()+c-_|?$`%Q+w2-eqaZl8 zQl+rDTDY=S^xl{aRklDqY_GyM$Hz?zw=ov->6P&Lu#_T++`XQ6l$7m5Gml!eE`q!_C{)BTk$rNNL9oR=YL_J625GK1{tYp$gn;dUZ10 zvtmKWX+t8M6WC%yu#}z6J`!yvJPRgIBwm6@=!79$%*4FE7NN9N!n|H<;y~ViRFBv- z2tzxDvsAG0Srmgx?6M)WWY;7tcSPl9VHt>(BV;wKeK{JtBp|{F}^|P zU$ER+)3b{F1(7ZNGB(`4K$JW&0-SAvTfp*XXz|sINg-1d^v%{oNCxG6`3lK!OXgG- z-RY@&$0)^5phVi~9`VV$rVlcF^ZE>JqEaen`xhG#u!AxARdWrJhyl?w1!W8!K=pN` zPR7@Nl}B1EsYoXYJHyK|8T!QmFOTXAE2u(M6r=N&sS&+QG5k79_9Z4!0F5Qc;!WR4o~k=0 zn8YBir(!$DD5RQgU`*k}*eG?Z8EV8A@}-mPoI^Eg5t)*9fpq@Ny@NmLz?d6>Ec^J! z;-~ZDj2Ot8-uNEzH57E@mSsX34YB%uy0Xg}2=8X0e4O4-%Ou`-q!9pj@q|6Ao=o?Mh6QX&g`-kbcN05Ci<8jGGr$g zXdVHvk~3@-@uVc(@cq?Ui<3j9zoQp$R!IE{VpvMe*;etHA7i?0x+Sfh$o!n~DdoXU z0>!F<)`cdfa(0&REDJuXBMiKQ@c7o{(3cQiy0k)u_#2{VCR+HXOze!$xAYD zndGj|*7Qw66kF=W(IMEH=w+~Vkl%V^^adBZdo~mrwA|oAx!yP*s8o|WZaOi)`GqvPW+Ja&T_WC{p3NYfgTdfB$;x6_DQog zJ39`t8WvjMMvg>ZaIh?=fO%bod`RPmd6D&OMzV+IgmQnNm`xrfV7nWM;_p|7t zgE*0y)66`j*Vxtl(DqEP&iZnRCmJ&_%tvZXsNF)uZqs+?dCF|koM+~F59iit?gv;B ztQwUa7VN?174PZ>f*ab<-?`$l(2lMck@HzZKkKt`_|hzRR1hgAW$y%VGkw1Y%QGi^!~$MXxjV-@cp^b3sB(IW;&d(V0zv6T!fcAXzGYM zjy`6&jwz)O7DO^k)CUx2Vt&7GC4fXA=|g{(WwwdlJ_SU-?S$qQvh&b|XxXHGx;TFQ z<-etvwo^wd^uoGk&b`AeMDo$6RAR|36iLkIib*|-m86O#=yG?xhe1G`CW)CRj4J21 zl$$3aNBmNiLq^flLqZt*qR`76Fg}$P==-JU~z=?*C|2qx#*y__w#3maw?W zj@njH1Uz02JKnrp{`lZxXe;ubK%yc=dYT1Bo?DU$5QlpaHU?HDI|^>&HzubzLs@;v z1yQ2e)FB#gy$}Wo(KIAW(a#F9bV;p;+qeSC?-UO;jryZmV|F%e*sZN6h?|%8%D3y= zA(n^k<`%X5{oom6dpozex!igD{oo2({AynR0sHOM?&%gm;f(4X%6lxh2`|dD9dkVQ z_fIh`a7!1H_z8=l7Meh3!*;+G#l}7HKR6xySh-00qCZ(*7p9rSGY~;VvP<-Q_Up~> zTW89&aZ@{?6n}!U--H8*OTbW+ImEG_*u0z-bxMd^O@C)~p-YN~=2!EJGYspWMJ=#Y zQzq9r&Hk@5xkA5USqMq2XKZo!3CKw4hw>_7{NqIUE^(|cjQGmb!RiN5S% zKBG9*s7$O=J9fYvF)o-nsdS%#yx_t#hTxdN5NtwR>wynz^aOhnD4vb#NSt?fRm=^8g>YzW6EDC%*q32~K zz^X+}}&~XxLO^fy}a7r35Zlz%Yh?bQ5I6o2~(_ez3Rz98hA}H$E zUd8?}hYcd1H)$aQA9h1fa75V%z8|_#*IgzfDu&0PH8~F$JuCq=g53iO1CN^46VuXw zKlVk(U!$5_k{U}WJm5QlblmznD@7G9n08QI67EmoBKq;47@fg(;(^R(c?z+tA2Rn! z4mZb5)S8n0h~d{b~01;AbFAc zpRnq5v?h%*DFT%$Ubc2+N5@2Cd*j3H)Z^OwsT6&KB?8AC!k#5%^nM>d2><>}5%ZIjp?{oqq3N z8l8Lti-GE6An@@aPveZ`B5blzqlNe|1ZFMfE4D%i6lN&Jx zQ$>UWi2r7ZK4VRUaY>ahEWfbNC-l|fasq(Ver9(FMOO@Jl$^=F%D;846NX@rSy~cl zRdCem_5TtJEXF%wl7OxnAe67-T$z&$KFk7xbo8jS>x&_wvNGsz(xPq<5=$uVDOOUD zafRo({XO5^@3p$;Q+*0C2vdu%@?kQQTD}xBSU;_LkG9;N4@WU;bQdio3hZ`ZLM%`q zNU^Ev;h4~0N5`VG7UpA#GM_cx?&6KJuPt56$u&5=BQq302><0->kgHvdXP7@?FZdyh%N^z<;Gxqx&J}J zNu!($F4^~~A83(#6_N{Vk7-7vckfIw4bH#)b%*)%!q==Xx@Vu{E@S1e*I5%()e^hS za6p=)*l3GZL0(5oGuqRdA5P<)a{TNO4jXNkS{Jl0EV~_WcA!j%TjC?`tjj`bIkjseI}<}-_DbOCYzm}uxrBHc za)~VeWb?30leL|vb-wY4CUehOoa=%ZQ%P}fY&l|k=`^! zc5V`J#cUmZKHHF6V<;lEptUTtD1f4#UuA8X>mI(K%I5B@Wby;QE!8&JOt;M)rn+Ep zXvGo7=5p~+Bhkt61708i@<^y@n8|P((lx9ov|Bh8WlZg7X@ks=)GGF+HGo2Z&E%7r zT1}qDZ3o5A;_H0JEi~6T!uXpsQ(thQUnio#+siQAw*Zm03`PRIa+m*fi`I@MD%&u;e@l1ZbdJT9j zrFjBkTNpFZb8P{l(q%NU-)U0VU5y`??&zieXvfg*QRB*~p48{QN`wqzl(&FI2Rtka(KdsdcB zBPM3#h>7b!Z8Sf>?XqG={yB%r#St1?ks@(+ap5f>wlxmb)PyVHb~exosk+Ne)mUIE zWVOw@vc6kU z`qCiIrd>+K^b{6L|2tRP3jG9d1`eX-@J@#RLT<9NuzMkoavUZzK-BcsmK|iFwh@0@ z@^gXy_VEhv`yG3K8!JAU>-2N*mvQr(XY_4Y==(J)tTHTC^S8%mQIF%Voq*2w>~Le$ zFVO!C_Wwx^DPn!FrT_u}*!Z{RLjE_{|JOtH|DOLn)h7zrW$`0~+ruKn^hrYlEy(xQ8r?)1Eyx|*$>-B=s@IJfnDyg{}K zZ_}4c8Y|7s+(n(g}9<$g#R@KHUHV$BE%7 z3#|;y<>_MKJ==$#rI#Xrmsk%0*brC>X*+E^!u=x`E>Np#wg_PAr6{YJALcsUCOwM} zGj^febp#2cl}IeYYaKF+s9)Tr8}K{BVB+q5@$DLq=!R|Mrft2ljK^&!=h=of$-ksq zTB#9gBRA*DnRD|9i}eoOUBn}H0F^(0jJc?cNuhJ4mb=R!n|;X&-Eo|aH7&kPG)Jh{ zi|^ru+ZZ2b<$I>nem_7BFief#Q9ZS{IUFL{Q9)AXX@~k*6Al~9juo06jAlv$De-1O z1msx))fyfX*tEMYi12NKi;rahq_62pTZB%;(yF=;tscVM|ErS@QDMm7_VhfTzN4Ou)Z8FDbO;x*KJc|DPO(FS1Y zkeDbcZb*vIP7%Kl#m4~vKaCBN8&W$rObKc9p!?X~A-)$bAAV$j9HRb{L2i*eBUn7# z$N=Yr>O3TY62WX;J)E&Xr8O`ZBnQVEfQFo6jv7t`sZ93JjzkX45rBUpBlXof;hauZ zuzMZ}#@c42yKUhMh@fggXxNS5fhgaW{%)YS4HEi8JQ} zokpCslAj)c0ud>E>U?}7zb>mk=7+1z z>2wAX0jLUzuuB|WgoHVP&XNDF(*diP$!?z8`D!u`E?4UN#Ntq|@^Vt-|o{PZC(p&@_KHu8RxH>EPj z(@s#@owFfA1yHFR=+#@-SfXPGCmbW!V2t?4ApoXq!qMp5rddHGbGJDwd<&pemCpi8 zBLOgnYMoXo?7*>3P^swQPJO;hfPx!W+|;1&@jmgRI<512SLsIg=gm8V)MD351LcO^ zgCXCvF+v52pNW{@C%zfX%MLXPK!CS&dfI1$6F}|cDd7V=4q=xb6&q`S5j7u>F3`8! zqC~(auKI!U1*)S0kGH@Gt=z8bml(-X9sz|Z+f081?Hm&)gaZSVD8Jg{F-|~ z1lh7Gq3luh6ZFXkvi^U-_dzsxiE?8#S~_}M%xrdqvw>*#4h$_os3ADbk&r>sv6BN# z+n7{=+SV};WJ&=n0zPGq$H*N7W|9RYn7e;XbO9F8p-x~hjCwb3@sg^R2xv@PG3g{^ zC+c}Ya{g|f^R{{Y64F|;G9y?jdjbW1A%!6q8)44F>@zV3U9u;Be>f(PF>D}gfLH=0 zWp38ux65n(+M2QHwnBpc@^OHB&tTwJ2cGJ-57mWA|W;O+UDG7;=t;2v4 z0H+sWn&Z^Lwe=5~GY?|8q#Gqy&4&-q7{j4#F@)Rd_k_~MdIdSagUI_a+_c)~g5T*) z%DnacELda8bB&`~Uf=7NEXQ!KLAG%+|3IdnK<5kIV3e6)1m7dH_cO?a!Q5l4!7#4= z2-%7bSrVjObv!WT$Tm-0r|AIqgg@9LR@-aa`aQ}^@l(LJw=c_t@*_*mE>hfq2~Hz^ z>W&JMZ_pgb#7q!^_s%X+ zFmX=7Gn^BN4Jl1X&VNlZ-P-S9VVpWa0k#d8Rg$>In*Qb&mhE7Me<+M?BS1pm8<_f7kZS_YtjWzqO#L#(@$sox}NjT&ykvqIy~6~g z8hOk_>=DEnC$v_ME9%L`Z^$_6V^(@Wbpy&2)501kKJ}qi2Jzy5Ip+gA09gd;`xO&C z2N z&72b%*x@@GdP_F8uU7wP>e*kB{$$@cmStdQB##|pXmfT6X4GW=Sp2Jlyej?nCe>DKHI_=^c3GO)es(khIxXanGaK|KV7 zmSVZTbHX!lJ*tAgb7%p)DL6_l;!Q$Akblx>HM|Ss2p!A&8lAnVLFk;zFqwqR^W&=V;p~eX!0Ez;zB*zxvg+jRS|giFaj|FseY{v3oMay(9Pw%Dok_$A!&WskqJhD zdKi;(vmgR;I{z#Dc?FDhFd>*i$SH>?A$kb%9Y8Bo^ySxF11Lp=8;exu$Ah zqF)&2dNTYLf!*=MyKjZ26l*rM@h(j<>?-_^0qocyOOIz%_cZH~ewBLbHMs=Ck-tm8 z#O%k$iU1Afmv9oYFck{}uH9sC|Kg^vu&fP!z|Pjg!;l95r1SUn6@XhwXT`q>XOk7n zc##8Kgy`J>pa>;YNl-OKr&SP1w#d2`Za?C>sa!rael0K~A@1PWa?ZUfz+!(qN+IpG zetBla(=w5;g;lL)!pHdY6Ottf1G->rw_9Ga$Esc!_O)mz z2*T+ZXJUphdO&31%g*R#_qcmn&wrpazcK930Xg8mvC0NyFMwbqWicPVGhKc0yWhYa z_C0F<8O`7NdkKq)kPj-2!;7Tns#VoLP~%akUji6vqKoEWmGf74>eS^@mezq7!3rm+ z0Ws)GeS&1k$vl9rA4AussUp`O22x5+t5$78iEmUv^khk}0RG2*{SyD5`FS-d#{*An z)umlWz_;{k;w+Dq?C8&vw-Gsam!eH!K8BRkt`u4zC35l$MK$`kQR6ChqUzPvJBQ>v z8RVu378n5$zm3?DKuJL(q_U)ZXpZp*MNg@;+R@EQIgYi@p|;BEqtq}y1Y!u;b6l|9xou&v>*wmqA<>C!;$O^M&Qgv zqX(S+*mHe=LP+KkPh?NxWM2lfc>9%4$r{im%8?ug!c3uIS?7@e?922S(bPE5-SjOx z=`n)ylk^OVB|E@5CPR?XnY}N!CoM5UV%R%G8M0C9QK`-$Pk78X0F4L-2ylYmd7Qn; z6<0c?dN%`5gIx;#*j4rGNt&lxz6Wake6?7igB{53z%1`?aIGjYE0Sm`lPCmG;iyMh z>ik3oE#3&BAaoMHjw8x80q~*mO&71|6PtgNgo2T?DYxYZz*qOZ{5DIVtyGirukWB@CvEy2;%t!>?mkR8h*tFn|f%?^QMXea@~2M#(MXE%;B9nMsI z3}xF<-aK)Va%SH1_zd-$-)6p91`3);0sT2{_ayJt7n86Zl>zYzuSp15FOm)hzKQ=S zuQ{HV6IXyQo>)9zh^yO_-Ct4Fon znlp!ON7uOKsFnP}wxZ2kc@OYIRB@^D_Vz_l&*r7u4WH-Q_iC)}O~v10k)EW;)J6$J z=mqK5`K1b~q1;R(Sn{Ehz6w2A8|&*>*iVi*VHGNUn?6rcdXp}I2tj-XWDr3Ikb}T81q4h2cQd+=eVRC_R=Ft2)zJC8~RLw+ZIEyb_X+H7(}@M>t?DlnDzgn z-HZm=#PCJf=$aLOTCMzBj!o_K{n|@xz3;OWV(VjW7s~v5t@PuqJb#@lx>@X}ss^iN zGC_=tIpl|>Zk_M(+3(~npo(rZy(+YGb;oGtOjohhX%Avk!mEM;WhT)M$ypkML&NvDo*tm#LgFz)Z$tuD#5#TOTB7WueYJ9v;vMHtK#o6 zzNEyfTvXlRMCF(U`GIR6RjebBz}7?}vUUhm-UA=f#-8yuAVxBWHD>H| z1p&ndfhHjqk*6J3o8^aycj9IONr;@SVzLVAHl&pYKwfPa zyI}wuYp;Q{``QHOO0UyzwG&QR#IMt@wNs8+09>P7c&mjk_E7=n{J~t__|ql9a(tq7 zLdv(E$UADdSe?8Rl{eDeOj$i!I~lZU+P zkHrX`K#W}fcY-;^AQ;5NDbMBs)+ed~?@@iu1QB9!*OVa6zs;f69dgMR z4HM(Ik(El6^`~1v*-qYM2E|+zPwY|gUyO((Q!~u62)V*Lp|l|4uNekR!?GVLe93iV z?^XqK>%~-W8RsLTdiN}~UF_!VaML*eJcdei@RbU)`|Ii4VPQ-hIo_7-siZ0IGhSxu zA|V;=hreBRyI=Q{lOJ~a$6I6Ok@tsm5EaXSgk>!ITTpfDP>Up2(mFw+MzZ0HpGY-v zkI=>9N2li8WSNkfXqWX|K{wiyXKnmyfI75iy#TFjZZ>zRIwIb@@t?IKPq3G>Xv{WT zq#J0QCGnTJRym7McDG{*95+EQfOrNFo3#>Xl*5*U41=CP9iD+d&`A;{MJ*XnB|;Y= z?DNW|19&OK4pB)!5Bmm&f^V7(`Ars1tikb zbCJ=4AGmfkd=E>Vy#soLR%~Nqy?n*7avNbh7AzXB->#s7?=np$?*kG>m7>Mv!Z?0o zMB_cX0&Qm~868SYUk&-FK$jQRHPH~&=Hc&1$a;Q@Yg ztC7k{z0^}E&)#rO@|xV$U~8(aSyN$aE46hNyY9-fMGess*;iF?_Y>(~T8FUzv%-b} z#JfEJQpD8at(Vl5l8$fFA3%#n znZ|TI;{!B6yP<&v4=O}1OfHl@LfB4S*KeQ60=t#!pONxnpCVLvC^H?GC(9=yA?Osa z{`)&`8^y(C zVYMjv5+K}tkaGcgo1&u(BNxDb0IK9z2%Zzwc2@(HM3|dh75@GC*~Z;7zECeY0EGPJ zgww7j7zA$|KEf6)-AFbzZNz(f+@|lr5>P--M`S1L8>91ZuXqV|&&1QrGZk;oMB=N! zqjSx{c*m}ozXZe?aaK15Xc^+()<{?`jQmbM2{Rt7wJK}H#`P0lLB?ds@ygO1q0 zBo;pv43+8|2ncve?g#pWK1NeU41>&*acE&i1ce>kBv!0*P1dSPRhTG4))@FD!PEyt zbY}S=n_>4r#I4%5=?E%qSvoQ{95Lxv+FAnPoDan6)2%KIt()S9ciaU)nF0Aqmw+X?W=5jboXydPUX02Wh1>p3mC1n(NPLwwZA z5j3Gmq=p_s5hZrFMqZxzPlBTpApKh1b|#o~-Iv)^;Yl8?T9s}AQN|~zEvB~YH@kb; z?^UT82!pyqWL7D9e-R0{RlQVrY21IL8E&i<{&GU4opSJU7E=<=kyJAFhqQO)q^#wF z!EGY}oP(x{9VPkwa)3?Xl34S&@vL;fl?S)mP%hi$j9OaawY3@i<{U!@5eihdHZ`a?#tYujMmnw$&^9#K8iunf=jrVP6L=$p?AyFEk-@JCD%>bLaEq8yEer2SqZ|~ z+C6lZ`BWfAA}V9rppLQ*^g2scKcSMXu9!-MCY^7eIwxV`2rzHJNA zEFt>BYfoc-Jr#&EB9`fwE?q#RkIJ$fij1m!KvtS8pzZt0Xv>4{gl&IMZO|vqLOh$zEK;}M5N3%JSK^DFj$b$j0Xpg ziUkpdT$&jfpXiR-!B)aX<&RcMTD?7@OjBsS$-h`9dPI7gQ8b~xoa`LHd=V@c?+I`i zGcHXkd1a_R`5C+kx|Af-ENS2s8#%tKet%L%(;Osb7OR=EmWfOk;kk#xmu}O}hb7Ad zq9;TzGB#HS(DkS5OL>EF;6w?Z`ZJ=s5d9}XIU1{b_VGk-v$kBwsVxZdRNJFJBScjs zz9qmgki~UooMk^#?<2@`>aeBco>jj>i(4hR0$tlmh;A`!R_}JV)8p7f4E7xqW^7Wk z6|2E3I?P%T>B@*@YTC59f3Fz|4g4lLmU?V25|;F&33l$B8`e2YQZbXbgSJYQ&K&RNLRfor{aMx8QfDl}oRGVp`|Q$;JQn_T=Yp(dj_Qe<5w#K&7T<(Yn>!t#5@9O#U2mI|~>FmJv@=M{*bGKKUnQS7q z49ObeR+#cmW#q%C!Vc}G_UIK)id$*=ippDDQI{V`ND~#w+wM)oriwr;X`+uvz>421re zTLXXiMaARq{l089@03jYs+jC|4XnRxcW=M4IBqjnk&2$OvnB@(f@w+K9F`+}0nkcx zB};~7-~^$KPGmJ|G2{uD&PTf9=QiI-DYXd8vnw$lSj~N9o~*Pi6~%N!et_!&{%pDZ z^{wc(yv_NTe<-YtF@T9K`cH9aiprVV$#70k1pjx=CwGC(wM}NKQ?_7m>t9);&7S-%T8v&b^ zr?L@3bxU=41_1@pe2&obSK(AbN6rTfSLcaf#}NB zH_Wx=5YcerUYZDsrE-d0hG@-6r;#I*yv;YNUh}3gW7rKbA^`VwFm$xdSOiiJLoDj& zKqGY2<>A$|7&TH{9K-b7yXpyV?k8#lm13XV6-KpyN?xK?V${c)VK~s|$9wbWJ<|Q- z)23yhKcYYP@&aj7cUW(JJhw&Rq#Ht|$XGGom)IL(BJ)0w{!{OFO*Y2AXGYm|4WNMk zGnBc14Kn9pgKqTXH3_|0M!R|$M%{h=I+L-vBQUPYKoQQwZ92GIf6$~`RGsZ{nT_XB z&2Fh5CoWhrw?q^g@Tu>pz97u;h*W9-&2w>(5t#X0NQMavv&$Ze8M_NY%=O|OPj&XiiY*}UbG|-@l#@jHgre9+E9uR%K|Mgk& z3k1FpLmj$ng62AY70L8WI(~%9&{mKxl_u3LUgNONy*E$?{Nk8OqC+9|f*WF~A@g&9 z4iXv6X`xAQBxpiU_hZucoHjqY6Tp7;kvk-0d$@5(vWb=_d|=y0 z!>#YGyU0C1{?UyAOONL5VYi8fI<|69B@WijKOoR~`bwpA1H!ll)h3&#Cc8;OnN0+< zh1lpwv+bQ^@Puo>kJnQu1zBqIVGPnea+}gvrX$D=bMARDZ1QJN$foxaH$lF)B>V@~ z{?;x2R0VvMS6(OclC$&Wl-@Vi=y>2A|_FR14cN3Ks?GMB=z0U>u%>NEJ3-y*GI&gnscZ5)fU?zO*JYV<6W7=V>; zUVA;XSVlV>5gFhFinw}8Df=@jk!%fl7&4UB4Gl74QHts4$)Fc06~{@W&0Q)RKnL`6C!u$hj`^HqsvvQwYxVT77Qb<0bbFG=yrt46e{KJ4Wb? z{qa_f*WzPi&=QcX*SV!L0%_L%W6y>RdOE}2esZ%|2wA1p;mctX21IOehOaJqUx zS2tpa-9)9w8CdTSui?$gS0En#{-0#`Vi=7&4R+wD@-k&Rp}__iRQ>bT#A&6*$qQKC zzGp|n$*lAb2_#Xj&R4JQooAV<@3&TS_b=>MuE#qlJ?rkIu-!SrDjpZUBjTHtDq0Ku zm={;@d8(Wc{IN#ArmW1Sq*XApmsx+$nBMuqjq3KA5Zd!2Z%i_Q2yJ}tAPU^^JlXlM z9_`+c2`*jy%uBt=4_VxMj2AyiDus;#up3d`Bb=-7)|`0a-1IPf7wTDs4(k+#_J-)B z@qX8!W(mTW36(&I*MC}}Cs0<7W1=`N3{GLC_N%V_RNvO;ea+?ev9^jbYb{1{ZR1}s z3_52srhV$NyO!6mutf&Nu(%Gf*IV4}OBu2xaJM+Tu^ok11%`G%wQblDF&~emA3VM` z$n6`}_`Cg^KF?AweK}jPEz3gsZ!F#i{rWr&%;)pF3k3LQRTS2Vd+b!PIVyS`$4hWW zUu?4XBd;fQwe}-qEnzGPS3t?&?5a0}$i`=R%mzDd5r7n2wWyubxU0gfX{P7^YDGB& zzH3}!jUo=Wh}G4Jc8v0t%HUYq3%@5SkMLZ;N z;yM*Q{APxM?Pa-8DD9@`$f)!}osnyWuM39_G}S=*1VgJgS>Gd2(|Vh7Sk;>d;p=YR zougMd2;tub$Md4=c?*az>$RgZv;H*KS1E6G_o}FL?)vzXr^UtAOvYArMo7cU*XBM` zi`{kKueP6jlyDUiXZc11@zV-<`R8l0h1ov6yp+259O|hQVw%M^MsV%sb-RD$Xnw15 zgRoPo5cIiVPfB0#s!PqEq z@?PZ#{KEeavIN-6`mYRN?dA#s000>Gzv2!5AGZF#$&&wzL$DCP@wiJibznu-PPi2zVu9yzL>xUrD@ScVj+{hK zzruZi3n1DnKOm%Bh?gqJV{aeR0MK(CCjC88jG!SkvJ=N4VXu<{|KqhKD7!{hB<`a*XPZT z*~l?YV!%9ZAY&HpMi+v1K1E6u^^`3ErzrTh3B~uX*2upI?c6L7eY(`dpeQb;ZY~19 z?zXI*c7nsv^Y2}|o~?C>?Q3SD(~H9Nk8K0T6YD zOp$B?r?6e6#$Pp-&%^z^Yk}4cZ<2Pw3KzUI8nr_F47ApXB`KsG8<8zz8tAl0CE%zDglx@Df1vFbcinSgJP1AC2y0ZBa{I3lJo| z-9fmJGB4BhV`&{yNzZb>jn4N3d7lA9?vZ_6}KVYFWq^+}w*8-t9 z^B^01(I%#dX8!2br9VRa8+l1os&4sBV=7yg#&MemOs{%Y}TUO`V00S^G`aq%3JQu_rCS9X7K2W|JD+?M>xa3^Ud~0`7M!&^t9G zbDR0$u^Os}$0dbG?G4EU7ZajEB2O^$Ri*_ku2Z+VM(HgMe(lvF{dK+FlmFU6?#zYo z8uLAkvqnOnkmTc5`5J29-gA#yU`sZR6!1;?+`x4FnxU(1mMQ6AEM84xGT(;F8Q{uW zy@8H{*h?9Cb}`w}QY+BS4%*dDgxG!?3>+U4QZp8RMYLm)Qn}Xj!=hDq$Oud@pXjrvv zLr{4^<++a%i`~LxJ5d6!VrkLAIO_qEdc&odlG`+*KpVxLq`j z1Tgu$4+Ygmo{^V~OTq&F!B|&?&2&|YseZtieoG!QO!+g=5qlQ@$c<>h&H#@A{8zbu zdGFi&N(C||RSsejH=;V>6ANZ@VO!Hzoku=<)$Fz)bF-Td2}N;0BQEWvPN>mnKm??j zKwe9~)audShn@P^WT7i019U78OAm_2h{7m#R9o?~(z6-VwsC9E0rtEC*#WN#$~=0q2PbRKQoZ$d zCc*ARw6sk-4*SxDI1Z$e4+~|jzn{oAV|`TB;SltEMK zEG9Mlg>q^n01aGuC5n85A(ddr(b4RX?F7VP6lWp-;W=l?MWf1fq~Yn)9_H zUl$$*!Y7V83l|6K%q9M>eO52%iDHODD5qv+LC;{?CisOC!Ry zRXLR_`9$+VjIM;ix}76&P8pf?2Fn$inwBed6$AVogI6L?0BW(ks&bW53d2-2XjV&C zu;0GS%nxyNie^1NIbR(jT^mF#-N-KhB-eI5>@H(5?qwAK;zENaYhc8vVpFWduw=Id z_P7e=nrg=%T;yPXtgBZ7L3-S z*N?k3dOexbdOCpvO~=)i!Oc_Qe4G}2OUJSIXit>yNKMvnsTPB86A*vbP3Y}R`Sf;mcBl3>>8&|GbG|=4Uz}z@+Mr?)$KhkuxwZDeu zbtWK{9u?z>st)ppYEWH)pc|@KJqP!=E?~&44%d{g9bXdJU`J@on0pYHbyg2fv(TPT zW#D+<3y}dDvga#lptAr8;ZB8y_8d!&Dt6eB4FE5yvKEdb4X`Xa~K6VUiz|R2?qWTTkpUg z3e+uW#yPR=oY=N)+qP}nwr$(CZQJHaGWmM?K6Cr_{E4-9RjpNTt_MFgs~*DJ&_`d+ubSsq5NNyUhtVXv`b}14q&GGN~OHy8L?pRf>=x2oUFpi*^M30|Dx!8cJ!6 z!cE6EaDVmK{&aAz?Z1gKq23XV?I3$dq2j9pAvsioRVeXsYM<}nh+W`Z=emfwk*Pm~ zP*a7t#wTeBzup|W0DXu-4a6m`c9f&v0J|g>NAG~ky~y)_DDkZcsY3e6UP`jHrk-yH zkEj%w2>aVmM#Q5?XMf(USHi#3cm5-dr0XsK@vjEPTh6BFmh)g|ro2q8CxTEFZB%S< z9|jle$7u!1DoXI%7jaPCa+^_??=^&X7xKl55^(WFiFk(qRBZ^Ka&lj4WKEL1{#h}+n^iZXm>55TA(lK?^WVI zoiC=pC@?~Fl!;gmG4#GrFAc7yby@Al0j{8C`4m{i2TCY0JN?Pu=(!YSu9%`` z(BJeorKL2^5pxPqAn5##-R&g;)FX0soLv>-_#M#8P}=jB@c_5Xsw6p2ISE&3Z$~s= zeCEriLx{+WP(LAH5eXiGmh@Ai@QhYAdZ$djsbW1^R&|h8dd>;<#b6#B&~)FQl-fQk zr*zSD$ragKrbP==sS=cr6i?$~0K{M(?W=Jm5WOyq7fk9WJ%WJgP#la5^0|`j?+~h4 z-_sVxl1qdJ#cn&$%R5G~qZpWaO}W6ItI<^bMG1II#m(-sY=EnL?}WbJC>*F5OCHK_ zK%W!7L=PB)BXw;nw=qCDLj!|MrDrhIf}c7sf+{mn^~`XanP70-G;L_lRlK=v6T_a%@@62{-G52~+{j}@!lew*azHBBo(8x^yKr9-pYo)d&-1DFAp zI)H$X4ytT{QZo($)Y3Ig8tDac1PP3G7=TQ4fi8+EJ`D2L4qsq7LgN~p4;O(9YRy;b zQx>_iI5o_vFX}3#1q}|l5cj$x5Aq?gJHq5cgtsQn<92B?^P_0cVdiW(4gyx2X7{fFVJlWcK3>N-PT@SkLhbge9O|!O z9_tlXT*Pa_YN^f6>zCR_9U&Za${gr|*R{ISSXU3!ZcX#`A@r_sAhD}hv3nm{3lwfI zzm(W!*vnTBMGHl!{P@R(!$`qerk)<);ki;1K}$-Zn=ZU0yz^HZSuTR`ayj8`8zJuL^9k&jOj3LM#c^{i#5;Z=zGt)0gFMbE zpnS;@%8D~5pfKsFSP?Exrzgg^z4tN6G;=C-r)A^Sjz9#6$xPb*IBI88!x)+>yWmAo z5D{7d6fnJJmOSt$sro!HWVD9KI=K(y@r7?h%8$=<+ni4+-xG6mWOJcwk8W=7Cm4k@ zVXt+FUxLEOHM7=YBQW^lBpe+aHq&5OOMyzE!PToLHe5bajAS=%8(Tm76A+K6j@~5J z5E0h3+{HIf5|&Pu`9pw%yYCYlzCC<#Ezlu3w1N{H0y<%42~%20N65t@AjH+Z*Nr=e z+On)Y*@Wy_jg}}ygVR3J##!LwmlT|E#OVGhJnXppADRxoK^X8C9KjHr16stZ{YSHL zFs!~WE5c=gNb&T;+NgZt=h6?MjlAB7qB$VX_TiAt*aD3_2IjXDSuS@X(m7-=FX_>EK86#V=e$5&Jj+)8S(ikK|%Yzs(kFX&g-^vRZc@s{fi_-w&#Pb8%n9O zcQQHnfMmG=y9-2$v;Ix^>RSOgo>boSt=(9LJBs^iLL+K#`Vu{&Tj`L|gK9dk#T{csD%Zet#eA#0ylt7pt8D4M=ko&J@V*iv%D1e?{BdVoYgIc+pZaApw93YxK@48RTR)xM?Yi;cR@7U1)Tr8$oFOo3* zr`rDaXN&fKe{TQF{k#4x_50ub%fDSj7s$uOn>6(4O_=`E z#c6x?i#qqHKiZI9DCy%_;)74;+TnJc-4?ur$g1!ve*ZTyaKC?#kVT);yYB#ol%gX} zyp+WVT>BoN$s(zT52D3k37k4b=N%Z?pSTC1Q6z=Q5-2c}2Ie`e7q^J~qs&SFZ@6h+ zV!+=AC2u+)R6cpGMi*Y+B4S{l`mK}v6n50duSK0)` zfZuB}6mv6jHTlNr8@1&oU^8yrrST~dUHHNyusmi);OgMw1rFh)02+StJjZFfS&C5C^7dl~+(7wsZH4K%$tbMf^s+A7 zVEK&hv;{CFQsN9|gbw?4Z&;5Dd-UvZ+wS3XcoU)vOG^h4Cs;}T#@31aMs>KH?oo|c{VEURB@n~s_u zXDnUE?r@Chk8OxfkR`mu*94k)JKdM)*Dv|lb^zKXKajH+{jV-@8#+k33RIrrU&Kf- zSZb2Wre{kV@;?=zN1rV|JtfzpLa0zy*cN*lC(j;@A(F<($(l{XBBjpIp+tRFkI8!r zIrxb+%pVHZyK?e%-a?(ZiqdFP3^*YC6l91GK(g&t0-8Q)^!D#Fey1sb)eOpU=QcgOF|ukLBtNnFxQ$zTZ7J06dz4fmSap*$*V z@1f?#18{hS7LCwt-nHSa?b_G2lGGm)I@vGYRv*pft;0-Ti;lBIC;?k?(H_RqnG%C> zGQ!w03M(`4krn((A_~RKlGg3w%_WqL{qF0pIS!^EEi`7sz%fK@wlc3?7+Y@%T6a2nBSeEuQFQKai&Nivj%U>5e93S2N&`W>^inM!*%}_d)v~V8i!BDs z7V<59&Z(y*;5C}oIKsJ`INR)w}w32Fr7DVvHe}{B|+zVZk ztvWtOWPSD9$#Ux?1^CexF@o9q{~m$-4AQPOiryjjj_lzi-&-6D{^V5DQR^795G7_F zo>x`6KZPVvi<_opCSLmN62`n`0+ISnj{#4dC{zEl%ld7+wfJr}zXaHXcEzi&4#!uB zNA&1a+-SR>F;Q|I`TNd)i24 zQytvr7PKiMWu5)*jBT6{fZw>8s2egpr1Im>?rEvx#` zg~G3kAoH8E{HOEnzvV23R_1ng<~F9kPQtMIkK*PC;?K74FwYx-OD>gFB41uHBh4M+ ztR#M*IAb=nVQz7vS9q1;jm(-%>Cfk8H|KhTe^7}uULtf#b7!kAZ>Q(gKsi;4M3EwJ zx6JFG88%v-NUjh|TxrMjaV0yL+94@)+|apMhBiVA7*#Dt=SW=^7pzV(gnO+=^a-vL z$f+d3UOj~G2V0=GWJ+oy*zP)DX~MPpS{Du8f}mn5{gIS7N#NoWG66F&p%H&tLLGE? zb|D89qkvJ)8 zivUehQ(5FO3I_Snh=~hUY?TYJ)-_0ZoocY0Uj)!0oOhDGZ0KNQhueD}$H$wNZmYc= z7rI(-Gog0wprxs3rh%9J$Dll4;2&tc{l4I6rhs-%D7b@Y!GvG(;`}yE`GEwdSdc^u zi(>R4^X^FkhrS*-Cj{@~- zcUyc(;RX23Q3mZIU#2qgb@Od;AytMW#<044Ix{3`p?>uMDdrFzRFLI!tLQ{CSV@SO zj)cSbLF61YSwUl+tg(^j4W4zY*F2_N>|!Ld38&p>=kxpWq&!?~Ubw`waE2?KgwvSe z35;}-mB<~K!-LM@qpXglJ7$Gp=EdIH9|#L(TYyH@P?!NLlZEwrorkPwwTc(Q%IM4CsEWQ^s~^MEgPX| zGJH@Qyqql*xT{$5Y&HOSl4tN!QG_Dpm~JDcV(U3j1E5sN5TzCvvAeCRqab&;%cq(5 z=z$P(s^uj90L<1k^z(P}Qxh0<;VBrXuv-mPPVH0mHI=~U@Dq}yp(yUd(&Y}c&2WdG z)#!hVy!PTDvYCg(Vi4Nt3W;n{rRvjd*a^x|+i0%}8^HS}CZ|<~ubc3FZqcUr3jT0? zjt5V@e+f{E=By>rhp`ndgiU#8)EAhSpOG3iicjHlmv}*16CuQ0*ywm$fhKRMXg^9H zi@Ev^QB?Xw&Hc;0T z@W%>}1+EHeAN6lI@*5l|MZH}+sA1(o8b10y;y&QX7K@1N?vXxbz%^y9Rdl{%s#Y`A zg=&)w;Y6jpC_`^?lhC%h3cda*A3_kk+9fo@OacYBxOMlb6LkvmV)a$6dsWWYuEhth zGO*^^Ln>uIFDYGvxbRiZAFYgGF3~&A`tsZCrW$BJn>7fFSfOHG^2Mg3tmA1&BEg~L z`0tseRM za4)ejInJ=TtugY_*l%WQVLXUEh*d-klSf$sYf@<}AVy4*8#fl+2gTx;^B4o$2`&NU zS%H1%BoH5xn*PBFpePAci{-6Lf+tZ}D>ZwT#XA!P9@MJ+F#j$i-t4dKTX*oM>2HYC zdUfFc0a)&157#uZz*E?cQ!eA*+Qg#`0x_=f<8chFXvuth7jhOUJE=A~QpN@VH zT1fRv1{<5ESc3c-T33yrzV!R*lU>n-Dich`>EB%y(yz|IidFE|IGQ6q&#ne~w)Fh^ z#`bt7C94S028T0PyL`J7KSNEbcQ*toP(-%o8~kxXmQ z%+Mj35>mXaQPT_FxYq2zC=A?WREdYha@AJJKdQ_T6zk<%#i_5D{RCKyD>T2zr_RDI znM{ViRB1BOL@zvA1vIq0gRR(iq*$bo_8ix7^c@$ZcMlm6PyYmp?5n#Fdl{Iay3ltIyOR~Cdb)Q z>|nB0{uh(nvk5&njTvCU5ZEDtYbhZ_puwRHuC?SpEHLy@kEwH1MFv5UeZEXbEg7h* zX`jc=tTWhXe3k(JSha(RN}X!Qo6Yx{hVN?zxVw|}=Dx*jZKM}5`Ya;)Vgr{CwYA9O zppz%%?NBWt)3X))tQ^|SHdqqqwE+XVHJ2RPO>d~zm_;@2mUF=R`9fo)Vj?2N9$8J7 zO2>TSg7piI=Xly#qR|XXl|1WZD-OL$-?pdgtCZxL$tLpROsp46k)cMEI-X!mYPE~$ zr(^L1ylSm-y{JrfZd-JaqhZrX6&`WpykO>1>JtqJn7%ZHPPczjqv(W#lW%Q*WOfHHtgrF$1_ttFD3q*Iq=~< zg}RU0YI(CO7pw(@)9ut-WWxL?KnNq%CBRG_K&S!9BQZ@`w@j@kFs(3$b?^Q}g^=V5 zWr@O0NlpMh`{0`#+qS%M_Ob8}-qX-2nS?$)&?0M^$8SZh;igI7ah8#8uUuFt8e}(Ft>+Rb6 z_REv!ksne=xvdUf8xyy_H_bTwmd3~i*Jz;*?4Z44d}Y}yq)@1V=N+ciol zXnJU@rWR|46o700MJG+u&;R&^28P3a&;K?L_rU(Y$ff_EU+8~H{%w{2hTn*16b)Y_ z?rV7OmsDI*JArO#EFc3*PLW~3ick`xi23bIN;=wzX(Uz;1$>MwKD_Vzyy3xZ21lz< zv*5|pWMjYi2Czkl)NPm24;a^?$Tz`;(Ynvll=^!H1^D;WS}zFfnHwXpTQ8OiF$qkO z%s2w`3Oir7F6jFr8TMbyh8^)f5l6@zd(%a+%BQG+S_Ps*3rIOt5*S1P)>s=~UnmF^ zl$JojgMb|gQV?jXBVGwuKGU8dOnym9JnP?|fSm?sUcD94D<4P8$j8FMVFLvET&tp| zD>0>*lQRoDdvJWS9%z!=T3eF4d|u4w z(8T^@^cqn+du}hzH0YXN^aJe2VC>zb-dhW>SaAW}o19rHFY&lGhWDjjR9ie>Nkx#9 zKg6t?K-W=~5o1t#hbVP8EOMEnjTR$JaVlbj9gzNtCT_fgzQ^tpT9ZaK1g1j1+}x%!Q3| zqAQm?j?7*D)|d|?XV%9n5j&7BZQFdC8Txi&Wd+7;CM;+RO>tqZ0fBQ(UbcQic=p4S zZ&_}KQV4EU!s>dRwWg0D73<4LEeub0QTo9Oujv%o(VIFOldrDm?NKb&GZKk~I%;L} zFcU@k1IkzP=yuI)6AFtCdQf^6;taFI^4kTO4$jXu@$Xuu&W$=;vsM680qOLx;)Qof z(q-PHC6A-k$O-|#P=!efQ#jlwtV#=JaJ|cv$dz!W(ryG9aL9f{rtZd{+#t0H$jMv! z9P^Z(iyBSs<;~XD$R`v|cjaRr4QI}T>K6|gT4kL0Fx7-W5LgO|9n@|mx9kpBX-x?@ z!4(Rj+ww>g>3GTv;$^#jm=$e)zBz*SZ`aR8^>*NTK}L5NTkO9RHbZjv+0XFBy%K^IFo5yrg6bQb@M;26 z#R|NENzO6`USLoQ8T*apFd2}5fKpYNAZ?1xDo9;}9v5WUkC#h`oe{bmV?p(e$#xNt zQAPiQD!Lr(^fYst#8_p$Nh$Ji0O8_Mw{}-&JE}$}e<0!QwlbKUtiS10pp1%d zD)m^pV~s3e0Q=<8oi$6KX%HpC9Wi*Drg7$PEVKKm8O51;Hhg!HJBt7f|L2yy$ZaUd z_zc0C?5P?ZE6+KZaWJ-OQE=wfV47T)DTJl;>Vl31o5R1HJI4SqO+)>4P)SUm#F1SX zztIQ`hmbztaw-m2G1|VrG>OVBXWw?5j8rRbo)d=wc``;VvdaXNpgG7O${D01Gfr%c zO1QEa6D-gMai1Ps7wlD3UKrbt+m4}v2uyu z?jDl?;VTre<|NTx8m+CwICoc?U_Wraus>Sr%UFXOg;F5I)a0(3{3B(t1j_TMnU{L4 zi5mU{{^Cj;JfvjFI2xmB64ME>`R1#iXl{@52-MUHcu@^XtF_Vi3O2)(O;w(gywzWc zOf+v#b%1eJq|;}~w$4^jBlNENMi`9s-EQY$yYdk(R4@V#5~!01Zs)(0#Yd^`MsNo$ zshB3GH6eo*m*zR+{}lIwh)F2XS3cmJ^@Rj^3D?YYudcF#lC;>v(Jy+gwu38HYocF! z9IF2~Ef}R$>VJB1gFZxkzlI;Tnq=y6pMVms5qg=n^IVUt$tRX*>^8^VayR*Q{8TzW zASsHpxg7s9HjHClW<+~rR~?-uT7rmbPCpqjd6(@XV_j;iXS$~zT!~nOVA1+#NeRXR;Ajtd5i*4LB2hPue} z=J1qe(6Ij}Nc>q?C1%|tP8n&OSq44n4xNsm?epuKde#DaX{Bq;{kl}itUCkF>(xRU zBHCrJ$YFznj7h@;zE2R8ovbgi;5E94q^Xx18vh~aGR-2gPayn0?3hJ^6qOk)3Q{Kh zC9aPnCUyewn6|s%$y!{}-a`AF%;(Bkh{L?wdrKbE_YzIOp@anO3y{a^j&_6UYeKoa z@az-BC+ra9qjBc__1@4SPs-2T5%MGyghr-MYbaJ`m~Jpm;D{$sp1VAnY>QTji;#wD z*h3*Kw1i_6fK7d+a@huM|HhON(l*jhyB5ugc@jh2n!yj;RH0Frfwdi*q`fo3bCfQt zF!IT{EsS$Hatjw%{pbwI$w5GnoMkiVxAU4NZ`+`g)ag{-aW&8c3|N_2xke;~z!wn+!19^KYYK95vFCqhH9*oJ)`Xg}1|3?4F&f9v(4O zh8oOJzo)jIPQJVYX`azRi0@QlXOdc~YgoRGClbmwF8=aX7KLQntAAry=kzAP=|ggs z!S}H;a4xie`9#PA#FKYMWUk47de`oJT2d8>O;%dAG>;g89`=TE8sV_~GPU!?H2-(N z$wT4ue2Ka|(4D%GKTrR||8y~sy>Xi>G^x}@y52L0l{k4fupg(U48d9UioN^!oWg|viMb4uD4Nh=sV_)Zt8_MN1VzamvYot4+yt?d>?p!z}4f>#i zoFHTzHxiTTiqIhrLduhmpmJj<39*c3lwh_N_Ooe1BAC-gd7^8%E(6ri)&x%MCxmZU z@_0mgb4@O%URGzIVnleNq~h`&Fv{-`$*{-jVE+k(Av=Vg_h}5A;_8FI0Zg_oAk}#} z`zZGev?oM!#Q$?ZdJ~Xn4)TGNHJC(klgIW2)M_Lhud_hrpv`BC$HEn>VS)HYii9J- zm9et7%J!Hnl%Zx{CfH>p=7mObd{6%0lmq?6Nx3Gl(K;k#=Z|w&{I@1gyVlfaE)aT< z(VK{;bCG@5f&0=K)@bYh0_%snwq4sF~uL^fEmNYwLV-uq|(ElvD7iL8jS z*Op;*NLnh7MIE-H?Rw>rDVgNw8tjDPq_F1dk3}mNPF%j=PjwhHBcCKG!>-{i?GBqz z2IJ+>OL?#Hr6QK0WICs$i!CW5Edex3wb(@ll{MBHuv2NyVD&>%DlGE5F3a?Phl~Fw z7zjR5tz!HIswE%;0AT+AenkFDkyWc{+Wv+EKBu+lYj7k`At_l!e9n6KHHQTPNTqd2 zsKNSmlnzAoC)Ayi>k@lEJL$*2C7|y_W9&Zj4lXV?pVL*b`we8R4C>X)S1zTyW@X3;VM6^3;s|&20}6r(f=Wil8DU%k zZemXxW%UAM&~K9wWT+bOt@z&wJt^8e@W0d%U(4#r@$>V+K|%bedVT`cQuFCDnO$kx z+8R?;+A%%HBg32m1^j!(WHE>TP+=0d$g~JU3nMY4OhxY=hpjjs#Yb`%-)r%BX7*8ini`P}0#7laio>%mmjRBbKv%bbqo63k%sPDg;qv{XUD3(VI8^1-JQu%d^ zm5c=XdEL#K(TYo=<;KgFC6z04Hk(d(TdRx5P~OwIN9B_1=g!z>tbCq7p=b7=$J|{7 zdkwwpMUKc&l3pY`ID-XXGbe9Ax;l^{yTcCwhT1@aKMc<)q*8?rp$aG@nlC}0LsS>d zIK>`>LP?Y&p4yH)@QaW&UtO@@oCa^X@HmE2g}j7^Vn7hFDWZ+lGAo4%^kk2vaevKEl<_ zQ3$Y(h*`~w@*EYx7}vA^2p1vG3X(?2y$*L`-OkW&jt0dVcLa}k2bzu?^j&2h(+mxk z-MZ1iYBEo)YlUol;(_t}T7%lcC)-GRPHo!9*;(=!`Adq-)8YXBb(TkOd6f*pwJ(>^ z`%UL`$TOnX#%)vDm$kZ^dPu+^6&(rh1c4X8gO1G%JzoUC)b+J<%+I?KSGoAMt|Hh~ zeJW1g|1`O3*ul#yhC#P5TcWRYguX_YYRQ9N32VzDE`1t^mvbqONJgtGb9+137zTfm zFX&1#2v?m|JWIPU&)MC8QCxSiH$)!k5>hBfL+wn5vWqdBku6ic*vYs3T{1zJFF4}F zx^s@!< zVvclHah^3F0FSRM(a_a+#8<%}>U@pOV$iz!%AONA>yQ>iPCf0x%6|nl~=* z5SAtY>wGVEgu_G#J6}%qyNf8*Tz6qj)2GJ;_!(#dgOAF+^Ye1rU{bIezq$o+jao7t z=TpU!@%OpmZ6aFA?eqBAELq`}FitjE=Un?`j2^|N|!*YO-rdmy`^)_ms_KG z!(pkmzbA(htj@hv3*C(KfcH%Np4ogs+ zaf@G0pE0f)2Xxyk6h60Z-Zj@KFeB?6g^Z|PAK6Z0gMOFIDNFcrV&U?MQyNaQKP*X| zd{!hZ`m`)Q*|6WyQqR#oYh2q7X*X2?KMQvr#zq()C=0DMxI|ztFU4-U!YwM>d@ZY@ zlHtUVK38j_$gP0AKHuj=jkMiebt_=IR@XTzwL@of=y%&U+_A#s+_zT%`ueSrVrj`9 zmiA;64nCs=O>#z<{}^69hu^iVTZf;FnaXqV{56Je*zR)vmGqruM1$!=coLXX=h3&x z?-F}>4(>0LT9|ClUAgrFw^k$7@ZW&-eZ6GrXY+R6?f7I&>sVcW;w&sCn?##!^Lu&~ z^(}t_zww5=XY9U!nNEF9Rp|+X5IF0IF z(v>k-WI01J-|s0|j!mlnsSjTwVbrzkFQeDH8Z49N^)U#Q44*uT->x4t3LFR&W0vbm z^pPyM6≻{n=uOft$3e-JIKP+ozYMjae}8nxX9(#`wbKa$s5aOpV^OdGuyAF=a(r zys5bQ_mp84$HmCt+|q+e9cYpb<@gs-80i#`u=-A7Lcj%VS4XkF$F2KiWCQCvu=;GT z{087B8+f>{jiqPjTW~sGFGyd;m4nEMJv1?nX)X#Yr|x*#*9fSjJ?HViu}lB+)Pfqo zOH=eO*E z3&u3NHkRgEvWcwilL+`s2onut#)P@7;>(py{Tlmq)`8H(h@(jI_?o;LW3*)QLFM}S zqNj_Frr>PoY#AlEOh$F$&Me>BdIXDw=m=N^7qEcmgmeFpD|xDaQ~r|8gaBnCgSE5T zs3j)+TwW|pkTzKGBIK#$ze%{+c*yuVuN8u%Q&ZE;YnL*@^?c{b#sZ$R*&Qhwn#A>E zfS|*_w^IE(iHE$D*YFn(q7@*yO4R)ga+(b>OB3`j0VqA1rGYG!JnosH)eg`}QlEQI z4UnSa7V88jgTP{`3{DI=^F}TLy)mYV>UijQB)@k7yVK``ZX7tIkK@NtHKk7W*O9>-t92xX@2xHdLW2speTBTu=RTRwz8ta-a zVEc08Hig~Di-9^|NdlNc!EFz%c}E%Zu3t_vGBP@x$>uEn0g5AmFD2c(tM8_P!%z4L z$enCe(szOYP8{uxSdg6x1_BaQP)sbHJj51*>s=I3H%=MPvKKddr*leJPALyh<4G_v z)W9sLJH^y13hn?|3hk73r|cIL(F_bOc9jtk;CZZM$%DrJ+NF!)RpM^JWH#^BRG%0M z6HF0&gOk!2F7}|Jj5FtqMIB+YpTYGbOtCJWOm6bL3ngYgN%@bk2m`$h7!5%7lEoDa zz;&rj5LDOZ?q~DR8dp5$1aRlWVL33Ie(8}Sv5&D$pA*7rD6sZQxEU>Fw1ENoB$8lK z`CAk>`K$9jbEyE8ALRc=)(6RQU-;3BQyP zPPp}*e9k?Sjig_r%mk1v)E_Dm&XVOCLTk+ir!b-nZeVlmqfdnHGE16WX~wM_XgIx|h=F(t1P{GmKN$sdabMb5$9|H}h_HA_->g?X;;zgq@qr{JSI8<^xoCnAM z@*lp5oEs&4~$Xg>cUm)UH9N#tRN!h^oXj{ymi86 zuPhAEStZyC)QdUi(n=O2|h&0;D*c5+eV$gCKaN{IMYhmWr01BE&evPHF`P===F)GKrrE0~++l7=@F(1~25{X)b| zpblU~N+A&}d`SP6ILMltNFr-?c0?xFrR49{B*~S9m5!vdGcsGK!(ZaFBVAcmhz63G zf0P(8d6$fFVh$VR$Th&}-+oP0!UTsur)EV3IvUJ(^|JJJ0B9NfS{#|m=pc}V%LJeS zwvKsW1I^+%l;R{PE1%x`28T_-ZhSZA#2{<@2tT$Eg$}R6|6qxuzO_XuVpMwFhtV=c z3r3&^|A9l25bemoi-rxN?%--`jYGi(_(_g9RHJm+j=xb}U1zwpDNcvs3z2(|fSz6{a;7kw^o&LeR%WIviy@ zqt{80v%27*8U`S<=%CGRvabx8RRudVfGQzZh`32xLF=5Ep;nC%huC79Sa!`9S&~ z1eL9+Mt?lJrYu@*mnB!&Bn%Mi{r+uD%#~MDssP%E0?)m zj2K%EAvWKmZZSDaxi{Gnlw=NVV99XVv63MMeabkAy_6h!N??- zIzmj&bHEpF-rI;M57!ZrN3?`Dd5G^Wu*#@Ju~QSnI|!L-`;;po%zg`*OW{zaNgf})iaSgKCi?bv{z@+hx8h&X$K;7@ifR;B6ei%5Y400c zBu*Cu$Ux4ZP{Rf?G9QBGm~nG>97V6$Yq;EsIrECih8ndW2#U}7qKKYvA+4bvZdLUz zssCEB2OunB=KI@^rm@36fvkDPCq`f=Qb0YwoHhi~9BELZRzw~unfxh5-r~Zm?b6u; zo_(D_G2B#rqwzLv{`;TBt&xHqm*>c&@!v&6{7*MVM5`pC0^;Qg-X*&K~lSACX zT=Og_tIfP$g3>}MrI-o?(%>Gx zD~(H_tD_KWkZW(HH|!!VErz$VJfhdOs ztL{5`{C9231k_JXw=!Q0y{h(`KJ79WY~}@@C0=tX@_2^N=vWxfQvx&=>}D(Sud?b2 z*o`*|70D@9*Fk1p-=_yTngRuTnzjcm<*&nc;J35Tja|N_+>M#Xy4L+avmtvP5?sMw zxV@eb008d)HN~79Y^`)%^c~Fgos9n%cm>=1Hza!L^cn8^u2UMAlR;hE?zjW0xorC7 z^oAva1RKiz{k3KC#Zm+%Tq3_;y>F-c_m2z7F`W>Ag^3(awqK(29O0Jk|Qm&fg3bqCY$G?s6RD=i<2~C_) zBNLP@`(+?;`*}a2XLlkLNZdzsKtiVmK|Mc~cb+8u>@rua25r?-RI5^e_=pmy6P4)f zk-`fM#77Fe6pJ8c3Hx2dn@N~3|HSqCEZ{!~59>3BK7)E`c3h&@hUxv&R8_Wjva8#? z29g;9b8qJmTNZG?`+7fLA3k1Z#7}GK8do47rW~eqRsgDu)Et4zBRzo1V>sqaNsX$N z0KBchKHp9jPd1{=9rsedb|VLn$QrAVLUmq7Ov#6Jb3yZKa~pY)p;>=F#{IlFdib{h zj($61y4&NyXm~55)63JImYu8Jpjx|rUKv?CYkFEsv$rYf;&6}jdGqhUG8Y0KB@cmB zIQ$7X9wH)MX_O||454lcg}(t_zCns`k5Y5__X?C^H-Z7!MU7xWnovF+aV!#0A_8oH z(qq98avuAq+&9s|*7g=V4IV3=ysPYlY_@_yc0DB(qBaPc1nC8qNhd(7Q3BwVOwn4D z*&<|2E163M05%PW=R$xzgYJa~80Yb>^|h|&ptrH{u63aEJ)!(|ad&;8GhJq`B*xRS z77VS?)ashcNY<0CmKQ8%Bg$9ymN;kSTiETxY$wn{4LqLJ&wSz_bNsJ8`QCPOtjJ2L z5ZV!bl&rvYBe>kPVqpunie`>NKboS7HdzjV=eAi~zH zLd6DGOTc*QnE73{$)TD{XPzLQ{sZg;IFg{LA*I8|qc$)y0H?Bhj3Jpkpsegst7&83 zzjBcCY%Z<{>ZWlZ9Z5`@!NDZ@ky<+zYUUZYaq)?Gv$0}XvTC8ew(6z`2K5JE{3K}- z6&*wP6BU?oWlEsGS^V?Em<@g#*|)L6LrIy;$smOcDNG!rxB@y;FRSZ@?%2bmbSmdk5ve{6S$TT}N1e1gZu!q|1ObW0uZEOnQmf z6jk6=odKmGuy{x@5^^XCRqMBO3&3Z@l51|9Fk%4`L>Za^e#Idunn1!c8G+%G%(NeM zYh0DA`6Z?|foTMdO?U-_aa3+Cep+i`^>9^WkJgBRVQdnpRrF_;T9A9RH)gBzlqD_* z`Q98<>iJG~76|0x@FJqF-1>SqR+LS^GKUV$&-f@Xbr5NjG^rCImxLo^43JJ~!p4*} zS;$#x{do_5-Ba}p3?sz-uU2;v!W#GRZ91p-S$fl~c~w>@57!i0cH|30NWZ!aJ)LUO zYp`My2pCW(s;g>E+~u-SF{eZg??cY~O96^Z+nYtuG+XM4p^VdZ>NL0glOEvH9|)D| z@U$w84B1`2KT<9wo`5Jp zzNk{1%1VB1p+pqSFzRY7zm8yH0;u8IS3MImY9j9`-Eaj2wl8E(1Zf+zf0a9 z*a)|9EZK?V8Qth2a)0#!zy=4q3Am!QP6J`fMrjc>va~-VTs4B!im>Yt(uxUubV<(gHsk4507!?nIYw}4;#8Pq zOw^|kHkZBvWIg9@Bz+EX##(1cA+rgu`pG0dsuAaZTT4MX#G1*Wmi#+rsF$^em+eVF z4}R;YM)nQp_mrzH>mE;6W+624XHz^W&p5|ZBvg4f6U zqI{?4&9k+ov90mo=ny<1zr&N+ zM3EF~^al=h_QC#Mo!4UOjc|4O0bE{*EhtC_6H-imqlP=An_7}ODwgdl!bBe9e#nPZ z^bve}*7m*Fu?Ln4agv_cjK&i28b2b=;laJet*aeO9`2X%ZL1=U)m8qA|2zICskX{ z^cqxa@CtLad`OTWWoy<+h24u82b&1()K*CMq8MhQ`%sN99u=3kzN_eb3pX&;qh3PU zaWqro;4G;N4hw!5BG1Y1+q7}pS8~UFaMr2g|2;$R_KC z#VPYEy=n)nrJfKJPoGHsT^;p8M z7m#>2L2z^dPU(D}>cMI(M9?s<7Xz zAjRTC7~ML*CS;CE`{Df zu1cmQh=Y%ATzM}Zv6&8{Nph2@hCi&;-!Vau)*Nnya8DBs2am zy^BuNVPNhR`2XSSoq|IPx30}($F^-}$F^Famc; zD2Cw4{9QErScCZXOjOw>V+!6=Xv9=uTqRX1i}aNd>Ca{?uBPN z#zwIe#xMxFzfz43upMKYm03Ru%)0yQ&g=JY@{Kh^m_C3ii=x_yWyQuX$$U2O<1p2Y zUb6ZuRJSw~sXhu1-AUET7&aD8}AfG1|#6rDl zGls&(#EH+B>lk@&SWbhBE2Tf5Kwx3Jkzl|Y2^3F6xYTN0`CAzuxHl&(r0k^!As@k` z+ql(^EmKqE>7FMJ!2-kXki*^rBD3QHhW0gfpCw-B{8d76D{yJW0fWHffJNXo8ybVd z9f~Y|@N$gKZ?2g|#wW_Q+eA5@T2uSNf4%IK9|HGd;nLmS@Z^Ek~xR8jxboPxh$s83G1F~mrOHM`cSkcjpc-2>qC zQYd3L^$O6tU_%L>{jsYxIO450w0BV3_X(0zKnO}}h|d?45nC8YXXP&d?n z8ePz^(H|9t8v{NnD|C)-6tWme#P-uu$RrF8DfRe#F2$68G~#~yty|yB;$XUTHIt`l zsXu1hX0u1MR1|&px??6X;ifPL7fGXhO0FuUrc)j(vmJ1yX^?7CPV!t$PDmUz63f~f z-5(eoD~2d|t)_q@#%z(mNi1#pVq}*JY`a7=GVYbjn>g92vx` zlywcW&xN$J(akOj*;v1X-Mog7;z%o5&P+|rfvm!tc6D1m{lnFC`Oh+>c7A9No4hXZ zrGfTcmHD`C42xOTt>JUa#p{e|G*j`M zku+&0O))@KMKry~tf-QmWs4yOI>H;~V7bR&yU(^!InCGj9Jr+hpj+$xjy>`}f z8#5YPRM4&JEiHR6V5N`ro7ca|y#pEzLw1S+(UZNBz48~Uth?%qo*&4F{4k^6sz7b~sUeVzw0{{^H|I07`OH01ex^UcJLHw!i`E%k0 zeni(uOxDKYnqGE6TvbxG5V>XIQ5_mACT0*ztmM~N>GIyQLkrL?uI@uDgXxaQJvwAi zH-k4?gc~Q=tro7wB0FkE_nFYCXLaWd|8bb;{*W5#g-R!`1129ww?hD?FSbqCL~NU8 zwEqY88TSzvAb7-J%qwA-2Z=b1(V;dMw=<3L7jp}Nahn>I30N1WHi=Gzs>Rb_71eews9K4yhz1y4@X6T9{q}DT=Z2d|&zma4s$8 zzpOwc2d8?$ZXPyvjynI5j)Dc-oliH3ZJbVTPdhJ<$NQj8cS+bC>>f4_ZVyN39i8~V zgPbu}(UHGa{d~%FFdbUpiFiTKa?#^Xhe34_0E*Qh2~fu%3%!BPLk)@3=xKLQqy}oR z>U0Poh|v`MTzIWPr}#|rqL`8W=EJV$r@`2p^Cxj3bj*H$oj(ZdB~Y`-^3QhT3~lF$OYl{2RD925xG-gQyN-0q(9C%S zA$hsLR!MIGibw+rConKc7BD8o-P%0KGv&L_xm{ z#>9|9P`jlI<9{kAPda74xtAF%lfW_k5E=mqMs*?kf|h9=c(-ab#E5?D**hmZK21^d_+3#p8t!|4jmTxpeWZiMI{_GDj+!%MSQ4}sn;lufj z_mUmD#W~*5Ztj$}KoAu~i(FUY@%ecD5i~09T|ZpzB^Auz8|1AUW6bzmAl21teCmX+ z1PLRQI?XeA|92c-wxJLonFCuU1oP*i#)&CZpOjbRRW(WI@^Gin*#rhB)pM8YDyJYL z_s^{C_df;Q+_0MYRYBJ#VyKch$YMo0FT1|Yu({Sx9K9GS0(!cGYaZ#+3orzaW3$|U z0+?;XMs{{=t6Sd1L_n0oYU*3WlX%;>z7IH>>Vik-LAzRf@>!Xe21$0ku7yPmGppk# zL797GRBS=W3vmlZzJ-dSv?js2_Kkx~vPv0_D?E^uGmRb&-6tLdVUoL*<-{_HA>zBk zT0(N5E>E*QGAi@PKF*J-6OSdvAB)Q>Ft9@>~BP7M1Vw5Ds)99IogZ(M6JD6D}VhEb2RUR|N&K;2k9Y zffXb;|F@+bfbD>0tuC3}$Xb(-Rl(T%=IH~-nGIy^Ujs-ic-295YBfj$&(!U;HUrO7 zhWk^m5v&`Ct#oKkyn$k+5!UyVDymBQK%Lf|S&E}(yEh#Idr7wY)9tvO29y4FCQg(T zOx-*QoP&tjL%E@yX3tP`oiOp;bv*p6l|su0IH}CJ*W8FkG2=?lD;2eaMmdRUvuReT zW9^MGXj&_ubfC(VNd02mnSM!0iMGk2)fS!$s59bpLXC1ho`S1CsBF+S5Fsjn+UnCh z1#&9U!BRJA1aI=?gLOLXP?)!YoTB=gERurCR!(iV*w*WX!dgnzhF<`K#@2rPM>Gk< zBv0}Aukq+Tpe2xGCd~B8(JSRo|6Jlf{#j!r-T*GhB%u;cw(MW{e#BJ^1ZR74ZGj~e z=Q*NFV|eY0s?j#c*yjed25)x6gOjFxGs=kx%863pM!==80Q zrMzIFhH*R*z?lR)V3jq!4#4nC@=BZ6Qgwnur_i?3gAJW#-arCC;yjrp< zxZ8%phNh}nOrN!gc4y*CM>~XT^hU0WY3Uw|-+QfwkD}I%QqnNNR(vlXt(iY0HO#Ka3281hNO%JR%{WhXzpac>a%GF$G#jgxA+SwQKC7~I$NP%n zf!uB)d4|{YsBqJvB;%^wFc&o`No&%)*)X_R%SvSuEcwOHIp&dJR$C_N^gPaSmBe(S zw<4&@yjyc$WgDfg!njou$|XaxiDW=}gJrq8_xiDJs$^5;x<~(rM~bYd-23aw!V`C= zR0@u>r94M9ZTrVF`3z1WxvoTpTY#|s8$gXdlHi6eM~EvRR-JW#l_?4FpZoV_uni%d zSFNO84p%?`R*{=GL%abGJDpj9!PPR{K{`!X=TiRRww)a`8%FAUqlpS3D zm(7@xYmWX5zVQA5H6hH@`3AqJ{`ME>XK!<;=y@w3jJ03h8m#<&59jpUptYSO!C?bz z)FMgNd4FrbPqBba{B%jIQ?r~Yl=C(4dFMW1+D~uFB@2zTW~ZZN=YvtCtH$lLvK^T0 z7R|c*AF;hO0Zsr+ht7iCK=VAWde~2N8D;QAVBU~RL4n+_&r2eP%XT5hoLw;pWeXB( zzSbvzGx;c!ZYhj5+0D$&t+4$QZO#bHgI?>=FYD*jQuXrO z)eu__s%_DVP-|7U`dI)%__6JhniwQoV+(hB22%4hZWqXQOO1m zQr_6CLMo@YiD)eOW*D-j5TwHH@*QA5JOeQXLz_S)+J5}Hi0}9pz=VpJN1vNDW6R)& z+iqMIO(3uj9Ata#L&^{qYl1c9emDpj9uPJMZh(9I4Pd#a0=U{WKcl=YgexY0>(Qto zsa(ZFaNY?9VHWz%+ZQjl^hR#IVg%pGXuK1SvUhDwL3tv|70v?KN{M!y6rK!cOMc95 z#eX#$f$PJ%#-;PAAw}zN%2Gw*>)aKzt|h+C3~N%a%-zi)4n4l#{X-x*;2c(#T14AnbxHz!GoSZf_9-EDvJ}a zYTsSpfIfJ`my=Ld&NiYWxF zgH7tc8y{MW;9)&=h<^y`HA%1^d`vtg&XZSStPI=Q44bi{adEhGcW2UI-{!m=eRR$? z3cPgpbhv-+;_{S<-As8?0Fd-2+_ywPp)``_Ht6&2HJ}sv71RL2izNQp$?9`UEKj)6 z;WvNMwTB2XgeDkN}*g>JSxuqj>{I5gP;~iwU1#hoq?5YSztc;NHHBRh+mz_ci^J=aSy6(7_VH#Tj z!oz@qD}qX+$Se4_J4apw8J3t_FmF=qUq9SHFHc`uVI$`m2^<$wJyx&GwWYuxa4*Q{r~$>t_&8-P+Jbg`@|p zTk9=*tt#KfE%dR%CKT?hbVTBDd`NXZ_!m=ewWOU+Z(05|WKGR@1Kr626M zPx%7itvv+f(??+mo!WeQTR3G=`uWqARY6YGTr(A6UfNV6;f+TkK#Js8OBZiQwP4lZ zOgB){kfMeczprVdJU@Y5F%v2E!~oCYnqJ(+V&7%GKeh&2zo#|OK7a4^*(N#OIQ24R zU-W6iOa${%gAc&tuK|_Up+3;z%<;zKytfzUF-t4o;q#4$U5Ly{MeH%Ab!OntAlD)J z$@9)zo3h&)Z!Dp7EBxhTu`@?|ygO9ceWjQjHCPTr!*VS4M-zZsj&74{#* ziw)Jj>1DWz;6PZfn+WoTZ@hCtE}9EMt|^&L4Ui7zg_0^AUfVnqbb3tK$~coD{Ke)b zjx}s@5bRea-$V}B`eG{$Z$ok%Rc!HQm2m*%Ew>(i*ZbWx^@imKpNn2dD(%i=;pM_` zPW-FOM9-qE%1A5~Lg<2oL6KV15;Y!<^Yyg8Y`b-=o_`Q8j#CHt9hSG%u-szuP2AL zfAP}dEXD3zO5~x|*48K;JnKYmYAAypuS8f(P!L$Am=G+6CMTZ))bxv3uuB6oJaE_Y zX-mic3iWs6aZ_L};mX!s9U=7J6RVp8O}|f5Af$|(O~kbSc6F{{RK`?x!Jd3eWJ$}l zvIJEvvXZuvCg>Gj|BNF?5^i8m#$e zt((y*eo~7*JFg#`Y&o!lzQ9P?n%vxHK}Ly_4V^ZS-C?heHNWfj-$1dD3vUTbBR<~^ zzujlEY&4$PR<6OM??9=sR9i9?#H^97x$N64Ik2Q-6-TC7UXyuy>u)~H)4{aHkkG&2 z55;E&Ub4!EGN`o+)ZVKG%HD8-NUY{5)oae0cjL7J$Ip} zJT7f+az_}_u{oDVJl(S#hB>&eJ6Upka6FG&%CDBEj3js_z9butvsmZ=k6zhAODSAn z&AHARqAOrkL>Otf>xk?i6~^pK{?79)YH4%5jW;%0q;caUGFM zyyaYsSXtMOJ3ItMIZi^vJr02zB%MA}ndhfmbc3VsLwS`jlUr(vgnrg`nms3~s}B_J zcCsX38t@$ZDim=<;$1WCARL2z5PJv_ov|Q<9zA+DtS)XrSWlExxMK%IyFowhrdMsfk}({gk@qZ-Wqu&q_@%8ypHwM04hp%=vO! zd7BmY)+%RmLt52HL9n>keh483Kx5_SyUz(gr?^UjbqnKCAT=V`$$b{pF4?}ayp2=4 znufgt*LQv9KDZQz!K-+b%R^+DA`L=01%(NbBQi& zF#SC{^IIUaTNP!Uw=ff=U{3JExKGvJuLtk|X)R1x6~$=2JJF|wMy1b6Qg)N__fUdl zMLCUt<*($twn^UYU!5^(q864(cpSK=n;oMU(?nIOpfslamd?r@)|=aLh|q-GEDW3# z$W8Hr=PLc4d2_4hDpi)LNr2wu4C<*Xihz(s?up2`u1QMf06fv6#o9SaZgglwX0!&O zAqU2V`Eqsj%Fr{4fW;7G4B_ft;mHO84SUj<0yKa0E~l9iAOgP&r6J+$C`jpRZ4@di z_(?f*zkf<;29Lq`pnil28WU4=FmL|s+?)`Diu58p5QXtoPBL%?#Fu+}c1jfpi|5m8 zLzHPGXXDUdiS;HFlFs9^!iw)W_ZR7iRKoFXt`$rr(TstkscZd?hX25Y?R%lCcBcvko3jHG@CEm`7F*r zSfJkgFrg^)wDTI}vxRq^ctq)q1XBHl_AfQYJ&DPr5dY3Rr=l;f>Hx?ro9Ck)H%8(` zNTjDEInQXqSAV$Js&lV5r%Mit6JC?t-$ah0xi{Zk21OlMC>u}WXEh{Y|HaWvLt8Pg z$FNVMfM{%`hyM))j00X)5(Tz_dp63y@5vo(hTk|%UlrAX7p9GeLJQ!3Q_5gxz?ovvRw*9jAHVqZ z3|)GMt3pP%kDn8FQy-N`U!5FqSi4uSg{Gj8K_) zj~mmc*B%Ta!CJKl^a)Pc3-Y$Y#-m-EwyW9NmW$yI$sM4I;@yR`j{2;GbPdLY1Zb|% zPfY>TJ<`)io`(RzZEDWAMkE)9Sm%ISzyl(_2hug_%%DNh{7Xax!aoK@RNaK**0GnD z{~ho{2z|le{QSYD)Ra9sTKw8UKO2L832t7*+l&_0RWCTghRL!4bxu5m+R<)&wwFe0%`iY*{(ss><|Q_K+s{y!kr%ZH&_+Uk&Ds&%|Thf z4obG8fQD0n$E*NZSC7Vnq_4}t3Qr^mHdHK>PhnQe*h5%R#RVG{ZAiO3Sii*4ju4D$ z1k$(UG|Nhe&7>gqYa$FW5%%sjrjYwl!~&6PAgY3W(avSxA-(b%M1ZET?`5>I(K43~Ln0*Tv3FZI7FjZ>J%YFK8jkbmX#`8V zu2iWa#I`>{gqdsYb-^PPQlU|Z3KUole9A$7b+i)NJ zr^cdDBqOunL!K9<0LIvO|TgD@r_FXx7Afi)nVc|UUWp)-! z$hv+SLH}o8t&eUmR|_N3QKX<)b!z?%b_xQF#|R77MYEJ;)1r?+CIAmbxh2Dw#`s); zJ70?z<7VRN-`(BI=+ns#9#{twPN6vJD#-K2a+# zoU6v)sSzf`DYIgTOzcZ9GD+xbr_KJ?Q_bVrY7+llLln;;@(A8^dp_f^$*MP%9Vv$| z`h4N+NA`o8<*v=!8C<-!kTk6=5(VVZt1yTURHrsQ;8zw}{LP7@47NvCFy`x1XjE|? zrv|~uYxZp-b%ld!tj%~Fz;J9sR^e^eZO_5sJMi=V86p-;`kv_ZK*X-ipKK#jrBZ?1 zEBKCNfBOQ}LG(Ms=vYMD7&*|Vew8cact<;C| zxzqL3fS>Cs9(n{pYn23=Pu#7FE*j3#*{;(~YziYnyOw07?+2w#=uvkJ2W;zH9L=)V zS_*+3HyXm7tCLu(Ok*#Pgm#(OTSUtnzt#_w>2IDfW<;z+(qSSa#}+8JiVtFFAI-Ta zHv!#&TvSZKNfrXZ##*RrN6`1S9+q^!0QR#^JOOVT2c!B;QoYiRM* z;&A7TB!=(fC`J$o)uae}ZszsVyUtNZpgv47Bire-@(kG}liD}0h|!_eRBPToXK@8| z`}O7edadPOiMAqcA9VD03MQ)$Yr0p~|1|CkKVvPe{Q&@g|CLvMIm`byWcmM$qW?=- zsZ#2*S?7cA`k)dyG$3AyD{L!_P?A7IX##-|D=00`9KszCdA@cr#mUv(QQs&GDpZDh z0^xc%p2}eD;0efM(WFYD0weqSXV@ylq+2+y7jq_YCh(U8(JQ&gDRqdB95U5|sSg?A ziU=&QU^fhbg0df)GAR$FMF1T={$4gCe3wciL8DK6GjtY+!@7LIdjwijo1daNw=pmx zj)It<77BQ?&NUYAXoxndA{nLg$826uOnIDZ>CC&$GwVuKFvdaC zKMq^qgn?y~#^a7<#w5w$l3!secWipI$qAnf;YY@k3a>zfS~%|XiMEb+x;eQEXOr=- z$R2GkZDLu?_hhn0%`fz1f%q-$z;(HVF3^wV?q=3F==#Fpu$Er%jq*C*CNDqyRtCyR zUQWP$febWhEHOwBYfhq19k%em<^-#o7bS5kXgrpvAG31sF8@QW8^8wNiiUuQ%ryKK zQy4l9*=iMc%0MqMEGv|>;0S;MS688~MA^9Au*bg3df$%L`|CE|*MhxisJ9=4onFY0 zrl$yNHvR5h&F;ys6e?>;dF*i$9U&}F9NT{{?X)PfuPsB40zw*TJWVxIic7~PNAb^0EbU(8@oM=-KIK#kDdt|#T2BT^S za5pkT5_99RKQ+pYIM7~wWTqXHJ^CyhRAW;(BywBxrJ+B^-*Bx<`F21O5v01Jo1=5OeEg|hx(2;{95Uryjk&2=D{@u$@7%d{y8zs~M zf*}@mrxp08!q_PAgMpe%cv92_ZlPJm-~m{DZ(vbG2_-L?Egs376* zFph#NK7S775V(ckU{*e{7XG|S0$IjB)4~@sc zlr6i)32+L&hKWy7B3tQnlfBDc=*e+&QTMBr|au3v!K{5y zStdVjXd1&tnoWtbRoe5w#!L4Y%r0fl&U0=>JID3sua()Z`z%5;vIob@Go1Hn%sQQJ z+re+e7HXHLX2dg`3&B)QI!x6zU!8`WJFLbLkx2?jNS;(Fp?0>oJt0^s``hKdNC$U$L(mrq z&~<+E5&FG8P%5)BvnSf$lW?A53qfv)=_c})YO6;i1vhF;3%FmTh;uS)qW@ zJ2_?jYzJ}7jy#*HK&qOa&o!gp`+i>+kv`Je+|bxFuEP*BJ!FPS_wd7{rc1KSMVn+J zzHj#soG|Q*Lt{l87OplvIOImwo~$D@mfp%BoWn3vU1G4b;XB;nu+pF3HJG8K8!_#} z)ks{WWKxV=$mv8Zp*kNOIyV7VfRN4Nn=IwtxSdU&=$NNpwvQXf8w7=k+4;BxuRh@X zKghk|xlrW{1Tw8L7y7GrmpO}p?=CVWA*#~SMlD%mg3nJM+dcm|rc@Q0bCdrnW)iBJ6)fRo(jRb5JX(HuW{+@$qQ+sCNQ=fH3BR;E)_6K4G#k9iFDKP8^upB&jwEx z##YH!OSR5K3ll7k;uC(U0|-oz7w*FAV+HcQ|34ITL}?N+(|_O89(Jc>!4DZ zx?f6y+Ae3Hfr?|Oe0n=+4akHxp=asXn!<{9 zvT}P3IlZmj-Pzja<{Dn0O4s%Y^zCl{HV+n4+;9D=T@B6~@C%_l;AlPW1kxj=!OwaS zn-mH<89ye;a7>*bQN_RobEsCx8%x978!9Eq1JDmnYjY)EH~HlK!n(ZFyEV$p>-xu@ zLkVh3fXKuGYALQCHk{(_U)sozvK!C`5KE&BczW7#{sz^7HJK=N2Dqu!zg`nqDMcTV z8uC|NXKKZf&3NU&h8)WRMcJZQ>cRTP?w15RAJ~U~%$Q2%FrGm7j}j-@eG;--cItgM z5XZgtC7;or1E7~NOJ6`Q7u`~mgipdxSR%gy!$?aJTVE(iUY^pX9x#hAiG>9a+B9?K z)Xx9AI3kaBVXwC^G;2UVm_<(!=iGLHbB590LDYm42?hOuz`d7;B$KACT4vupo703= zRV#Ruc%02Rl>z#(L-8O^FCSU+!{+1~@7h)+PhLo+7hMDlEQ6#?K_NI;$_~mXTrrg- zAq~ar=Ixjy6*)dvjfZ3e?52W%_B<+Z1+Axtr>FaF^bykfC0hpQ$($UKTtQFQ^1UYmZ(}DCqjg}&#mW6wtwDy(xxwT zNk9lHANs&&Xlxw>I{dv@gA^HEGC6Tm5Hj^)lB<9g+TQ~TR1z2Ebi!Cxn;4~Dn6@F6 zVY|8|!KID;5GBWcXl|Ai7m&vrSyy=4u+NK{EVN)oFkdYid!j*4xr>onvf%4w%kNg~ zL})6jnUV_{$C^%w?i6IiI8CMvP_1H78+ptzs;cm6a~@{ttql&OLTqaV;0&UdFm9 z>!BJ_RgAeV*bo6byFe(?(h_*xB(c$~hk2tpMF^dBtbYQLFu1_>N-oq)TS{r4`#_Li zBx(<4TUfo6K^LOMU6$cJ8{Qen#9K+%oGCIsEL}IGr6(M7GR*ZFO;GGm;3=opvTVD1 zzPhqW9{n$iO`@S%SbworFFy!%>ZQN~(-#^qn#r^`5{erB{P#TEk zfT};J-5KgmBuoU^^zx>Ah?2o5sJFH|jdj#7F^^c>E*{~B2SxYuujk)E#U1sg?e^!* zR-gM!*u+!IyP@sK5|~Rc6t;|fgq*%n2&|&#Ho>CC@FG0sLQ|uhJdj$j^@oqOlQc&j ze^mn;N<`3af$CD@{O*}lSgML*mWXob>TqL;@HS|;(IEHT;&G|TQ8&+m-CnSV8k9ht zUZ=Y0HY!GbDWHa?Q6U7(o!s{PNhF@nAWQi!+v)o(!zXN?_sc+a#%ppr2iEe#oG~Q= z>0OLb+G8*>0adOzc_ivkg=1y$tKgm;NP(W<3C0qViKwlcuQ>V#9mT#4a;-k+Wq7eI zi3hsBsCfzXuY3KPD`JXnc*$>fgxX&e2s7@Zjx>k{>m!oRPcLZYUfu~k3Rf8-k^@yu zTRG@XA8oFZ%1Dy=^Oobj~2ry(e02CVtlp#;9p500o=NtBm4xXo>=;Y&*zR zDkXk^F_2cmq~Oi}_;d`WJ7k&Je*V0pjrO#j6z_kI1I{*n($EaM@;Fl}R7tbiEnDqP zx`osDOVQ!{0M^`Xk0{`6&*(cj91gs(vcZaw&eEBjfstw^{GEDXe+YepdySmH^#(;)2iK97B#^p`u+ayIhw9i z?}=-n-;1%PqR`Xe3lx|>82_QVuBW2aHH=4C+b;TJq9fLw`MAsl;QYusuQbR=c?K#Z zM}W=~VoR%yyXB<1A6E)EXNs&#r>t-a7QD7H+7JpK_R5;JX~`GHuw|_pA%7ft59jAk z@WDy10=3;O(4Vr@>xWXHaM2X@t#dkrseZ8cTp)`;wS>Q^6;o1-`MQezaR-i?e%KNm zeqr=6!gbXaZ|&u#BEHIb9+3i+L9PTPJo&z74k|4Y!_K_&#{I3r8(sJ~SRPx;3{LYv zFIO?HFz<_`YqJ3l;WeT&89tT2r*j;?!BvMMG|6qtvAY-sX%5k=54AX}ovrxwmDZS>M%b3tuIbF3n8~o-gdfn5-esdf3}3 z89mgvR{Pu2RtjOdOtY*@z{L)CqG zYq5$t%M9c6cCiLcZi8#UcXF0_`(FRMzl?rPM?E31EorU9>@OrqDQ{m{Jot#;Tw=YD#*oe!>K zwKc||f2F_O-7iavms1f@ayuKg{*4R@W!!pK?U3xY`=*Q5WYoFf`KZijji|Np89BKS zp2=Km>&TCFR34^sTn$?sg{qzk+aj3xu6Hro;B3F-RU3npf-K&OV~~s@j3)ldrF_Gc z@HYwVtoS@zqeqfW%6_pS7lW98fZGyvYkJ26X?8osWu`?;E->K9Ldm(xWhLPk(t*+~ zx0pT;t))0RIUN~R(d<5sC+j3~&xBkTEL37KB{sU+q#zR!>qylXeArq2*H_o+Zf5tO z#&C9nNLQGXsYPv5q8hr)Lf*nv=nY=)Tf@7IQ5aTHA)(d5W;=eI>0LIDg~yuwt~c7R z-vuNS;E_iO z7MaAQ0d?Xy?|*iG2C)Cx{o&yoof6dLh$R*5KW|(UjWpfYc%=RGE*>`>XS!~kfK%43 zs>oMxhfND|_8{jzNt9qECzmB8wIQACh$w$4aQc)an;?yWoK9wVW9T@-qCul_?2~ya z>>9V0L1lXwUrx4&uUJ(`8hvcsfouSvTNQn`Xa=k2n3UZhMcj+9;R0;rK|;MH(9vLd z_zOw#0Ts~tM4#;t3-NK#ugmhMNhd+aL9CBTkM_OZL~L4?H&a7rXC@8e?nEq0&0I+| zvsv4dv^6J@7f6VktzPnoSwNAz7Wx!ROQ@{`+!(s?5@W2nG`I7FT@M_mTWWj3lt^~% zw^%muY7TpyC>86xRLe3rN1I@BZLI-)PJj(V06saiCXqpG-P6gwn?CCOD`#6 ze|Mi@X>Ctb8;&krWMp&nIetKWr;k-3@RK6?cmLJ#^M)hmMM+=i`M*l>ybjn z%j=X^x`)f1i#A5g0|quD(lRZ?cV?;?p~GKeO;i*Lq>lY~2nY7*g~t@rjV=v8lq9le z6e*A&Cflp;kKZTi$t$U81M0G;c{o74Q+Qi@8$uP6pF{72i!7l`LA%~+r)7ED5-AS2 zgtZ=c;JTfhGUzv#6^yr%M=*?Ovq=K7Tl*7)yOj{Z$fj)W-4aFsI287gKv5Ox*pDSP zD__c?K5|Ie6115%jj{XdAwCvwYIH~6{OU9(PCX?bC65IDo3=1AfAZtfKVA;>2Xe!3 z9&klz8Sq$Z`iIDB*me#jo^dO>M|_2D`ZKge6)2%KW-w(%P9i4)qjpby)bLn+&nd6T ze7vNrzySZg@6>EQ7>Nt?45E|}y$lBBBe;%6=`vO{-K2D2wVCW|usPY`vjMwt+5%#S z{X&@Atd;c#w%!NZN)gv3Dbe;C=`gEM)@gzHq@_<`V6KRImMa_ES<1Yf`aA_*;vi8Y zFmEj0_qy1QvCN`(f5(;F~ej<2K4`Bfmck-*Ny1m@?4hE zysuNvzm^?+0moW5i{7R|iFf4Q3Q=$o46u*@q1$wgT8v#0n41}vTp=s+aRFnK(P7?9 z+(}#mi21q!%#R4JKjrk6E@C zp1ig<+9hX4yAaH2;5_!E&iZ0?;@ThSxNv5&v&r^GbP9zGLzFEU1(7MWf6NJ##SgexDEI7k8G{?gnxXek3aQunfVZ#KkVe?aUceXqM zVe!k@7^AyWfp4B&a=cw$*xi6@jJp&Knpkenz1~SW_d0$4mjAhS>#jDUPz|tjBV*~m zB@CRX5Bj~T_@9mh+&g^+36WO9lcZx1CgnbqROZl$^rP3so}ZT0nL;heaTT(mdBlzA6!ddfdNf$S zHGS_Lvxnxuxi`u=^?P}Eu_5Je%ZBxEWkZ@$MM{xoscv_N$JevR_cJt3vnG!c$qFX& zPZ0v7p!TKT(y8n;>Z$QqGd@af(*A2k{NarX1W~y%qz3I&GDVCEs5$2w&8VdW~!=Kb1)SS@9A z*hMYtpkgBnry4#g5`^F0LZiRr9LM`V?AGa82!I2#9h@3fPfhFJ<0)Q z@p>f(AQRPp8|Y0+pw(+3Xs*)Qgk)E|J}8=}43+%bHFlEwJw67qb0d`&ETdG#dy2{* zQ>TF6b~VblQq-v;6GWx1MWvI_4H8UkrH5-oR9HSCd^t-}d5j?y#Nzh}`VQN{smI`C zc&^JekB4GYv}xaa%DmIB?SsgC~^KzpKQ zevG#H6kwyC*{VaZ@ zHIxyH2eB0!x2++sVBIBsJfDCj0XOEO28%IZ5o^e_iSIZ3}LIg#MpC9F#4eniqKstv=iF&%REzCJsKft;OHHp#m1 z@NsmN1J;DK+KKtYf$z#Hh94F*Wz)bx2|cq?G)__(*_~O2xi+<~Ob^$OrOS6qkJJoa z2?kFpw&5OhH!l8eY6$qGo{aeE&mw)lW1le-gyz2K>unCpPpdZ#)ThQRi2+Ge5#--t z_cBFSziaEP-ZfP_4!xmikT?$PxiZhoC;?Q)sLeu}_7^2~HMSlDc^b*FPgyovH&rWzWIKNh z%DNhLsciOHwD}Gq=Q|wLMx~EwJt#kz4p93w_ei8>YS-CC5D$UiaGT)v%OnvKtklX& zhl?(r*&Z)jLrKA$4Jq}JIqYUBj@a?=XGCGWHKjHjT2V|wn~&YDNX?u0ZFtJmLRghM zP$DhJ^}b)tHi?rd%%SfLN~aLtJjgcqSrmz-T^HVkTRo~R)Wo3-c}BYJL$F+}YM%nJ zE?z?wjc%M^-ibUMyKprYJT?8+3UK^c{auqtc)bkK?}D&ojr%ur+-chg@>v0uY63RK zI}*!>Cz$_%4+d%Sadc-t_5ZMS4&0$dQId{r+qP{RFSc#lw(aD_wr$(CZMB5j#ey{vxHXGb z%hfaTFn9uody}iU3g+P)e#R#X!!UzIZ?Hyp;3-_zij@4alveds)O%8G$U6L%&xZL=%|qmw{7 z#0`0LnzHs{$p4|Z3+=n6=Ro`UvRr7GLYIh3bKP;*~!{`)ek=Z*gwn zavB$;w-8W}hjW~1uL%wU#Nj_ci^a}6+h8k>Reb;WU09e-q7zl~mehYZdgD(x<-UKp z{D~<;?tswt*h4z@FL!a@pI~f3u++7su5^gEjke4W+UP3Bd-CdemyW5shV^R7oXWI; zLb~c`MPGh~b6a0e{>@j?HUtPo4h~N`&cqUap0L#cew#7o;H!tR&^eE-D%PsqI7_sj zz|OvumH3E&lC;1TjF|!$U$Zhuj|K+!^e5yw2bN6c&K0KKoJ&(M0X_u%K5h0?9l1~e=aWc!GOW{5aW>K+&slKkh2jN+SoBq$2I`(465dtI(&I>7xIIr~Qhx#H+UcM|Eg} z*oG8aSxgyl%W=O=vz5%^Xi(jl$LYEzM|ee--6Nb}JDMxc*k-GmZdqq4DE?M?QIw{v zftt3dULUEULxY3i9!~3zIyeP>pyWr;s2=DUVc*tz0A+Ms#=UKuLBRcsirUw;s54=B zOs`|V=%_dwb*zhsdNbLf9X(h`D6`p??e@>6(;ZC<+d1w=oY9_G7PbGYc6BxjIi_qa zt{Q0^W{e}M>UZu5er1Vs_cAW~-hi(JxFXMqcb zvl$tON4;Mz%G(iXk`3uUx;@J*OY@-gg&&~qeQCvZ>X$K6czT9?+5<$#=WWRx-wjAP zJ5AutG$!q`ImXjfSK2! z)7kZ~bB~l$)E%T}7*|phG!W99tdT&TYpyguDPP${US+;Z21mOW9{(m1Trg(I7lyB@ z$4stgGd7nh?TE!NM%U4>bS*1jyTpH$fBjbNb}L3|I%540i$Xqs!2fT(2=ym)S>hK$ z0rQK(`kxz#|9tKJmsI=@C>9&S_l=%|1fK{NPO)4a{N!GN#)j=khjDlgtf-=WlgOsE zg(Pt~$IaYN4|k&Hy;h(tbEFtq+-T=%u2(|beh0;><_tR8j2=u+keX0(?*va!gao=r zL1a4$*v>?86Or_u6lEK+1+d!y#6SP2?!z5MslP<*L_t)u!C;*T_9_o3KfyVCUvy0g zqHw3<(BTq$6xO)IyG5Zciv5YTg!DAZV);-O`eZ2p_vw)ffm4LEsl|yr@}O0Jb=s@~ zkywCtKv9R69FtEd^bT4z)Zp4&z!uF6Y&<+%TQsYYZlYzOc=?rd=d!zTb$WX+d-+6z zwi7BVPV`N2*fw1PTG?Sp|5t5O}F+YG%xS>PKeiIM6Cclf^7M zYOcf9SzgYa)&L9)J8^Vi4Vrs{@ax8iAM#zi8L_;-zxV?AX>V`qY|Dt%j@o;HueP$+ ziY3io!}@vwtFhnirOi)Q@<9Yc)gs0flYP!hNHTQH^{a=&0-J`gk0leOFlQ9%y>}fr z0X`lGp5bf<{5dOTL6HeMF)0&86l`;$=9Zgh{oNNYa&vKILIlx}IQ=M#-bc`b%qBYi zGYjNu{L@Z?0B zROk0Q+8!mTO_dO`-80f%xtzwLYmbUtD5PNKvVjrI%@WR1vK=bu28$L3Qps?L!gyk0 z1ZJ!LJCUAYrs%y6OwwBId}AkX?h-#lkx3OhML1?B9aaBL^Mr9oTZpFYhEwl17*nj( zje>z>rM`>-9$ev24>W}-lgT+z-+q$O?)**Ki#P|?8fTDPS~0=kwjW@1zA%04r+ZZD z4tb14hNae4O(i%U1}MqDpMrB4T3jV`0ezWV9I`^x{0m3wvQRd14NP2MaD$kpgaQ># zKdT6cnKFxUH?E?{=pU6%FAvwu)Uk_6F8c!z zzZ=Y)b(|m#bSTt=%qJ4*ZdVW2Fd=oEm)aH00E`5cW+3G|5R1^nK^oxhMbT2+vkPJ> z0T|)c&mwQdG{awz|Jc<&g6##96!KaDtf;D=WQ84Po`lnk&_O+Kf+AHF#1t$?AVV}t z9QvU*STiK4U?5>Gt|i8SF_L1qCfH_Z+AnRO>C|a|Wq2`+birYTHX|}KXbfIZcdZga zA2wL{=9uYFMpsb4=ncG_nu(A=>H=hSj`_Mp+ef2vRm;3== zc2%b|2g8uwMotqwzWBfRkk}*OHtXi7c5PWc?Wj}KN=@5c$Gy1xQYkH z2d$jXg;F$!K}#Y8;Fqj;UvdtVns7%k_-rTyvhQvo6zqzgV%ywGK( zkL_ll9JFZ2{4a5*4G7-1j0M3b8Le7`@37Bkajg^|(?5-4HmIR&kP1l#p+F6t^+-@9 zZLpdV)~I#!>OAU5OC~jT#^4l_UAf*+77-h64~7m%f;oe|>edKCxDNN0JDPuVVpU~5 z-SNzcL-4TgOgs;)uBz@-#zCkpe}6Cg8~d#^d}efd0Q%=hEh*n58jY$3(W-C{i-;Wn zlXKSLAl`(sfdC5&;gV|$B}cG~r8J0MzNu6hRZK(&6q0arx00JrqA?A-$Z_Fi>8$$_ z9RK4kSMx2*ot~UG!Uq_#c%1k1Ikjq488iBJidHA5?DWohS7hVmJhNPu4w?O>E_JDn z&gHe?@e-n^ywPNnFtimYOJ+1{vl~uokj{_~QQfcg$i_^t^oD=b?o_b4N-?`>vgc{G zLv^{ip+w2Y8yR-BqOe;w3J%&>p*sB8&AHjv#$e0gX8kC1Q1&rn%VKM{;cC}dn&FX` zjDFa|%NOG+1K<97GR*gNZ>C<@Xmq5$f2MwschzRc^;>1skjnOWL*6g#=@7g*;UB1= zu~YvaoM)LbaCwxU(;%%;&g z#GPLD6HMasRvPcixi$Q)TKj!iW5_KKof}bjJ@@yGjD5sH^wyR&197yt;6bEPi=?kn zGVbs^5jNkJJhqBglkK~%T#wM%8mOTRW%!-7~EHJ8EHT`L59maNtvRl-O|Ga6Z&| zjZG@il2+kR9 zob>=a^E!~s+&-FjREaDayeGad{Pn zCUm@Nj zfoSDHoY;htB@`f{94?ANUEu1Jl0qak#G_;8K$6ssv{|No&k_o;hQN=BbY7KwMmx_x z*c!4aKbDADlBc3!ItmWHdCe9-WgUMq8Kmw;1DBh_>*4wRP!1pJG4?816)ALdktUmL zRNjb9$b83~#%zx{`5V_EOlu!#WFqx#^-vZYUqBY^_=R4n@-uTS#czdaqVOT)AH@^rp^yuYF!j4qV=cf&i3+ja(Ma1 z+UZ3sKIRRvo1N-;3JEHwi(NNlzkw$qMkKV#btzZ$_Yp0_pk`!>N14aEkWR*EQ`O7V zs^sc(5-GiNQJFJC*P#bW$v_cF+n&=zmTMMbElwcU=x@LASj2diJx8ft6PIc!G*&C>SssWoQ7a*^ zP#LZ0mxKiCPbxS}=es-`Ci=6A{hZz3?i?i`Q~?Hln@5M(!QqoXZtN+=S-x5HG814U zq#deHNR5G1jbk0!r|d-P6Elz&)dUf=<$_LBm5DJNU1l|MC+z76DFXRc{)_|VawW3 zJiw$mR>M@zErKZpte}0F!U6$P(UUVkv3egrbq3cDm(z#?pX^Emsh1ST<5~LNL@Zvi zKv_k^+azpeH=bO-dx6m|XiqNZpyrB&-u1S-Bt1p?be0$ zBL=HbxP1!Jm>D#5+(s3U@owz$GnI5~OeudGRHt}`QqdjPr&8cW2l*m@XPTNdhD?G) zjwb5YG97dJHnt?#x*x2ovrp1`^9>?2H%?^J4l$(}fXMc54QGFo+f1(AYN+SbRJ@nu zDoQxstTL4Vnn+0@iDMoX7mtd;Y~^4nXCXD8Q|G0fou}x^L2|Kl%+|JmkYcuU?bdkA z=VI`CzMUV(8zQca&oTfv_iqt~kr;Hy5_H@VK%*z!2Y3)Ix;QZ4`F+2ihQags**~c@ zc#4~;M^Xkx7Dmcm*ev3uZ&80bBmD^38d@V^VDHA@v8F(=aom*`6m>W30^;H?x!z(n zg3kZ&BIV%6f5681BL8q87A*3LoPH9Kc&*O*AjW3j=-JzGQj?*!-dzD=(eUjt##4br z0oSUL` zwuk|V19y>_x?nt6xo*U(KDPJLyKl2l@vo}_{)h3Eny$DX9v4pKT0|L$CT`4C=GRpUnw0DaFA`jWkM3wa66!^@F1h`4NQ>% zBt+bSWs;Q{LNecXwTXK#>tNXdcgF-N+wW zr9OP=tUK`XDJI2&ep(1qGBZV-6m~-{oKG1y5ZMFx0|A2X91No0-7Ic=0;&-JTh3#q z@?Z^-R6~po^yDp@^1P0=fQ1}fBZD$Vz3%(r|}QQ zpA&ZEmEnl_VZutxVhR~cV`96%$Jmjbuw5TqT#La?Y|xd{lMq= zQ`Bqc=1p>%GxQsl4p9K<3uVMko=dhE{*-U??4YPi_Yrxl`=Z-D%LsTxF4CPgP zGdR-K$V%~@V{n(sYSMFt5IHJe%?j+h&Kz=e(>5$gw{$Jtw$CrmY$B~p#@9ze0}EZ)bE@k|U|v<*|W@02-%MHK9gRE2>LfXFaQiJl#YU+)^HFYjNKRwqsrA@ zM4Ol?Nc=CiUc$&V$#OAIwWs_{50}q=9DK~W5d%F?$5I8VQ8px3v5Hzw0%G<#w-}6!#5L3xiFRVar+gd}OwlhKg**cwU)>A^9bw zO?TM)KM62%458;``r`;v{41owa267XL3G|bbK*PaEVBtY803EhLA>m<&bxq#o@q<5 zIiIZ9Y3sGGd9%T{<45Q4^5GX%yb~?xl13Z?rU%=YND$4nrl3;Zi(#HBF^cMvi36T z!PLPqvW<|?D9@vBUtU+gyvI%|Eg5v5V3&;FyobTxN1ye*VdHZqjgg8I)8ds}Npg$W zwz=U$3d#IB0X(kx#c{Fr!jGn$;eHAb031B)#ImCo-@Px!j!32|H6%m4n9INjslaNV z5e3OzxZI}K`x>ab;fe|}N1Z^As~{;1(Nc`B*=(q~Bfv6=ipB;vV#V^Xa`cCb=KBh7 zoi!%ThkIvP6$Ty9o8`FNr|Kdh)1O^zg3lRU?Y1tn)TrO1TR=`C!nB407F>VXteW*M zK@Zeh@{}hPQ7R89%(^^D)X=&+)4*IC3cmm>M4RG9Qk+OS9O9Du-PoP=p%){T@7}o$ zq2lgn^&We=La|F|nMQf5-r25^z8{LKepbFB7kt1UK2*jWHFbDJ%9?bu@$+WHFSJ7o zoZsQp?^Oe~PQh3-=HG{GKFyg$ zT*uNs!_KyKo4fp2fD5rDGU9m)S8H^o&J8WtiuPbc;zg8Wih1d1M6Lk{W*^BYWN6bW zItj6DnFK~1aRVx_Uhz74S-Y+-!GB_Viqct>Bx`Djn_-bu7CfeYMD`_po1<%^$m4u9 ze^TI5HrW=}?|;UQ6X3mpw&u9$py9)kQLoN|Bl+Ik?d^^8T<}yhSD%8+lMEzst_*7I zW_%g@Oo_z%_aUrKjxS0DCG@C9)o0$Ng6U2LiqB65gA*THRh-Iv5(9vbf^XHY%QtQ@ z*rQAwStACFq%6lfRTpKsa#!&yXE8|L@UnvlbI*Rf*Bz$h>2lMtDeD|R3`O?XFOXB{ zYdRXMt;W_#I`9~hLVCU>p`;+)ajF+soaty6JQBvMiK*xaU<;%Y>5RtIYjsfAFEy~D zjP@2T+&{RqM7{M$)tu3j`#pQakZxg6Oz`9@ z%#u5NxOPjcIK3*LG1n#j16&+`K_`RdFRxxU>WEwUj08}-aMmpqy`l%h(y&tNhr6JC zlMx;SXHv+P(=1DsY-*K4C!xt^1d1lvyJW$Zp0Nv-dGo~Uvb>O5$b^f)WJ$U^kx^^ze>`=+-I3Lv@ka}SdOZ%dP zGk1@?laYiz;VhLR=;3?uk+s>&A4&!+F~W_Fx0_Edsoi8+7*QI zmgFmU$gFfk+xB>;Txl7GWeppm1~5k zlJPD0sV>5A@*Eb0x!#c@_9!b9DN+jPO%S|S{@agRWH$$Ug{jEunn*{MN)`Gj4r8}{ zg!QD2g3tB%8gKPjJK84Fjplj2qBX*K)yu`#@ucZH6uVr`F#1-_O%{W^e=!9jT^xHe z;WN*nQB%SY(&~`nq`R zhwiz6vG;m$Ke*i23nI{N!7qp|t{DhgDST!y|f+cfvoXo8J75yY~poGNoKzPjGjPqmvY5md7mI9#_$G=!PvhZnQ7RbUi<}aeT8_i*FS=a*-u#Y9U@aU^RCwrs80sc}4%*m?elW z>sc{MuY@ndf15!yQ$n2jMN~0IhQ!h=He_6MrC-FTMl#ZUWJ+ik8-9KV)fYbL6`QKb zZZ2R|%c4Wkuar?ffTa%=2P zxy*fkG4$E-uZeQRWlv{UE-$l#rl`a&zrm32Whm#P%gbNQqU)t~oJnP~$|+onMkCY7 zr=C4Ge|_QJ?Bnz*K6|?x$e1-d8B2&mzzmn=R^m=i{o3n2G~WAkiI4!*UaW9-q#4r4cA3zW;a*u>SbaBzmCFF-Dp$31 z34eIbgaG|{s-{|Sb6&_xHLo@`bdC*$*=^6IscRATYb_bY#WK;#W5zL!eDu!9nPgMS zHP#Ls;aXUEgegbC=jz<7!ithdTkejD#_d!4fnvGg+0Y-ky`KGU{`(F`hirEp0oMrD z(I%0l3S@Y@N^muPJ8GC!T)T1vcJ40dVlB%NI%zkQIJ;h7^UCgQv<0nK1FlkPou@=7 zaz;1h+0qXLN@az-Lh>A|5@KZYQHVJ0i&qoU0#5tF6awhqqkt3*Hxwco1z?;|FBkos%H5SeQ%3$hQIzXOqe*`UaXhO0tvTPbnfGa}c>e$~w^@|M zQMGHH@3e&KxAZ|3+WvkJ{V9miNhY7FI4gdg&FNO^k@r>J9(EUeY zJMeB(YojB-?^rN2X2qr})Ws5pP%8B9Gt=sHwBQ+h{SN+_d(hK{OM?B`*{c^m@42$6 zBl0kdY9?<@0xR)rFLc55aMq4ImdDe2xc!>HJla^VZP%}Ndam+bs#M?9IaB_rJFb)V z=v4zqUTwFWxptLJFW$1;7G7)i8Qe>aK>0EJdT58!LhwGMJ*!;!?nLb)31XyO1;7d= zyyo+t6n%K~uCvNn{mGY=?kw}88|G6ew_&L;;m0C|GmwBv_RF8T7*}!d|2fYlTP(AO zfd>Fs{$=R>KmI`{r~h0gALCg3mO?#u`hjM8ho?)x(CgSPYplwd@n(_PTwTw^jtmMQ zkvPaEz(@|ie1G(S`x8(ohM}%!<#^&q58dD000TTGW{-{wV@;iKI?!~=(_u(|qlzFq zPLDiNz#VGzF+K39gK%c0AJM>!ljGzh$kP~|tC{n7cswA9MQbG2Lu!a2rzU#OjA;ms z-Vu6b9`({p2XYX_S|otMxf6EkV)%hVz~3j*oB>B1WYaK%5d=od1jQq|@|#M^1%^i- ziXwrrU@<$QA360zg8-Sw>yd`LCo5sd8h8WAEBW9QaEfF2op|JC_xO0ZxtaZmt->sA zY%Rb@@AmQY_`Vrj?nEP)D*eI2)-F`NPYJ>_>QZGOuy~dFaW|U8zNqZihpz)Pbj2JJBlu8!}6OsYwiR_P57&w>R zT7~!v(un%4vPum=BkFKST=ZI;U8u(#(J&P{9Uo_jK@W~?n42;nFoYEuwV2tg zvv>>-E^+R6Oa<$093PLc27PX$@A(mXt~G7a_RpP<{z4-a?X(Oct5=RNtw%8Qa&%;Z zqYOXpe#FB|s1O4MM9!MA7>Qm?LKY&kmSRf%$!1|flPiEl<DcQ#|U#6hqQ5-F0;-&eG!mh&|H;}x3_#L zztVm?*1SAgEsA62^!7YFWOQW|^%q+fX`InFQt&pC7hm}2o&;x?`t+Ilg(YezXHJ&?k#a1Nhs`;XV1q&}I_y zj@`@vH+L$PUTWMC{FHQw7am&CamT>S*3^*a5YK0RE9lo*TXM-nl-j6=nD&{qo7ql0wX%t5~50&jdg==Mbjeuzq z4B+5o*@l1~w}BmXpDj)U+REA}JYNt_;BKNL#4@f52=*X0|BLb-Vj6Pok&@&5nP;x+ zn$7i|zuNZ@PrUtso0Q%l!g*@mvvl1?Wv5D2Z_^Kr7Ye>C201-eUOJ~3($|TZJ;?>_+~RNlc?o+$XE(e(-!{LH+tszD3 z?g)iIP!H}Ge~>`gk~wUl&jakn$)B49^4H#lz-DAj4T~;=hv(nKpF1o!>Re@RFl#Ne z*TOgs4(GRI+|e-PhN-(H)W}Sl6OQy&nqo};A(J=8pa$}+q}tf?jF$?lrIpL4&l2bv z1)LJ~mTM1hlUbKeUKuW$##LLdToW%MWi|bG_zZmUePPfNq^G&0KF;MD;4S+O>b~Jj zn@;N+Tz7xqfP(KM*$2ZS@%7?2fHsX=vyHIpxvLlp3*NYadj|+9zc@&kzvk|+0Z~s1 zD|~RvLIH3lmkx7$FsD=l9!#-q=iOQrm6WB>`HjF1JuG1ErkIpod z^08sKiL=lB$dJTPRfXsfLrK!61qt-rIS_qRBQ?Z&Ygj6y zugtv7*3QPE*8?aiq)T@BtKRXp9C~IRzSOS(jJIe^lJ+Ox$^x2qz|7Z-w9;gVikZ&CV{Tqd!3(;uu$C4C(BW=%Wgm{922Xaiki&7Zw_ji}W zG|)?Bu}-8QD~cCuJ)-L~4nb7s>KXi0gNXDBfZ+gLdIAZ+0hlv|6th2ZWPT8Vrm|Kn zk-*y*Us9;c1d>paMXN$as*=$b13UtZVmeqqCh-TV1GCbH-DaDJG0Zi&v+|k@?_?f0 zyufgZj^2g9R{9If8SP68OlskfEWnGL3eGW_?9b<9pwGQclr-BXw^!VTIaUXX4&+9j zO69LJtdQLxDRnz^%obTB3Y0w0Ggbo6gB(yp5vxhc1oEo$m<$vHgfb&gkGww#!hcN2 zJ`{8C4bv;C2Gh*u%M;E#P7uLt~?J<4xqQ zSr=W!mWf5cSIZ z5<)-S9mC;tlv>E!C<;(uqM%^^2LoUf-Y8%#dBmN?=h_Od0>LjWm~v18(Z(Day?=ZP?iR2l3nRRR*~cA82dI15ZDfSyGS zJb$3l1U^a;XCe>=2PZbi5I|wf+F+FIFB>TDHF&=1q^L0ZD5V_hOL}KUD_osT=An;< z0*^0kQ@$jGX9y)?mSUBOf@NevKL`$D3Cc0631!tZj@XUR(w=nDss`Q4Dp?F7wo(@^ ztXy5nC~a48k#_-@HBuvuXG)ZS@EUJhT}55AxC9ucB7|5lnX`l!Lwj?@0{B?*Ur23j zr5&i+eUttVhis68Z`KCcIhxVSsTv7Es0Y5&BiyQ{Z1&mVw82FT{6*80H7ico*O;Cu zk&VjouxnMvq?-~El0X2FK`?ooP|~sHU=cKpPmc%72d1v_sNrMH34_x7uf1Ae2}u$ZBpHeVBX_Eyu3ULYH9GcL(vb?oax%+!IhHN5bL9u> zezFKvLj3fOYCP%(#(Esw(|9|_>J6;37%{}DRmacmO!ap~R(}EiTWW`Y2i)4!9Qodr z#8_4#$s~T50ZCZ`w=yN*i7eSDcO-Uo4Qulf8zP&wHqh3?1!^FZTPW(U~2-0KwCcp_wI(S{7|Aixm+-JRx%*PX-aM z=hL1&&Gf@2c&d#cW^@*?ch<(txOU)~SzgKwEpXO4)a;Hfos~N(N6OIhJy9^IcJab7 zL8wxhK!yrt)BR{dswf;;RMgYOt?X}g2JN4aaEtrBHE6}5Z(AgIUl-MFlhnSeT55!4 z^|96|a4Nr!{J3wtan&%hcyv$%gOvxokm_bo11ns!s6)l?>g$?*y}Z&gb@;#eeN+?_ zu0`^-4O^XrwapP9YeYw`e}S~?U?|WfXK-*R#%~V>wC-)_Yt9wA3d!!=KBrg($gQa{)U#9{L$=Oh#|wi9X^Yd@dbV ziKl&YJ9(DN_1iTaS3hmr8{Zc{qX7P-U~+#?gBpAOsrO%)-n*3f=i9kSLc40Ml>_&) z(2DEObE2N>ZUAW)cPGrX>d58E8$f^WQq?mWR_lv6=pkLZ+q|5H{V(`xa>E6}0afVV z{;002*4Or*$-PY#@8scHId$)b%AMdpl|3YjC(nrSyG4b9W$*(4p&)~~%G3=Fb zG2KGlH*39U=f9jN>BZZ_byhsN)6pKrf8>p|K>)V(9YGXdKa8RYE=93NLZ2(vk8>aO zD0g`S6|Kw%`piFw_J;C58?{1!%I&+?p7}$(HG8Ssjkq<=bBm4K+pcL^2xa-3^&9_c zM)hv5Ic+-+pP!)lcxSkIquU3RuE zHYSc1M*m?7>sGtA-5f*xS6kX<{~b;#J8H|0f720yMXv7B=xD&5MHkqIl_2ba6FE2q zLrm2V^oenOd7bpMl~F)5eo0`Cj;fzQboQ1(Hw-lO_JUAV*-ES?wZZSmP=!>|wcQ%0Jkzn;oQm{9d*M&J z)h!v8AjXptOM#i3fTloCf32fZZ15xCA<@JtpEMOvAKowML!6z|98nW1;Wz)05pv71>f{rU9Smp4m0PH$U&qm!pE z8z-(_>@3;abD~&M-r48<5esR=pZ=IDtY6!l#IVWlMi82^6^OzGVvLPRt#~3qjn{5i zdqLYTxkQX>Nup69HOm4~7mykiScJpxRPN8!$&uSXn3tCqU2c_v4gqpYSoUhj2zy&c z!b=1=^rc87cgENYS3=kgmHir<;%Ts;2-hky?Y`72b#t)4NFAbu4Vl)PEF1aQ=0(ZC zU@@SUC+13Q;uf14VB0|^^V+p=`1m+@e5eiH+A}?rd3eWuo-ar1MF-=DAdb{Ct!r1_=brNNE2sAw+sm_xYf>3C__UN7E>-v2PPS$3D|7R3kULD^=~zKz1^8#RL4zi zIr@u8>`NOA-3478+a3AeHN;c)sGO@MH&nF5s?FC)TZHFd3}nf)6aI@MfmDrApm>if zv^<73YQjwQT2$DjjL2TaK53{?W#N4GlyI)uX7su7oq0h@XW$dVa_!UfVV2=IB{@78 zwq+cv<=DFpE#CQeZB}4Ok7`0@K;&B7L>as4?tSU-rn*d5+^_)jQsYlteSRQ>>n6du=HG$tzKz z0Ii;Gp(A)k3Z~+VnYEP+x<$7!9AArATXv#Rrk+>>odjaq&w!D^or$eTL4$mmFtDe- zdq4JO*?87iB%aSl^Y>gthDYCw34O|i$5$HEHU%FLaSuO&FgydWhoOo$#KtqrG&~fO zDVDLy13FrYG7^uWswE3L>rmp$6m$;ioX8T4C~W(8mr$5 zCntZe{;7TC`J3EgE$?tAmGb)YR?qJpz2tLM@B5nml5O;Omhv2;Cwc~EQL?J9TEmdI zDITc%*$UosWp}X=+c@jp6fmT)+Fg@3KF|_i_tc*oe|Emd=&c_8t3~U3v{+3aIp2^c z)yX3N)DR9iF?xjuz-H6-Nndyl}d4;##)%{U>L&HHE|fH2$a>q+YwQeS}Mi42Fiy<8TD`P;hevpfrvHd0%C6&_o) zJslcPjK0!g=ww=8n>Tt;JUQ3)i?T$_$Z%V!1R@%9wluG4JD`V@-0mRw#=0q+Tt>lm4IYWui@Oz8ob#=r*jLq8lJ%1a?k^|&SMTm|0bbw^(!AXOPyHJ5ydV7Mw+27Z$1QH-Q-1BQ zIOdVdCU9=3Tj)lmFYcltF2s2C$-jku)-cAb4>JauJj!oem z!6mb@35iBu9EjTG_J(^C<`GY5bUihLYL*))dT%Bc0_Qn3)_d%_sF(xFWJGSLQB2-0 z)FZ&7)$9)RyFo?~dr&RDbtVsR=^c8X&P!@o&LW#{o~B0@%&-5JnF9Rf_$+N8;==qE z%_<@Q0ATT4ehIhKSNqzRHAItc+KBvkcCfFv`5gEp}HjY1*^ z1L9+&Tbg28GB45J(QPjQo8sz!Y-wuYTwl|oMa6U06)ZeHX;^TN zYZeTqv!gQT$PBF50~L*b;d_(v}kB|e}%b?;7Rth2ef>2V#)HpNjcL5Wquo5WUssAM!X*=pNc7k(qnQb1j z&)?%h4~?`y&<)gOB7$g?dr(nup4toZ_0-bRQnHN?0Z~IuRWUGj#+mpnMI0g%3DJcg zRNfHMoLHku$w&ccui^%xk)dCv0%v2^FrQ}*S?3|ajL2P~NH^saFFk;k-!Mv$U%vQ= zwb4bAAt7V7Cyz?$N+~Ggv&2QOa?y70)@9-0kE5fLCu8uhtWR9iYk&xm0d}wNSQDh@ zej}a4==}h^0=M^{Rr`4zRHJ~_(;+stfQKh4VRfAakj|Y&0y_x53z55s06^-bdE1+5 z7kBARGmYe(tYN(K>ed}l3J>lw%a9-{P6}5U$j_&!zTfzi=ZHm0!2T%%7gzw0h&A&m zc3kd-!%Mi)D=I!|Sz3~ko^_f@h?^t7S-H8xKc0j4oO^d4fUux}#i`}n7Tx>7mbSjA z!G+~AyHw&kh&Hf`%>qJh=Ot2*P#c}Hp8~+U-xya<6@8UCBxnK)NtQA+Hddb~>PQwE z$vnkDTI)=9=8CIlX^df=AM%UeBiCh1ivB%Z~Ovjler<$?PKMY^a9A(oAm2(JH znEwIG>k{^JalHl@-N+X_NbvlJu|Dst5+2d_T~$qt(?ave79AA(nDz}t(5VIlPBueP zH4lW=s$Dbt*L0fket{1%$vnoUs4I1sQ1i?Ds(D^1#d|x0#K6m}D^&OT2@hFW5MTqk{(jzLVim)r|fW^HoBZm%&F1g0$!<+{b zP}46#dLU;^IdIOlAhd`4D6cuNxofME)3XVf587V{xj~HvQ5kDpLA!tT1+?|y zyolEP;*89~XA2>rqfz@{2h>)sk51KPe|wr`>Grx;G7}7J$amOQnlosHmVE&sLGB;p z^`H=@?&Vs)JTK~GTeK--pdRvK|3_7p;|&_5xanG2inG3cB3)WU!){Ucf>>ngcT?8Y zkh`1C#vOA~Kxwu!qfxSRz%c9H_Iat7xYCmo#yvh!!^VLtPtRZYu>1GYb`RM1&K7pU zu5fKc&x@w$9(;#t+&70%vu{;{DAwl48r8H?IRPjpWU%g*Z6#A!jULz?6>kswq-^*K ztC?(}+(>~u(^hU(9Bi5s9?VVcA0VS<7Td;QW~57BgoIK}c?^T)<5rzoRSofOZ_(U7 zd45#P9X8}ipAK(Ic)h1C-Ns;ET5L0>O$iVtj%1bR(aK2<#l-!p3(7UZ&~jW+yz$ty zAHe@ndH->VjI}@QUj7y}RsPn2{$_9gyUMe7{AIxzdHjmJYE^^(NZkIb(Q~lxmB=aL zpqvB5W9qZ&n6f$nY3ZzifC@{MW73LTl9-77ymS*EukS9=0-w?&QY`I!-Rb5&#_VEx zUYT$_-uJfm1xpKmw)=C+Flt>K!&EE+zS@w8_lP1U)t_91D>Gl9vuN-{<0*!SX&J`J z!X*yv8^TZ_D(2Tg?(>FQog_Bs!R4P0?73Z}`v7dRmI+2x4{F210!JC*tRNBM50ANm z_cwG(xj9kD3Q z(fE8Z3Zh6bE9$HDbY6I4Jsw%XfY;6vRACv{I|Uvbo#M9A5!U`*VNEAW$@8P$zEz7< zBMMyaGmT=yNrJra&}DVP27Gz>RRXlN?1~_(n;>WI(f|@xEZam<-fu3Q!e-z2VHt^7 zvJ)e*KjY5f-w;HAd6-ip=jjsDlGerXW$djGm7P;D4s_>CCqftX+du^XmDI?;ZMaOA zN@~=*op%MziaMp)sPy;(-T+AhJF% z1tqE~M+pLCX?t{h@T-u#5E*jZd*UuuQX`qkMNmyr0KuCM=Qi6CGUq0gA%$M3uu1~@<-fD_LtFD2!!*#0AY@e@B_P?C#**zbK z41|leTq)pqRuFwvlmuKl?k#R>_$t&qQm-^hNAEi2Jg3xg9JQtR(?eroI1ge+6p7I# z>o4Eo`sV!`RUQ{ibZIYNuzA!F@G^83GX{D2`|-|e9gTav$C*1gjOgt44T?N7+c=(v z792B3WG|sfEe3im(AEq%%6^?<1Fq#IL1D5Y(NMYhovZ~kYC zxRnB>+V5uL0X#852ZZp=rr)xw`+z?-);S^0@chWM=+(QXvp;mHS~%JXQBV+@@pdra zz=cO22Sk5#3P{=6s9eia8Z>#Q)x0@_eJDhcCLSfD8AMrRBA_cgB?%DJ{ zSG$xe?{V#(NGZ|zy@5p@4YYkv$h;RM%(d^m}&qC7q|G}>RPjz%nm;;de zZLyaT{};Rd|F4e!-C;NXUr>S#lwtbo4Pyp5dI@$zYo;OdP9g{tjRK>Vrn2}XT-P<< zU0l-5hU^-d)bl{%_|n5}?w2jf54(bhV6^oi!Pf2TKFZT!%Xo6T;X{dA?lx{f#yNJdE(7{{fwdK%P%4 zIX}o2YGp_zR)Fmda#%4`LHRY@=kO)SS$aCV!a1*sWWSFLhcy)nhK zTu)zV6D9Z*XW)5;nkEiD?)2$H9EYJ$aU-!hWJaOtU}EFL{5m1lm5zI&lnsp8w<#q{ z2d54I^;}L`^t|NwNghj`wBuG|#8tMXzDN$Jo|46>wBu3&gkdvAstd82)Gv{s1Rx~x zxQ-p8->wOq4_ls6f{t$g=-|`IjkW8~XlW|LA7A)h z*xA3pdl*#QA-kv}Nm*APpD#kmgvq0_ z(ZQ=j4o!qi2$+65~lsKKcp@-bSC#y_CQu+>A!J!l+Z8Dyep zpu^2q#rA0d+d9bDcu&Gq2L68M)qNw&(DuCAFV+*b*3am6Cj1(1Y0O=V%}?(~-hYCUtPM#5G!kW120{p( zA!V~7$UvP8XXDXNj;cO@$~egEMrRgPWl9<&08fR)0UEwBu7Y0G1F9ZH#Q&}Opnp!2 z34v#p`Nlf96pT5NSuOPLS4otP1E}3>j4CApdNd*D-%m;f)IGITxBC>o<|8cUMnqnaZ1#yxo zYPew|0TxCTOifl0H3ECyTtnbqTqK7ov{{MD#8aCutD&>`r86aXdhV`hx5hEJNT{qP z3R;A&suhtFi7GM(M#>SzDbpFE`@PztHzC_0#h3EZxuMZ;Ofcbv!M?N1-W*s=4l%!8 zP^)%kdVvbL%1bT%{Q9gS`15@k3L5OAHhPQ-8Ii%D9OgyEVl~GLkDs1R!sN-Q%jK2D zPnOFH{15uvE%M6;j_()Qzm(%+q$?PBwgv6Z54AmWUTzS&za2loIu=w|^2+q3FkJm! z1oY-D{G100)8)mtN!&{0@$?`~mQDaSwjnc9T3#-%aiE!3Bygo_!*qBZ2!XM zpuUV>L82W|Ix!PT;W9BKhJrNJffwYvGo|eMVSUfYN7G?Pu~vVx1-G+=rpzWG+UiT!aSktQ#*LqkHyKEtBLj;L$3we)y2c}wL#W&D`1I+Og3_ zB_x|fYO~qh;zPC7j-o2Wbzu;kM*dPk9;B8gk+0@8AVcCnQcq$coL-4?Jz<#_Ti2fs zrkexmH#zGZ>Xra*`Y33@5D6sA{synI;YZni!&<%Vbg|odbJ2!7d!&y7XL%pRc^1$8 z$PL=8)N>i1<2&pw&T%YhUZ6g+K}l+DIrO0#nVmx^pb(3RBKk}U352jlc1OXD8#orT z7PoxefB7Hg{QtZ|Es&BNup-mPZo{gNd&e3eT5L_~!~K{>p3L0O|xorJDIx@%}9 z*{z&P*=GM_ue;xKj!53U*u)@lR*CYW7zy1ny^rsGT<%jQSy|)gz}$%p$m&j*&jZN^ z2*qVswCYIqNQ9Ox83@Uo^u)X*s6%S01I?T+!Jv6YA)Q=utK=ymWM0@0_8wlQAZ-On zL&~f-MAxx5JC2nx%6GCuCsy{at*tGw#FVH=`=9AHedqV(<3@DVTAeI_!KLF`_mT?9 zqzhj*b4bPPn;r>P0U`6lKO*TjO%S37YB)rfi#w zkUAJw2Gd@^FsYM+@nLHFw~O%C*htmX+7SN3&B;|=pNG}egU!fPdiok2-5xj69(T$d zZ^x(4e^1Bks&|);_V!Qk_qwU7v!R0iWqKISt#t9edQ2jW2^T}u#aJ*?&?aRXl_c4v zkWeNlJcGG;Tya`D_{zXsNR$Ac0lWCa((iUyj9ja7{6100{ zk}Lz~aA0{&1y#ddi9(x1l~b^Ks7}i8C&u+iPd3q0XJWHPDf#wr!w)Rwp?dxNwiGpf ziR-^S9(=FH5k@06l+->Pb{|l*8ohtOC6uMgCKNyfGWooN;dJ4(mgzU-58!s>rWlS;^{dx?Fj%t%d2_6dOBH+MUgGCF7vsWHDs7NL0 zCr{G^DCD__Ut?+`5S)G_xrF%+i2xxIQpy+=dJ=+KNs$&uiKc^q(|(9F%508b7sl#V zXp;ttinFN|J4x*~R%S6npp~%7+VsY-&rLPL%9X4L8%KSRaV3oW>q8`;y(9{C-ztmg zM=lwNwDHdI*Z&NXC#pyjS^i#zpWv{9AjPHAgg(k0H+Iji3|CWm-mDf+Ridot5-cu5 zh-|sDDC=Q8iboBv1EcEwG39qv{09Negph?20r_f5*~D`E7ppcNta4fqYg*hgK9#{h zPl_U?#~uf~!`xk;x`Rc#%RG(LTaR-d4|5PGjY+1XAHlg(I{4`E?Uqf4PbNX$QI6{PcKn(4)AcL}=GN<@G*PnSlzs|dQLgH*!CxHNy4Xn;W zmRxlIoG=72^q%&PV~8>B@{H8tpfVah#(@h9W;C%FJjN%;NHCxn8L+Mkv#Q``=`7IY zxea;oBM-#6W^5;uDX_P5Z{(Tz7c`EHt$fPL;45j>+BBZUZW%(}6OO|qPtKWs;zH-Y z#{bY*F>=?}T~7Q7p#CCU(Do!E2E+!tce1_{hx2>n*j$xB~^ueJlJ3B!YwVK_B zArNdp-4)G6ed(Z*NVSq|f57p}PTRC8&Wa~9!JJxQ_h8Pr7QsP7tqZL%cXkn@1_e-o z2*lN47BKk3oC};a@jZu0@P^Qt2%75EdzOEy$9D*;Ohj6~bmBlnLz=*vFeD6sj3JJ^ zbpoj;)XBXf&M~?jL-&CoUuZZNRr)9(82O5R)V(Rc%&=*l4Uxv&C$?Wrs?m@kI3X}< zR21;5k(h>fv2f%oHk+?}ipY5$-W~_9-XY^VV6eTJSmuN{dfyuYn617|j@UvLubil0 zXvrR9iU=DvXSh-vK)L!HeOi+Wb<*V64D+<9IT2%^k!@Xq>$ycE)Jy^4SIvS5-kLqk zisg8QxfS3LZvKmipKSjQf0*fW%j(bPGQVO{6s^u-n`j+uXSY2uZS~WgaME&=^JHp5IS4y1L31 z5F7vLpBV4~mU-EoR-3sl#h_PP#H&aQK^ZIldcGLy$zd=79ENcU(vc^GUOK{^MRWeH zSUfgdy{*`HW$0j{?GhzvmUeze0}qulJm#$AQ9v>Opr}%_CFf^jvu&1vq(wE|G+FC5 zIT z6LjlZr#-AQgb$>`0qB;zVusr_nbdvxovdOnY<(R>U;02!?PO1n{~YiqIR4c+m47>@ z%mL>*6wBbQX)bwx#Ni3?9dP;)H+(6sXAtQ`>&!iPF`m2nEH>h6JU{!uE&A zpYOke@feJ#?aNL#a@UQSdg3;lyJ;XNls(0}FlVe`{Xh`oN9lmWUGw}#!3JafFpqGHuv)aVw4qC?wTX+Wbou

@13h&;xyFkBEb%kd zvq7ylp=@|)Hk3D6uGM*PG2;^<()TouK>@d+3}~PfRP58c z=l1X~JEnZ&7X!;Ys$c}`stDKxv>7VqnLtwG>gCWDd^eE}fQ)Pz@;`8EKiYTmkvW>5 zEJnLY$bZ%sE{Cj5u#9?s)LB8@YMW5 zs5Zr9Hs^UQv;hb#$Y`zt+_vyYTh_a3@V!F~eZ*8gvva0g#iI5 zYXT3!l~Fc}SH4I!Z!mf+AzR=J@zjE!j5oNZ32+8{DTqC^%zeSiGP+;`8VLtVE-~P$ z1TJx#rlC~D_nz*F?5?U#qj+SV3>1oYANIIUp5A_O-s#_roE*IzS8iVe)_YF_PovjE zxzgF!iu)saBKhoC(4E((u~2wL;H_4o2px(&&}=h0h*5eJJxgD| zXEAnv-agOO7k6D)?Ve)9lnP=xdREGN%3t@#WuDbnI{nS6$>!5Z2&kgzIU`Kn{V7`% z!L)N9VE~EEWPssWcccIU!MU_D6bo#C!gH|dO;uvq6!sybsE<4J_Ehz7R77x*ePT}G z2y0Z83x}pbTFV~?FY3#U+X(;i7UY!qCS8|c( zElayV=jn7_o)hW|4Z0lwV;LNrV2t$z(D2t}8r_D?kpD}s?@|Bk;h8Pg4g5?1+-z0( zG+L)l@U9&9or^r>jClsn@WTi4z{l|T5|m$pa~ElN#=8Zh<(aptqH5M(uf7O-wLFG5 z)oa7Q5RJS2<5(WE3jcTHc9_K`*d+Ttla@;nPr2HlUI;OkW64qEI9RyXMK0V=k&CWM z&__zda+NY_2s)j{x_uy}%lwB@-VA(!?@zHJe51)G=r3s7paNCvpWG(OU;Q<#;6<5E zufxCmsrmlCukqzxfG6s6BLcyZiYGfLW}xd&;#ZVA)Mh#FY^=zC6Jk4fu{p?EaOE{V za7g0HA{&yp6ud6+VzsO;p1{0wjBnnrLc-1z3q;IC{y2F61r2WqGGnuYFi5c;14`{$ z`^;PMK9xCS>N~=bRR*K!6N4ZB_7dZo>NRYah|q-otle zhG&h`SFFwmdoE?~tr~Q@PtM@JDZc)ewWhZ8j~i-5bWuiyc;!v(_8}Y9|1R`bv)&ZF z`$Q2I%OW#mk6x3kKuBvMG~OcPTIv8H8m6T?6=>DHI-zxkFN0&&ON5p7?M82s_H!e> z{pj##Gv?PxeP>UK2&38@%Z`Twcd6Wk#p@ywY3?*;OJ3%jxH9hGNj+h=?u|_OzaM7i zjdaQK&E56nU0%=dN0C3>dXYZyn&OY(lyf;d_M?I0y<=P{YP<*7^^uJ^M{#^1%ic;k zaNF^x-(&f{;Rvtu!9fpWj)#UAN?d>qFOv4Y=w@I!QNQeM>KtwEegN-oIpDte=X!s- zQ)5E;?_;`|6~30r``HJNcYD}Bz}gcSJ(E0rSCOUbTjcLf#uBt8<@Zu<_|czeg4 zhUuktw;qcEK9GCpa~va%{wIUjLq5sDAy1 zyk~8bB~ZY!S_-#l)s^qG*BP&yb;PF((K%4F1X+pC*+nD;-F-J7cx1mF0FFkPP+mG- zHchNF6!?~_x&lS@ML*~Vq|G%^R4$7lLfA9{4Gj;q2~$*Nz;VD%*FPuJtC%K|^*D1F zR~xE3t6JjWp^J4DLIaSnTN~wLivz~GIi!p!oZ33ST9~0h4}BO~etzxm-7~tv*lL(_ zIu{GIcOR1??CE!ZsV zR3Xr-fM|jjFMMQ}7S){^@g4$#4bkcd=Tr_OvB}+O( z%bhBD7$5Z1U1W?9=*BGv7#TBfkrA;)7TY$n{8QZEMd(9O@{=_sT~ac0}U{ZaEdjE+tC(xtp)i`&98M^rH~)f?GLhkX-+qI9YLz?oT>u~cV+J_lu8SQ?OdJUAVXYw83;1>AbkPBKW@2#!kdaVJGZLAH}mGSLoE| zuV`Q76Z5kA1TZ=A$;0)dykoGq^zWcPD0Uqx(ZDrGI!l@i`595JP*_1{}TxwhEgugX;P(ubYNzD|h>oQgsV5R}~I<0%3_T28Bo zrVgkD8j(uDE@WsU07`SFJ9@Z8SQ_8HP(P=A!g*WL_(BwO0 zS-X|*VcMbRvP(t*y0(j;Kzk#R<>JDzEUiy{eq17J?C6zS9|32 z%RJZXnEyOZ^11c;77-UEzip`Q@eF`!hd703cYd9y2gyq;qY>Ns*!MAIu8Fh$G2z+J z8aV0ydcMr2&Hh7-+lq=Uv*@*wzrD&{AhQ%c5+4CyYC)Ytb1#LvG5imnigcee>TNJq z=nLPue~d7-%{xER(%qHDv-s~r#x#2qgO;nC#i}MGMBB{Wlcp{Pwfin=9ga3J=wpoL zM>?>#!A6yH$=*|Qj3-zV=hwdVtv40p5SMNgm6ygp8W3iZ%nOZicl6TYzEmL1A=-h=pdc}EmZ zP}ZJMk0RevJ^>s+8qi_vmmM&5q6m4RuvlgxkmPN`tm{kWBdQ2AwuI;_UHwOK44;m| zB(aaTI_8-In^%~*Mcxc)t)oMm^2FiR(C(H+*dE?Zd94BPZy14g1-~0H!RGD`xXg~7 zL=r=@)*S&Y?bt{3@eu$?ay1Y>#{5Od8roQSo+uNewW#)hLqp442+|JawJ=8^Q#Nkc zqaI8rI84(@Mst(L%8c-nqtu4P9A#73g3A=|#@|KItsc%+>#A_oQ8!S~6XMIcGW;>) zj*`(4BC8~z3uAyXxffWgk%cE6(SjUN(i^-Tw%S{m<-qTk1!__E%zf3;P<44Bw4H;N@SaB@dpCibpO9!Wmmhcc7&m42v!@n;Q@+q`G2nJ+~ zrkk%DV7@;~|3=TNgE%!Zp%+IeG0Kr6*e5kQcuSh&3KgZi2RaC99DBLy%RgF?vQ=Xb zfT$DrMj#ouTngD>TZ-LV6|NS$O8T8XvzudO>|n7Ou0R}~=g(4;h4yM?W%qV+ z^8tjYC+Y&{l0k-_`=02)2zBcqZ~^(?d|dDh<^bg`YM*a*v~#2#Rf5HZ?RCtyO_j|aW%6#{|Fm^4{HpMB8t%aX_Q z`6`0UPtfRP7DVZE#@j+Tk`LBx$)X=%@DFLC`(Ba$8F}>z37qLrkE|DQaL6AhdGN`) zvUX9U(dVcVmggQh_Ci?io$-&@77t612h6=0*|AZAgkyC~l>3?iir@|vB(4VH(U$@U zP8$!p(e+2}=xYMN@Tm-{tq}kW9;bWngtJ2JB#~g3rC&OA2emH?V zM$>uhzOv(Sj~!JAL8hdX(MU8{bwul186a|DFc<+Mw3l)XdgDkk`8AgTm^h4m%d+E1 zU@R^3H)_#Ae#M)-flL_i&UX?cY_$XFGgw6CMh+|{$6>PMS>RLd{thKdlOaS$!8lp| zn!t|afZb9Y+CtpaNNjSR{*{!5s~B$wY6I?2M}Hv4Ta9SKV7OwggV_{>8Oi7!5M9GE z>4-RkGLZS27JJ6S2W|Dy!3aV}%0cVwJ{j^e^jelQ$$uqqFx|kie;z}5JmyxJDSq{K z=mLL!zt~(1?>hX~ZDlMdjBI!Na&qnovKW zkzzF=Cju9TR&pa^c_3rIM3s`tEy92vQE-5)8#N0{+&vqhZ6G)3tO|yi8_wd9>A4y0 ze7X&q(0zqt6~tA_2xnMSF#Wor$i7A3eF+u@cJ#Odj(5d-s*}7uWsi1M#$Hd<;Yy=_ z*#2RCxsPCAZOGN8allA%LJT}#!_bM?=>smR_=OMWkf(|A?gp|Dsv;>;xSsmji{fq?YRAh#lfmB3Q8>d zuHKt{C8Zgn4;2?!!FHdIcV43ZAtHx(Fh-^>Ebh3?;&J6EdTOqXw|w|-_q%Sz2jbCX zSJFL~=gYNCJakgPn8n$2jD6?F0Wmk|hCdTOf?oPnVgO%dEXmo^E)mL|4a42a1bZk1 z+AT3pbCA*0l#gt(j^`&wR(lmU2g_4%->X$Jx{TYV()vc|!^G~$B@X7~Ls+Fn-jo~F zurZgKuMF^0((@ajR`&;E}l2QQ6l4GyL=C;Usz9znl`qU)8QyIhOQRN|wxa@lObeKy58> zi$jNh446ShOmWuM&RsG0eW_B!g|R4^jwr-HYE)iAN4tg>o!kx%sQ7Fl1C^trP$&mK z{F-hHgQh0FhH{b=7bB2@dmWs(6G632cu4DP4Vx`vkPoRxlPJYlbixBCv9RV@PfV4; z=x%-EgU^ig=Wo}YQf2825!sxQ^PtjVjcZUySsW-EbOY$Go$wnnXNc` zmJ%sVb8#qwqMX@CSv9ATaulsDF@HMlg*z*5qJaP>r!4<{kjDgmbF%0$>8$~(Ewu3C z>uA+XZ`b45dAlPERC{o(AxJJMe5YF1Q-h)~U1$27&7L{ilmI++KXQD0`d&>8RH!R! z)N9Ud%a&Q%20vv0R5N+9tr%gr2qWejd#_Tg#V1IPawXB59KrD0zc5F zt4BSzy=|){r>X}VvgT>p&GU18aHVB}CiM;u$7|AOK+2SUUys98b*Lt}q)Mjr&fxhH z!*!$_isKx{v9p?!)5zdoJgKSl;X-fx$6zh5+6b)KY|Bb7&-{u`45F2+y%2F9W&$+M z!wTK-+CY6f7;t{(BxTSYAn7vXgp=HDpLSsHWo&yO$qkC%w~$9(h`e@WfE zN;h~*;*KkqJC@5rAGtIsELeZq_4V?IFv@c$?(E5YOT8N9DoS6s<+4%Mh7~+!SuSxA zZ6S^yW1VBtX^NVNAAdgGk^yte4&_n&T07F)inO+c@Y1)^dXUHL{T`#wMQTw zMQex7L?h<9UdTjS!)O6~dwF$b)vUYkKmUJ4_J%b*9>rA?>a52%7nDNEksBy!n%ipA zAZVKhktuCZSChFrK5qOpg<9=lh*Rl}0K4+QCRJ9bRFjiT;ve?t^gLZ@FEF}%ExdqE zC8T1pQ3kGOm%VQBY)q7DI{5U!8==JqWqnJ3GQk}@=>799Im$~8W`!F4xU^Yk^6^u3 zPP)SZ@BJ1Hte8`=*sFI&{HH^#J|=I@!CXrCY;S zf)Q2VVEiITsTd`6_42Y_G45}qU6Mw)4OR_Zyn|PHeL4GZ|B-F-ZE=NscxXLi#3%YY zVAd`e+5gFQm@rDLL$~@yd>rC|K06nmN6xGFod_&I{_M_d&|yE1a9vk@BxZFPlGIt& z!dC=#LMw)jY;Of;Pt?IL?#;{tmkBT#s!SbCGDhI7B4EaOk&Vg@T02wsVcJ!DA3>(g@aJN`t){7wFd$;yOhgP@!D8BBlVju}*= zxBaIXo&mTZx2gEUH)bLxpvUugpf;yCw{ zG75YP?9pQPnQh=FLn|)~M<5LT`#0+EP-J(8&OG=>;-KKAj#(-+wS*Qy4#4XhHooDN ze(2po){+>)$n~9=+*qNqZ@wA_okeR1(>RY@=N>t$Lw*$@G=NVd>$eA6%IX3WtIES` z9>Mr7=I*m^ZwM`ST7rNewJAfuU?%J*5AR>tefDnDUePXEoB)o|Yk&edhQl58yrf<5 zPgwp7?%>*6u8@sd@(OLmKJeN!Koq}0GJ0lCq?AO1=_n7+WuymaTWPYQZRUfN1*c-w9hm^oTS4FHo8D7@Hf~ z$1i%NFS(<}J^%+oi-9#JkbN&!s(jI^Z4CL2&qu$3ISjaqXJpd{-Ij5li_@UkTYs&pYy;&C^xvxqz|NvEG)L;P^Y)6Xw zq!Gy;DvD-g>Uk}YOKf<@=muReRI@LpxUZ)nmtv` zT&A_NW{RoZvZ_8r?2i00r=^qd;rj%wjQ3JF`=e-5uWIEw-2|*2zGtW@+$v)vOcLd^ z+mra}%V6WqXX(`%&nuXLX2>6Ifwr3vKB;MBt8mi8QW6zcCCos2qOxuxWk0=!W6TPWvLy)d(T3`Rw8)EAMCvK;vXS(l+O^WJhd?1*TYdFB?#J|>C z&buSCL;~Ys=+TiMH}TrD^>@}{J}crW5ci*P-l*{KaCe3h)BKcj7!^wxj92UF9(ho(!|=JV*?)(=FE@j^%#()?1U7AP3yl7(P(NgZr_5 za0CWCVNl05=-G1s(6rp4Keu>I%iA;+yuY*B^LH8aSO=Eotl@Qc>E;v@5SbMou|@Hb zz7ndXlD%T?TlJ1~yMpi<|g#$v0^l`Gg3|}to=KzC2Gvx`Vk~r zFJX2yTHJ>Ll`&`#X^_%kU$H~f+s(ONM-rmbzE=4=5SBr5Z%7d}&}FHaxtPsnL|HW& zIAvh@mJPc7zCO@^n&-VgAVEbeYoNvIMYZ{maGU7K%E`jRll2dMYu*oP6iQt%I|02R z{J46|PbG28z`=V03wgB3Jedl1X5Nu~(cJ6r>TB32qkThk0G?#&e)=C;?f>u! zv_QsDU_k)@AmRR(R@>3o(ALIC*Ve(v*x^?vH~l~CEdByOIHZ&zHV<=vItv9moTmTD z8kZCP%^F{qERs}k=+f=wl6Sajxk<^#IH7M6-`VbTy_NLbUq(i*i zEt=2^Ki4={Ku4nON=kN08)l%DYCG9PtC=q_H)nQ3;V!_Yq>993p%MbNiDsx4k(Baf z3gMemD@WyynI$-aF8nq_`4Y!#>F}4d5-qM~xR}iSE{I#(a|2kS$ zw9|GV`~BFvp}QRdY@j=I7ZrPa;e###9&E3*csO{rt=Q_%`bIWTN{f5V{_4)qLHziL87( zw&Li_SX^{Wo7wHFZG1b%h$|~#a8tz|`Aa~$xRwlu76uOBaQ%ywRZO2u`lld53>75U zYKBq_%$gyoL03ZRhUreJPB7ril}j?LpX0Nv%YYS~!%p=Ch{;4ZYk!TRe+34_8lXMzie%w`?DiB2oyAy#yLi|;Grqcr zNtr-zvzCdVn50$bAci2WB({uW?Z4q;A!0tF4u+z(ZVAkzQkog5n=_sie-{J+f13ka zHUk&iv^TXjzdS7HVGmqV?pVFn#u z*qTz|7nCGl@kJp3hOB6|P$ z-)|=*fcP7%Lbx@^%d-=!I;~b`jJr>2${Qno&YPAEiPX{Gw}An_XY*yRQ_pA6)c(jN zcf(Ycs|f#^$wXn~)~s#Ptq9A|V8J`c9*=%2JI_&I9jHevu5PHhxsSfBnod9h8uGOhiMW z|0e35i!||eB&BII-o->Cn%mA-u!2&hG=N0n%fP$h?F-Sdi?^7^Q8NP(v?u#KhE134 z!q>F-XI=D;DElcFo#;luC;ZE_FBvb;KX2@e3XAnj7T)Ef|X`C zr_rhWT0%xQ+RILqP6QOR2YQQ}lx*Na`5>NhYi z=X5>`10#NBk|KkI(=B`2V*Z@P8K3{~q{OtNfoR&i@z#dDZYx zNY7i~ClT~?GHi4dOY5?oMG(rGyjs*1i$oP%wsd=mpA0y!>uk~xhZeT)b58N_VtM^7 zm2X&et7*Ib^u7Y>5T)3*Y4yQ8G%E>K|G{cI25oH<*QWrH)@t&H#9Gp15+$F96gi@R zrOQizUea5I7)&^RTy{fx#cke~9Eg0xF+iWQ)ave!P13KWqFBQ#YLLS$LfPs|fCllJ zsBh{cJ*uDpHDwZ8v+sOnkU=T{yv;zzbh)D9reE21ZJp!Im;)JnxH;d4g<50+>Z?#Rc2-^y%2U#MRbRRbs;7l$5yoY;q@4&!1=NEqT z>(v|)DQ5UBI{Xllx}wO(OKM&K5;D9jX)HD*v@1(GfmzsSV;2;wfY!AmmXiXgS|cA? z%{Ym7jg&sSPrsG`ndo(!%iymo-la`cf zuRmSeaTe3|2B3d?b9y~Iuk`v}?&@yCn7$y}zQ&{+1*8=Qfk_~C#c2^EMrUH(`J~25KJjAA(IRKi2xRa46*_XO%9lr zPn%tFFh9)WathT$G zCiGonh{05rin}@QON*FXVM#ZuGliZ}vBy;#2l+0CV+Y^j5%9IydsNk!f8}WIXt?xo zgrCFMb~yR!h5gBsXI+)P{fa6}{9{FQ%wl8a1Ow(z_%!F}t+_gX2q@ey56B zPEsuajG>FK4=O}S3aLhV6>!St@}sf6ZMK<3Ue|7?D#2c$u;45&Z#7M`y|**81P%YC z&FBB!^|MwE9PN{?-iQR*;^V!R#m8$!ho@+JV_>Y(dFgmAcVPFH2Ym%)j?8gy7QV}M zYITi^u+2p$;zj`PCh@|8;S{ec_#8k|=j`BwqkvRG07%y>Jj=L=*Lpv%oxhGIvPBD# z*?b<F(ndxR@SfI2-tRG)j+5w8NjAE1u2aT$2@z(xtz z)$uVsf1h|)cEKRypd`#3R+dR^v0r)Wcdwdy_R?xqJC{U9aIq6*KK+b|99%$TF70Um zUY!W`$N#ohu!xlXv9hE@Fo1X5LCUVOmXkU^!ozyrgyQjrbT5Uaxv-Y4u*`w zVoof9>Og9Y$yMDL^J^i%Ck2(|wV~>E$;`K{XhxL%^?$Ua|5=@NC~VB%{~D{G=m7v| z{-0{I(B3v|mnQ`PiTn?EjihoC92Y&&3 zsUu1bj>lXUA!A)HF0v<8_(abcdepEeoCttR2t~OM=)N&2U76#kfxLUQsmRecuzPpG`potNTjrVKyo8k*5~ z+-kL};dUcr#X8($UxE*NZ}*3Viw}j~7+1nNsO>BvOawGR^!h#%?%Z5Lq-$r+R(K5q>3XJtd-XEYSGSj?VSMoNTAozv*)i~u|2rC__?EV z_b0Cho#?f6XVb4&R$dFjqOK@vKYFQ$~gwat?rtZvl*Dl(T zHiB4RJVvEYs7^D#0QFbPM_AvDc^8<2LM;TG-vnw$Pt@HJ9EgB-VtFf$sOcg6aewtc znP0YJc$KxFI68tdyhyH%EpAIHbD1zAMP#D0D9UWD2|}YbpQk^1oZ~I$INHC~Gi^q# zZ*6_DB71O<5uq4NFH$<<;B8Au$I?_merC&nF9j}|l^{+iQ`IhvQ=bEoP~FuEZDyr% zOg8Ay6ew8|e*u^jT3sHNh7=7TD+v2mL|Lv7hG9ypbO;QIb71-N7h!``YubQQ;;f3L z+)nlHa>0s-JLY1VEL5{J!i7ytE-MGGMPq+SR$i#Ge@Ef~s)gVWkGl^V@H!kG8F12+ zp5Fw0GpOX+v0lw2>blG@CEUyCFxd#6`wT5^;qT-d=RySz!+u$-7kPWKs7w_ATmMX3 zp7mwZv+dZ0nmV!=xp2{B6w|)(*qdB<-Cm)cY+?SCtT zQF!YxpRVC2diELF{0D;)=(w*0)IKRa_45*f^cHM^L;}3+N1Y~=t=ypH`G{oBp3twMnqwHrC*)*j#(hi{xMT%{#!D2(g1cd6{* z*C%QiXqSo~392G0TKW@WQ2o()zC78HGF)0CKezXjrW&dqXh4F5b03Vr+f^i)J=rb=E|%_z2_8`vATz=10qN;xHgS= zF<)KEsBZ5KQVA26^kspCohr1SIy41>vt3#7EB>RdaS2`i0~RS_63EcfwKL5*!(1Cb zdtIn6wcJ?&_|k5P z3%x=8;+;(Qywod|10n{dUG=#Rnp}@)>9*_I5G$&R=c8mSGLD%xBk1}{Vi2;Ld73N- zTc8&xF-{G=aNOEf#9Hv(l`N>IlF|j*=*IQ<5Y^($a^jce>k4wEiVqaj2 z{%+EB821WA>@koZgsM-$G9|VNd`8ihmf8gK6?1P7vnkcm1)hDB3D5)1UJ88+u-rhl zr9fC_xuRfjyp2`;@SZ&Him{m@Z5w$l3*&0#Gji-|^_^n+zzfUl<5bx?oi(L1Ae#lk z6C)9)?_$4?@b?26p}Z#}`SkA)VnhAppIK<_i5Q2gFxTY){pSyc>w-hTBILV%(GJ^& z51E0!=|Fw_jq|vJfpfoC{P?F}p)|j|Y=<)9)e$i|tufZtU|%WiyRsn2*umIYPNoD^ zfQ4fOgq)nPb#$A)hLz*!`)jhB7CD_B-jQm-K0?SY{=A1hmi&G_NOQ0m6LZgW?IaO( zq2h!i0icdol~o?VGJ_M8oN5!ol}~!SE%w=Eh6=^hR^m`j#P0(!wJFc5-sE=Muva%G z^`$oDg#1OyxDO?>z8jvJHuV@jc@C)Xux<}!5afg}pkfi`4+kvYuM5h52Qjr*+#Imr zH(U@_?cf(QGDiVY6b^W%;mGY8`vDWhT~~J0Oj|L1ua+X#XU6?)o9(TjE30ehj+(<; zxRD#rY*sY?>=oUGlK@5vp4Iv=hQMK4ZfvcdHCukpj>q;yNO>)N^pHfZm3}0rfj6#1 zx*9`PGQBKtH)?@ zy#?lfZVX>me{V~|a2QVKdaID_U~?}h0$oZf$Y-f>9%H{5f3fu2>b6pF)>f#w%y*&T zCEq&0Qu;}r#P$fC+!vpPt{BOJP^7p?hXJ(x>}9;8;xPBtge}Xou&{HQjmf3QTsf6W z&|Ot3;AtVBu!aWPmRLAMg2P;`T54ASySD(guXZ_b8gcic9%n?^1lb|arvt2J$a42SPe{!Z4wN{HBYs*<`Lhub$F4N2zj<)jUG zM$OazqI9&dqc&s2S|J3{Uz_ScL0=GqdhaoRfRSFL7}nWPWG#Mb!ZeudhoqXe@t=Pvvm-(a@wYe^X3Q(BlIas9|3%IVqZCAr4OTu3E zPw(Re5MYlUCi^PmCFmy{t=JK1a1B{|nIhwE?L%_RJx*{Y`L-a@V&|ckyQ(rz#o89& zNBF6aZa2&aC?+LvSKdY#_c9N0f8YT9WGq3)0cKyL{>qYu?-=HxG}fa|!ZPZ=F`G9? zAKtnICWbU~rMSPctW0=@**cTgBd@_Ex|K<}UG4ogtthIp2d?NRe#MGXF$Nx|)4gAz zc4BkPZ>3|*Ju|w#%S4DZkMLy9VTa;D`YMDwn=WyR*ot@n-X+XYYtzX#SE8adF5Mp+ zGCJ59v~+I#Q(K$Ptgha7Z4*7{-w9x}A#ImMoNiX#+>(G}^VLOEUI=NltD{$Uq9H)d zq_{0pcIxUgsZ}WD#`TD*mT4!P16n%8A@B(yQ%WzR+S?rWn{=&9!M{qEqO+SRSIMsM zbn-}>Wf$vWOXsbh%|EkvF`w@)*{84Il36)$&DJHKc?RQ1%B3D?@rN%M58d5Z4?Z8M z;JO*K*=KjpDP5V(bRgSPocCv3JNWmayn}a^k{8&BwU5P7BBs_{FiWw3uL4fK1}^YN8mFl8Y(E z=Jd7L>qN_0!wg@jT&(SVe|zz~l+rIFIn?8k1X%oRt&N@S?mNN5un77@3KS#9UACl3 z+Bvu4JApleI6PcJ#CJ<0f7mJu?#n{)=NpJN#m2cjtBUc5#390od7& zJ?NV`>}cq-Ore87ks7!*MZ5whWdB9?h-D-ft6uEzovrksxpP01;vH68ZrR6 zjz*4=tG1>>O=9^cBOnD?EGXDRm9Vd$F0HSxw;Ms#BTRh>oB!EVD01JGCBm^sHOKRK zOSB+JV%7ZI@AFf!q#8<`&4c2O+S-K#5fHOpuPswI5FyTDm?HYgiqfFofFgIy+4e1b z)1(bAiS<;9i|W&Pm`Er8hw`lI}@t#?(4JYPX`-G4=G<_n1Cv$`$Nx1$mBIn z7;O=j-}*i`XerDn+y5R`7Lq?5p@TI{WBSl^+h?=`1Axr-H%Ft@x)(qeRDi)3TNIF{ z{U;90;Fq|M6VJ0oyfuAovojS`5?PKucw>g%Ag|L(#8&X`GL2HP@X~MX4R&!1vj^U2 zH9F~AV`vg33!kx$Sx0{-JebJ1ueih$R2P4yCX$MjeJ0o9G#nl8lM?v|)UAKQhOC5_ z`b4Ku3RD{`^h#ZvCkPCSM`L4#%15U4w+SMq-KypUsB=w|5wGy84x?O8P(jRz{0yjG z45&Pxrhyh>jRZJvji7E}mJUV$P>&7~K^DYF>~UAspi&t>#{K5V`d&Ikk-iERH}{?L z*ZjEv%?BO`JK$z;s|pAp3y`~?2Y%XZ+Z_4@jerCU;|j0i(w=+dN)ime2$CBR0VAL9 zzyNM!D8qr%3v<@UpBxjlMuzkc&Icxm`I@8Vq*(gf)5yT29~4?hUs@cIaWqx&m)UzR*o0 z>u{j4rsN$JG?q%FUXQJ5Q_4v}kk@3$tIxhr?S)vPpiVlA?ZCGi7+4!{K6=&)qHr(F znNiRSHX+t5*H3j-k#O73pS|HL$)M1(kE#I&6yH`EI!_hub`q@S7| zmiA-q577y6{d+JiR&bn>Td4d@LoQ&Ooo>1sm{_~Z$Tc6)4%4!3DJ>2cRN-$b;A$k# zR^eB4Ee4w}JyjgJ=vRByOfNfllJ@MG{~mAb884fJOr%A+ERe)#V6j1!BHN@Um2)yv z%VuR!i8k0{OKG4<^bZT#Iv^}Bi;0Eaxa)01J}ZdRsBVbPNX8f$HeA*JjMk?jL|=b; zFCOrre!3}kip89zsqpu}LJSmOhYSTUCVf8MI481f#3TybKN?N^n|#2V3BT<#f-18z zolb2wcjR=SkKBa`Vpop>b>_sDfA~Fns}3ZgUP{0^YwBw5V`D9ywJ&`j=g5ob5S~a- zCo>JYq0uHCop(9-jx-@Fh=b1*q*;1VRhlxiIm4N>rR!>=u~yyKn~tDjA1Rrc33$yQ zzkYQDqn4zOKXQYD@|aywAj5z!tk=VZILFpvK>mc5+Ee=s8||EdU4ii$>pDSTjAog% z^sZiEsBIddJk4Naqm6A{v|#E9ew98)EVwCx1nrS6|07C6}}rn)i?A^Y@bmDtm-Ubg|XkJ zyHuujxov0bs%eaeXWYlQR8zeS`5$9V461l)PzZ|g|Hqd1Z^ z*@^G>{5-CBdNSVqE<64~yN8=X#7jsfDL_2#e#`o=#9`*1TYPSm;$%^rhPCTY3k`CR zVv7cf4XnCS+9E^|Cx9^&trN@}K&KIh{CmV3b=EcBxQQ>?ygGk}OV`gk8t07F!z2qk z@kQNK-@h>u&*L}JJLxOgDNrW&&;U&4O1+|YyjJ^Jld$m1>-cR(O>Z2lB~2Qyd$ORl zS}Qdb)rm|Aq)i_6r15gA^!%(~cNg(H`^iuqq12}E3XSZ^RPE*!ZCuc^<+zi3>H$La z#{3Ui{6%aL-IQC*H&{^d5?mJQ?69TH$qo*3cLkg+!OPxU!H z64~hvGXgOO5HJuMj(WsoyAI#jJ<)r|Y@Z#PxhNm@#=KEM~{a zKpoA&(|_*Y4XYIL%TR|I^7cX%x>Z(%{N*65{23dPDw>)U_?h17>vPSPTSdI(F|-tI zD${d&&xzOG4GQN*hS!NAz6Rd$&g>l_A4=Jw`Ac~=m(VKLV1~w<_TFaIDcM&#&K20; z!DE(GsiPRJlVEDvmoqnDHE=Z!_G|C^StzP>2j8))@ABzARm(Isl8%M6tGaiU1i#df zXj?Eat&Ub%g^9E)_^j{q86?|y7vy-2JzyR`geFc&3yc{<_#KU}nU~R;k=w>nFizaM z)zS=yY=eUG52&M(fP$NgR{^YR?rN>@;x5YCAH-aSS3acSpT%qG6KnFK9pcsYjb(s8 zxc1jysxHP3{G73AfgQ#M1nd28Now%JB5`n!MAjJx*C@vmU#5I987vYiB~=FK#r2kn z7VbtF^ri&-jPgpm3CIw9BMjYYm7cu07P}8V6vqj7cKCr8eyaENGZt*V0RPWC_@C>3 z*3BC49uNRP{qJ!K{eRAbPQQCoU1ukAtN*M`tCZwrH|P+2W@^}yLP!+EFk+;n&i%}n za+A{Og32kJ6C}^co47X!9(QgoixCDd?&sgHJ7#vgc|Krl3ORlWqf}ZtT0Z`oLQlQI zkQk_(1X${hOm;Ibn+9pUjpll1vtY?)ay@c%Y|zsX^0{{LQtd;0W^qma$~5{^6_;uNt}ay`abrHJH;v{nA^0K8q(F*<1C0#5wog1SX=U@IxC2B z!vh+jAsDA|1B~UPOn6-S7Am6;SWQB%_VcBp>m#1*TT7!3HPF%NERW0*+eiWx5jJZ zs3(Q@Nt*j%3*G(Pz1tU3Ly*3Le?hrf!wAzD|AjYyNNW*=g=6-3>{%z-{~2bi~X zMQ`6t@r?~@0@w*?1@dE6_X84*IgaF<+|>q~;mEH+vtD<<7}FRSUk4&4&YSy?ZMC*B!$SXr2M8L1Kj z(!pO}z>_M%N<1C%b)~5c@NV!BZ?-nJelsO>DS`$S+%6B0-O#dU^`|Yte9|N9+VWW@ zbS({w?NVZO0spZ-jFRYM?P*Jm4fVnkx1)DCUK@fupPj3cyM{XiwHcz&P7RUmufd&C z?RVtHE#6o0*f+`FvwwgW>NqMRU^9;urct(l{ePuWXV9o143otG@wLW=t4UK^TZX7i zY7!!8^BeuH|F%!*mvU`&=W-d6x(xgnz!6sUsK79 zYUneWdzR-}{e13SeLPIOQ1hkoyjku-lB*N>$ZkUX_8+{1oD96yNl&_ zlafsa3*O4Se$u^KYcS;%+O!u&Xkgp-2Wj+P>H)p6U`E$RAF|2>al!tVT^Us{YkFt! z@AV?(OK;A|w(Vr^lX;E(CD<`(^7bZ%F=(iZM2h1xt{)h;jn(&O_`z3$e;>HPe#{5j zgn3DuEn;toPo(+GVX#Q;Y5I8+BfsUG+X?^gN^-?EY28JYmpd9@MF7CE9*qg~4D3GA zE+h7zr7Qqw`p z%tiD-_E-8xY$Nzs2!t?5X}IzINXcI3<6c#ZYV_$`YMD^BW!g}HTy@7mwXV_?ITalJ zhapHTPlA|?Md`H;Z71qja>$yNMABIMN`3!_c z!c#035o^$bdsO^2{It>}IVOKWfHdE7^*g>cQ7)Ok{J>|TNbd2Z@0{nb8b;^(pcvr~ z_ZLdB$AHG6m|t2wdOoMX5@EL1>69#m;DSMPh!BI~^rPO-)Q%${P-3AA9D`sMnP32? z%B<)S5Kk;&qp@t-=V9}}=<~0%AhCaEiB22znN4B(CO=6oR~4`W@MxkyOu`H2>)jE+ z^vJkgas;}{X_;;CAr3uQ?GF0bN-1HmpkhK}$0IUv>ritP7y0Q|vjLw=zwh6uEQ5cz zu4Axw_mGqnmU;CV4tpJKg}H4Q`uChLDo7UG3SY?$2>zv8>Fpp8F5DxZDrmul8H{veACvF-y6#ER| zF#${8!t^?@N>IQTL1=-e&f~lA6Mfl5vmhsG1nKxcSO0swss*(W`Y85-t~Nb_2am4@G~6ZsAdi3m7-@LZAoWqp4R*?nQJN)fNeFM~a^eXl6UeshNJ^<|v%q(X`2H~<>Ig%Up=CD4O&s;Y^>M8iN+4$@YFgRh z7z|MNghDt+7sjYP#|`oL1ybUR1e)*mqE*=qmx6l1%mH&1m5PZ+p{AZ(LZOet1-3=Q zT1Y)`wVdJ zurI$FImG8Tg5J&V;2qPdntKc>w>S^LaE&CY{lpsSxJ@{B?n}TzHXPjx(5sQ^7ds|3 zWK^!2{7-oPU6L@$cl*Q3@UdY#S_wpuNu3ngT;q!j=ue1wf+F0sdYAqtQ@%I?-m_h@ z-&;ihsaq$7lLLs1A|_iN?kc&>I=j_+7@Q|mF^i~#xm**+wMV?31+#T!zthn^`O|2S zLHKOL*mSurf#PeExR{Y=wfjOchVqo;fDTQjKjoRZ_bMBQ%WH_XtE+eJcbo zG%+^iK9Dh4p7@+L*A<0th!}8oeAaZ}q2XhzKf^2A-<#U%T(XruHvjYlW@CoL@$V3k zx>G-VE6^~$40Rw_{%pBK z$i-%$(2ctG@FfDwv1}LNi=}3gT>-g~Gp)dzfPwISI#|+2gtn8jIb#nhtde>B!}SdT z$JwWj!R&S^Ki?Z>`>%XY#gWWzF7p=Ohpy21H^yl9nP;v?L|rFOYw(^k7x(uNX2;=h z=IWVCMjm0qU<7KQwK2Z?)9oxIsoRH*EtwM2F1Iaq8|0-=%<}$Akwfx%TZOhZmD6rg zL~em+90i^N&m3{QGQr0Z5U)`Fr~n|4{fAlSQDP8xNC7c|3(d8$2n&-^x`)gfaMlD4lp@H%BFbc+L)PIOB znpPOtp9_m<3{SFBZI_i;{Ln@(E&{C$r~B>2)!&Sbo78>B(22~xhstZYge#W{U&|-E zVT_M!wa`#sgd>&IC4b=(X(?k~M-cRPrz9pL0 zCB6)L^qs7r@HI!pSS~i261QW{uQ)e%DqYLgVKoG1*&qp7v)|#Tk9hb%9Z%Jz`${$Q zyhXeE2UaeU7`SkAfu?*H%q@;LB*~@KEd?*pL@2iYsL0nVF#mvDd1ps_>Q=%m{64PB za~B6P;?RUuxP<|Zq8yCo?UhRwmkJah0c`Tl%POgIAhnYzIh5 z@uS98#qz0AO7gJ5e&SPc^Cf;x!SxXg6w(QCs zBsrNbI+I0}4N;a>mq%b7Ddyv_X^zkCk^7||WD#uVMaOhG-!hbhEfanNiBjTaG*vB@ zFp}=qddi-G3T0o@$(WOmpFwYxQ)bs=LNiIqB+T&|W>FrB1&gEd-HE;6SS=wscPx0KOEKJwnG4fxZOTq&a82)0f_CNc=>{ z)vXOce_8+o2Y*~w*rQN?k~%S4Xuj{3?tPm+b*X=Fse2TZIe{zgG4C&Hr!)C&lzp8qL=^qw6UJ*(Sj7LEUSNM5Eb$d^5BVSa9 z3EE4q1FdvwB~Aem{HmIZuu!Goeh8Wk3g&g@mmWGfzWe|7zdR&;phe8#GkMjE&CW@n ze`CX&AgrhE5Hr&r8@S@|a#VU*ogn*;Xkal#JHk6p+ySIAj|2Ox=I%_V1q!zc^Gs z_h9$O&K_nhh4Pre;=kdt*Gc)ClNk318wm?mCLhk=>PdFcp)%(Wb5&K3d~1!pL1gNn zXBe6BRo+v4&;^UAI!@OE#D|d<4j~zgiy{;wVQG57Gw{L zZOqCqUXOd_&I>S$9}*6rrwe<=o!8t(}^^T#Hi+oEYg#pDAMJgaA==t(dOrjBSP*bbH<3H2* z`gvSkicfEzzNoaP;@evar(@x-=8U5{N>O$Y0#079cZ(EK>ULk{kaV zvYw?a9TtdJ-3~zRu$jD=j#)5O8p+?Zl|SIgB%w%Zh?hFAT`$Hx_=o=*2~sPEb~c}@ zm8);7E(asvMt072E4Rysgh7C8vmnNoGXMNHmk z5>*=rdI)X2%JD3;j#D#X^Fhuvs!ZU)r( z5*a5fDcRVWz)~zLd58(reMkFhnCM$Of`^|>4rtHJss#bbGjj@KHJys=Ren-MA^#|e z&1l5-sWAXX4ER`z3bAQcNGcSU1Rzli?dg%2R1K3y42(MkV6-xU9ydkJ5rDdg@{9H8 z)JCO@3X>UMj z@o+vQ?+(9bN-dun_?<_H`qEt}pIL0ICdp3Y1<~DoP_D&-;|dtn)LC^A;f-#qTAHh0 zB$DU!H)wu^VgUCcS=`MZdG;;QWWWRgj5_kO=hirbx!ul@;gFgy%uH}ByiiOB$vOvN zZ9-%!3keHD6&6|bN@>%lpM=`iT+lSxWi^(G+_n3!XFZK;lC1WxkzDGeU?=H75X7q!a+aOT(AK+iy{MDeXcRyyH=p`+h}2oP(t zq<{FbT3;BYp(RibF`*t;kN3%Bam2NJ#|HPYffo~Q7dGC@C0dPc*EL&Bv<}Y3!GS8= zO;7PSeW{)Mi?vmFAaFh396r1G{z4H8yS+m)Ba8^~X}gfNH1bs{7Td zDD5D$snhB~NY%I4et@`@wBe3r?MYFB;vGWOZ(EU?7KTKZ`V~#ytW^Z&wvfK<8v;!7 z;E^6+ep&#s_O_G2it-EsKiAzYgp|}5ky92Y4sd#UEHmk)WNJApqjpYvwB5}ucmxW4 z;~mnsu?lE|6U-=D^Zc+|9v(jJZ{U$kz&lsVo%MXL>RL@Wz&snsk8I`JvL+ml>5wi6~1y2!v#(nPPKSd&%~;BK3ia| zA@85^F)&uLIg{bijDJSi5WRvb6M)~#t-T^<47TjCcx8NtfBf!TOVR?&Q9(P+4-M@M zR(ZY zs-ainv~*A4oIuC=>c=v)V|0V(j3|4@o%1v;|G7VoKYO?pB6fgs5k*-pw84y#_YWQ#lJum=9Mn`N@*VD5 z-J833PjuxV3k8BJ>lPyYRkMdno=BMsNd&9;HRu>KrSen?;Kwx{PG1TQ=h?>21Vm7e zH$WtHx7&v|0$bFi?m%&XX{G$3q>$%+ppsuf%;kjV5IiR@S;eO|0J3>NqIHP^8>hg` z?#}+@pI_i~c47CV&F?kxofyYhxGERaTa;@s23xP%<-=F(Y(*d1!fM2BkyDi-K9z3s6JFiRC5uE9(P;)Wy0hz!f6cFxc>Nd!(;q|_$_5In)X(#)t1 zYO=Wm=u-#_NbP_daA-wC349eVD+Hw3Yl`ZLWziFR!X=0c5>Vp^`R7iLp%n%C>oFzb z0nZ09LaSB!cF8L&70re@&2ER5Cr*Hk)Xj8(=8vN}SrY1j7DUWWfD3cTknL*qH=-*t zE{>~C5*0Y?9_+Y~&s1_hVzwBhEHbJvu}~3A^^%Jsw`ZTZf*rRlM6fw&Vv9g_K z{Eb9buTGEkd8&iKobLp#_mG06aEv(Jvm5pR$a$D~F>vQ~_9mB8T5l%vmo%vP6t{ZY zEMX~M@~BePM?`)@A`$c7c(_8xh8VMg7|5g-H$uhWG%`)V&3UM7x)~>{>Hz&H`QBr- zOF4CUDRP`m7yFq`& z%c%p>Ce9U@G|mpvj#&}*ewGZgO7}?mR+}4vMC4D-n-GkLVk$l^>IZ@x8jX$s8nZGk zupX?V5Jq_?6)QV=WJ%t|gG4l6 zbK5W?a%?PC`IXe}*mO&zCWr)*U<#URR~4Mn4$?EI3q}}Enz;R=PH3rn&&^C53R_1c z)zq@$d|GYgu?F*D@-Lx(Q*;*v(<9?yOC%en@0`rOmq$#<5^aShfnSsz${yLn#tpuJ zD;^VbDrS9>7xjW@)9!lKfeyB99RnH}-mdgA%c@KHgdcZ9(X!TP__QdeVS;@GrcL1NNab9;=UAb;;AJF96Y_;}s&n@(`HeGQ<9 z5BcnRsH@a{cu1?%Svq%YrR4ljHOT5vKB=WCia181y|?*zt8@ zpaHztE;XTxXyJHPRw<+t+`^*k3nf7Rvh?tc>anr8W$mg&@=LZ5`nL^(cH>M0UL76W z>=BwcF2t=if}BOe3jDx&o);1jGpiT3dqDXPrXyPk05ID4w+U^fd~ z`51UT)wiisbW}sK^QrCv%^Ip`0PUz2%7N{ugPoJJySqtM)`d7XMxpe$mH`27?k2H) zLyvvgKY$wmK$j7{AqmpcqF{L}0Zwcc#KK!P+@vs9o zdE~yKXz)_;Q%ZUuWc!@67l_2dzRR8&(OO+z3JgbEXQQ36?fEua|0w-i_-B8U)y1a?({YmVx?{_T zpLq(z8kI*pGVpKmS&_W$@E8$A@loaVwb~Bkt^pTNKk|H1?;MeMlAh&}t$L4M3L#H< z3hctJZutp3svFIB(^{ZNd6D($_fmn1lrTt1zt30kmeY2Rg<(Sx@xvvpqu^j~EQj&B zkv7BOX9ZN9_PR;wVFkud=}blFTM70C;7dA+0zeUt9)}&DF^Nsxp4XYxhl&TfUwWK| zRX5H~W?5I|+euX`4M-m_7XIogqq#v%Taq0gT%?y_QS!~mdzTbLDK25Ugzd5+>MP!6 zImo!e1V+jb=9iKDcAS;D0}+_YQSTG3PaIq-oFUJ9 znysHZ`JN2TC@zNWoG-1>}taq^1P5OR&#DY1q(g|(U^3dSx2)b{pz zspGM5je(Eh%o;dQjY9G}E@-#{ELer!6+-I;YT*?vkxkB7xC2yr6;WlD1$7%Cs8%!_ z{?o2f&FCX#QxO0=83A3QZaG$rKyHNM5J(Z+Exxyc`xsY7rQ9JwR$-;$Ro;@h*!pcC zck3im(JZ}GTRiV489}lT^7uZr56+?Wfqh(FfA%{M{rtBO$T%vq${)uSw!`0MaugW= z0QUbC0{O3N_kZPor!XyT4_NLyd_m)e2oxwRuZ$YHrusr8TQs%oHmu?bRgwDyBqbXM zL~7!duv<=jdobeT$qxv^J5FF-20CzE+EHkx~YjNV}uk z>Fw(NWdj9cnRxS%C@v22(Y6mHq-_I9?VpoM<3EvTBSrwB4Grd4&ga7#K$%?w{5K+L zU9QhhqQ3~yOT8?*-r}0i1Px$O*wMkxN1CP@epK~j1#Jaz_H}RnI0vu0hk!RzS3660 zTdT8!1h0B@13pDpcT>x&HK@gQ^zA6-P!}yl7d*(nTo2<}4eAgt^{`aLn!{mAEyQ10 zA3dHp;w^jxU$wjXG(FDwu2%?dKQ=xrQj)0{dNYo{$-xxLLi7`zI4i3&H z6@*vlA1^t-Bc!alnhkF-9cx;pP~4yPMDtPrxHYgc=)6aCxG2FyrP)Ywl2=-)G^tCn z*Ac_7Pb6o&fzI*Xxl$tDBV+FE$i7(hax#=&5rsim%H2r^E5uCiF-I1btxzs7xg75^TD-F$Lr!Ix;DaTm@oCPxOCGj&QYz zNlJ0OTO5&l0BmL&{BsP{`yXWBulk!jix7Z%KklCz!!g~?yqm>wV4#$F34D4>w(>Ck zqL3GfQ5#R(_6D7j1-p$AmRYKTHTRfn9#N)z%+$auwLEZEpUW)~0!~H%W@i~$B(T&6 z1&H%yKwFHcc|Fw#@$Q&2^R-w6W}$6TZh;xY{a{8rPgI8dw)7qJ^vmXv5_gMS#*L4a zVUUO6R%vn$LPHwa&ddIp;tOKibMfwM7b{fEluzkNP8j=O1!!_z ze(16Lnp1qZ`3R;7L25zxR2oon>IqA6GU7pxZ0^HNBdam;~F35 zb4~7Sg;DBk@JxN6-wz!Yd+zE>=KQ4B)ijxG8|cV{7qn8y^AFL$wd(oJ#G1RO6x5|0 z>^9G=^qH)-uXzzWfiXEpJm@Mo$!|D55t&~YE-ym1ZQhuDHBokfPnFHi!==wCc9XEE z^>T{vY-Gu?r;L;|L`EOd`e%IM0lUMF+-vmYuh5-(Usm!21C;#=$#q&7Sc0$++2C*m zxCI2Bfj`(^;yB+jh&?Okw7XjJ#hN9beRCv#dCt+V)Ny$5Dgg7M7a-Cbq_H>(d_u3l zm!>ItQ{kKoOOY`_1B9!t?C35fa9j@>V<#^vIrR#;0|+Ndv0(7bG|Ug$nHAZfM^pAL z3-;7FvH2$ZS&wnDpHjtH`D&4D?EtyEW`Fv`tRRmnVC8Wb0ZQuv_x4d$gh2NY-bpml z*Q`Dh?PDV?aW>fq*J{*DWz>c9VjDa7= zML0Yh4gr^#<3FKgfK0_bgmt zL$^#~@gyMCm+irz>k8gfJqn)@6blh&mN9sc31f^LNA9!nIpBZjlQWobi$1f}AwRu^ z%nY|24HlCL+GuXhxR>iC)eBKxy9w?_QgfEIqQ&phU;vDD$7>lrK%98Ib%bLgUCXPE z7ve;17g1}gcR8GTMIcycP98|vDz|E92%5JNu>RGcxTvV(vQleqXfKseqsLze2_Be_ zRfguTug+P*>!X=1+rljcPT4F=H~g#iiP)vayuruE=Ic$o-U`0G?CKGyXC!?Q1_Sxd z_*^OJoB$wSkg)J?5vnK|i7RM5X>-g2oMb^9c#XdI>dg~ap?Hb#9D&t9cMDIGenN2W zW8TGA;95Pmg%E7v?>zgU>?ujZ_3Vo49(Ez#DF9v|Tr(fi8%R9>go3C{E7>SmhTk{< zC0-S=@2Y0UCsgf=&YC(*&_)k>rmh0iW$ax>M|B)=1WJLXoP}7a26G{{GYypAa0-Y< z63#|iAa0<%&t-%UzoB|V_MjvAb;P)86na(>@cGf9>>zcH`|pDBhXkV{B65=Y_Te|+ zzbw3|5A-XZw8;2*r!U|{xXjG3nnv5ACmVl zhgE?f6Dd)m#8R~MCU8i$4t@*@PtGGtAkq(AWSMIq7YWZi#>;$|6jEBExnNc@FW5HS z=>2P=6E`f3EPt&w`5~V|9nJ5jxj zq-|MHcMC6jMc!FQOBoK!liv|AM|?j&Gwcy-4?Ym))3HpY-8-uRC{S(7nvT0)tS(4N zJ76F*UY<71>>u&}h(?FQ2HJc=GpXTnnt=xnyJi%>P%&xb!4p@!OXW9ogWe7-4NHYn zZcuV%WT$!rOei%0XN?lf-DZ7q>afnfFDK}SjhInQ5l1`%4H`L(9w(y#XKxayDeu+RPSfIYt<55qE^owY&pBW`o9KqyeEGJk%Fd-!f569e))$S@j>LTiV6q4=v9M;#ee%8#? z2=;7(0e>qj$A*9CvI=I0P9YxI|KmZH`NyNS8PJ;5TBFi9EM)f#JD=+N zNV%CUyQ4)}(p#I^x2lFyFJ#h`2u9E?eI~utcs&uJqf`AA%7SIP-2z}e17-|0V!)GAd=+YJ`vuV1Rho=@r)MSWH!^|=9XXY)m5vrODbxB zMtgJHf6#9(zvwqXiTWcVoPT-zxXEtrx+j^E$A7x!ho(oD*5}Ub9s_h3(eMf-&cYL# z=F@bRUc z+ox#klCdjeT1Rxy45k0pTHotiOyQjc?DRDLjm*EO8nGUf)PWpnu}GyzjlSyfUz7m5 zQtepLlb%LlO26bWS|e-HX+Cyr7rEYaTg=C{ujGv7$%v_Pi*dEN3;r(}_nLg1YGyOCQ2=C?)Lnu-#htj|prEZ5NU`p!cD&!#EWy^J(7;8&=sd4RsW zYB&}h_qtt8xice@t?1+;2@0-Fjn?o|*Da{}%T6A+?R)^Tw@@f4*CI&V*6pNbsh6-Y zD#O4vgwBU5fHwS>0g~2-axFej@W}OEcLIW>8e&i5i=$6#Yv(0&<;C=oJh@lP@(1R? zm33L%nN6Vz*#E`WJ1~hBZcUnH+qP|=vTfV8ZQHhO+qO>GcAc`OZcltO(|0=dZ+Q30 z$hGoG>hmvQ5Xc0O@)D6R3(J>o{)ieb-9bHOZ(&oHkh39@zyvCVTey79xS=h98}<&0 zle@)3chn?pFgLCJs%76Wn#t+5ZijPn%tWB3RztG|jN(}p^;_hesSm)_2n|qO1lDe( zBVaruH1McD6L4ftRZvlPNJ%Bq090XkAUSm+12fUn>NG*jB}zvbAF~B76P-ke)h8-M zkMTPqBXLwimtG(dm4r6;?hrUDh(ez^Kl)m-K_GWshPv}cznCJIt>#Yj14J^VUU5*= ze#5I3ENJRz#4cR{x$>*S1@JT6ntzAOeOF>FaR=&!Uz*E`q@hj>^2ROARr!%^>~PF7 zS=?CgRa{a(CM#TYw$-0mV+NnYO-yFQNwe%=1OMU&b7O4CY)K^7{>pk21`N=s=F9RQ zALgX#T27=9%;v5|2Fg(~{cg&tzBcUV0)T8T<++j2K&q8y3ke#TvI~PbH5@`#`njr613*5oO;^3*T>AvgW=& zO|@wTb3YRQ!XN&#Rebb`A1i;6uybzoTkwnaJOOX!G5OB3-{X9k!EBTCyxJAy?LoA) zAP)=SRn^}ZF9j7`zOgI9-$E#sHF?k(F|yj1?oKXRo$sm`j3Lx|q2=_cX11(?%-iIF zUz$wQb12B{H)ihH)+4dP{nU>e)dnU4?!oOL4v+v3xE`|u%18gku!cCi)f3p%mGjWg zIkxsNlu~X?*w5S=qFmrHLKk&v_kJSw3vlW@^ddszxcbt`_-X+mRO+a1Y1^&mbLkcu zQY1811l|4cR&TY75e^EpJcRhc%Z7t*L_JZp4Ky4Xgp&;$GP}Qfdu(d*ZFvexyj-1K zShBmu$%&CYESn20+toT2tM@qk=pV5$-QQ^IZ*lRnRK4x-AL^shG*NUw_vol$(>n-T zR~zUP)y{tF+8DTZKjSh`z0&l(nF;<$eAAhniDSx8E5OHr!t}M5kjzrpgBN?f4Z&S- z5b@dCu)x5DkIA}Ec)g`<)i*~o<8TDH<@O8!j)Ma}K~p6H*SMbvBt6YIePI^R+Ds%*U6boC3xy)_Imc-IPj3;)rsF~W zti1kGRK?+a*X>PLkv9-~r#3cY7WY|>n_sZj@Pk_cQ@TEvn&w3YKXjfjA&x2JzmnjD zmsSSBsFQ@~0SS9Sy#EkFX;ebNwr~pc9+G`Cb1=U7jY3rCrDh5L!R44%A%6)&SoH)4 z*8LJfrJ_IxX^sSm_LL$^fC~s#L?}VkCC29vq?nS}jB%VNfoA6qk-Gv`4^({TtIdJW;*Wm2w=Cr_CA0g~SEOk`o#VQM+1 zh@vuf24?P)GB2tq*I0_&xQEIeQ&Lx&8SVzk^A1jXE4X4#v8vu?-3~~*8$`Z+1MpQl zk?gpB!IFNHqT1&$(-}n>!=;60n<@6KJoRm{L{STdI?Sikf!bVJsrJ|fozHpBh+*cP z$+HNm8bL0d5i3ji)t0Y%GIdJ8!*G17bJ)r*_e9ZY?JIy}6K^BwCjU{v(e~c8F*v0j zF=gH;6t5|;`>wdRP2Crc^kHk)6P#jkmH4sjVgM-R zu#xahkz&>75r7rk%PyC2^^09;S%wZipqp-1^+B{G&17AC#FKj&V>W5($8qrNkkFZI z@h3)kpf(hiGqA&2%js!SpJ%Fs6&jikB9u1=L)Y4C+b(W4I%*mAph5M|DU{|x{cl8IP5QGO@Ku2 zXtk=_=WJCZ(uv)x#=%P#(aVbiB!DS(bn@z=hPi0{rtWrVK|ED2k$N7QEtW>{NC@sN ze#7ZqSyiiM5nXC6P0g>Ao3biTC1GB0lwKqwr6Mu3w_IBzBba4^$y47(2!2ugD>5tS zae$M>Py-`FUWY_6$tQuOLO7aFm#NTC!bVwR-o;%4kUHCA;qM)NeLbz!AHCItihcN(PUl%gTQq=jh`Fi`Cm33l+T>YPqhBmCd z?Z1Gv0(<=-#WHk_B6GT4X+EFa89qM%{g?4IqNLII)SucE(ef$U2q<^M29QbV>WXC& zMOwlp(nB396P1M6@nq^QG|fTD!sr!Cjm;k8mNNt=J-~A$&>Wc)mYlqWU9!h>ZxHN0 zY4o4=pU!+eGI;o;zlQG~vM=!e&V60kb);|2UZ06&O#2q)mi}&slPH3+J1SX>!pBE0 zoyEqUDBb|7r@#^@O{mu>N4G~dqg50)d#7~rb0|yXP!EM!0MG}dL_-U!yPZ-{xH^4w zeZ$Po&h~1iPs-cYZi{tZWSNhzQ@<&cR8 z#DeJto77*@^+W#pttE0=PGb01Q~v%UQ(b9e9uI(wRx+zWTE|?a8=^Wm8EzLj4NOHH zWkytVh^AZ$Md?h^s7%-5E(N_)qS}B70~R4HQ^ZO{eAp;Xq$Qfcz$U4T7v2C(m~Bqm zD)fH0wzGEcFRy5RrUgfMG)1AbYeEnj58;O$5~yTiM20|&gl%HWr`-^m+PXTDfV64x zR2%jW90|{eCbE&=&-j9mNzxvP3B&%dW6Hq|BaS=R4O%=G=Zd`q2XyvD$t=68vmrwE zqn1jMIukIt3h2O2c6?^!YOt_jXkj!+nj025m;Z{_YoY5_$A7$2WDNXsVeS)XS8R6T z=F|_$YgZV(s$rNeU0v~sD(Z(_mkn)&qVR+YT_rxkI|HK8sFCpV+7O>}EgL8!!$*nc zu7e>FTE5t{4@69P+=cixW)kGRqdAsSJx9qny@%A@Ov){bJDxk;VQ3mkRarzuD>Vx} zK;I18JQ#RA!@S**08lvUfP0$&hyn3+z^b7h+-mhN9_I@JE`PuC8KJ!`o;FxFx!D>w z=5=TiV#hEF6+!xiX9 zopZ><_#|n7%=M?ph)^Hvg_tgJu4?RZgm@60@xb*UbthW050y8sqaWx`c$&1&+Qb^fBl>>_d0 zNl^uT<5nyM)zJ@?m=y}9EBCs}EXzHy!SW@3yw1zq?i3F$SgsrcIm$E&7hm3+ZRhZ! z(i=M+ZnIaN3q_LHgzRAu+^;KXR#xTCn|J--J?)%#L-~-__ZOM`=GExXbuj`51|euo;4Da5Ch$ht}JGFnPX-9OzjAY|7KF z#pA`X&No=k!W+OnKE@rP@G4_eYwd=un@>Q3p!RD(OIeIfwHOI_&&?HrQHoLdq(-hd z_wvE*!!Ulr(5eD0M-*oLr94c{V+xO<{UzORdA)kn=DV+Q+nKJUA3&gL7uU_H7@>>m zyy?(3U-`?jXijrevkL06oQbkOS#e$yn{IbUV=>%-dGCAm-0)y^d|ijGCM>u1lkQ^z zw~X-k?7m*v_M4Z-1AE}Ld+ltk(?x4`DM7p9Maoga1{@%j zVpA?iH;T)_dvT}oE-kVxJ3$AX^bXwyxRa@uQ^IMZTSL3J#;_l-W1pX!e)pruKYf6< zj6+_oxIz-0Gt~i)S8eR4tbbbs({621fxRBNnKI#+*M^*@4pheNMw~IZnmr&t97EJB zc<{RJlI8ert5pA(DqVX>V?Cf{sU4#L)Oiv6LTIfp*+?o)L6xusR&8l(Va4hh~!7Wtj zL{>Go{RrW_1_%89Ysw~^42dNCifT)U|6er!f2H#OM^m>b@BZJK`WF_BM}$q4gt0CK z2K;gwU&DN4wiDJHAQyx5&J)kV9{%Q^l-E5vaj-exrj*?(t!x&aNyWqb` zAk8{=M*mBS_8)~(8A{g?x&j6pV8!^BhVWY9A9<0d{w3}QBA`JLBHE`FxDr^7E0>h) zxrElE+AiTC^M`|ZLm^mRbBH)at~FD~9{P8s*8i~0o^MuFUAs=XOQljO=}yqi>-OYq z|CvsMuM4-&dgx944{v`JLa3znrJwbF;x$uUvKQ3EKMic+H2JiPbugk+y1#PJrb9q* zBE&_s8ryBLIv01v8;Ae;fEA`xDxDVmZpxWci!Auwq4oXx{tJp%7Cq|Ro7GEw@DS(* z@ARc<<;p>}i=83IFL2lF@dG4YQ)FyU7JDvF9BOye)wy1nhfuyeHhv{NLeN-QO^2;p zm|_JvGjkF!rW9!vehH~3TfP__!mZ{QU*ye*?B^UE7Ay!=Q?a1zqr^c_v-Cv|<8M(p zLnRT~AScv)3)Nx1;@`Z1ku_DO+s77=wlY`KLb3Gf%lYKpvuMATXjpB zRkP;iLW)U}vQ~O#x_h-Dr+}8nqeEoj6B)OygFCiJ0Np&U76p%_4)x_u2iJ$OD(uvL z)-T(2GsUzpvP0I7il$a*d?wSpFuWPHITp9Sp%qVRBqT($B)g>c7CDO+nBh8PB$~C? zB(HxX){KtDUn1H|uPeaStQF8CHcyrl54cMMwA7T+w3Z8?m|EKqln+1g$aKcm-TKSK z3CHrSqdoD(TJaUs(? zH<;m6z`B!xNKPsW-6)@K72!8Ui<;L@aTDB!Ch%VPaR)I==WwOqa4L58p}F4FM6xps zD_>%F>KLYfEL`U}eI>uUm&x6#A}`e+EOVUJw+Cx$Sk9GtvwCeRU%?TO-jFxi=P+GA zg-*%-^(qFv7WW8VN_0(;-KUc-?|3`Tyazi;K48Z=(d4f>lxFs05+su%QP}zAj}Z(8 zedpfFdmUgp;LOJV#Rn#4i(rTZO5?GNcdP>)Qy=)Viw9SNzuoC z3z+Pn!58)Y*E4=q2Y-%72LMo41OPz$-^H?$^XI<&+u3|ZH_1Wk_1k=@ufWt zULD>XrPH~thd5VEj+05vsE-u`2pEd%fdFa15Jq|?{rY^TkwGXYv)~n0Y9;hb%esC) zy&=2d7ndfxs2Lsdoj#uzy#sK{th`7nbatk$s*A=Q6vc=>w9UdK4;s2jMeo4{N$*6- zR`@&nhnXX6jD!iv3GII>bXp7B5)KQ+n4d+v!UwmFDk3@&FZWuEBY^b(s?eMGfP#It zqij0SRkkB^B>)sC8t5b`kBrE&WKmMaFKm?1EK(#Is5`keMS%fXhWS{+Jy%K?xI*q> zy|q3Lg*OlfU)|PRWPg-@fA1*7$nCf-NRbE0ZN7RKGXS zY2!OLqM8A{8=XPbQ-3>uYsA@Wg57jtUT-!EsVikjlzVPmEvNw{b<1mMRejJIru4yC zK7jpoEJ$^=wWFpk(GR+!E~hRb^gq2l9{KS3cEj`Nd3`$DJ3Bt_HN8JcdVk!XhK^^B z$EaVQ1+=xlgg$SNSKt%k`>E7LdD&j=C>qYQj!v#daegm)J!u$q<2GHQ8`!^t4=O z^FdP7!&HB-5E;@}V8@-7RnE$Ys8XIVy4CNrj*`Z-pPm-x9Yi95N}3Sa41A`IpKYI1 z`ee(HcI-gIN;wBvsV%4aJ8RMKN^g^i-gq&%ojpB#$(M)40M1{(fR9m)B(pL2>{919~Vk7qav%&v>mZ1JiF{qcopUi_0)~7QIAh0?XAMbpFM^)!IbdSZ z2|t2E`sR zaQjmUZF81Vk3+)9;v}+WvhkmJq--$fY$P+qV+=DnOihWHT`9f=-k zgJD=7Q0LUUoWB%bClF`_+KW&O$q5H;OThOy`bYqSE}|@#@l%hh=3>y>ar8_o7!B%dBW$;MDOV>l3Fs#Z zeJPAHtYm%6Py5?slUULD zb|oBP4Jag$m>_oX``65Y-C$O&5-v;-tlrmGbTr9QnVlSh;9+PP4*zcT1CuNz0C14( z8GVDryH*f^(kwR0Vw@* z$lL~ap%Dh2H2+P4!^Uz zIUnW!4cZ2>pKV(@K(2+D0-OZJz+6VB6|1L_M5hAch`f&Z_4>2K#5p`sIk7P#LuGb^4)UY~~(nXdpB2 zuk@NUPEVjY)BOODn6IN!U|CnQJ*Zyv=yUfq%HpA&PTZ!S~)YoIGSG6D= zZ{JDr@~_IqHN_3f1;p;p#_F?RN%mA_!2<>Rk$sDtJO1%n`gievtq-ZpXtV% zxVo))vb?*#pwO=PhP3*PS^qt`DAC`7JY0iwG>=*o4kwaluq1#owUL9buG9VapB7R8( z0giuv@^kG%BIp}LMhAD?*YBKN^jJ(K^tCr?q3*agfXXY21;yL1QAqPy)iSMRI$AVK zyu||qj>=gH@K7MaO9ItPr;#;>!F;5Soa=!}s0icM|0O!M+w}C;X{RluFVhLF~tD;~eR-Dq-g1=Y^elqNT z@^y6Rlkb8b1jB%U_d6hnK} zV7}4J8p_q-7HcnDh7a=^bQfa+B?DHkp-C_ma2yTHFr{&|xW;c5M`?oM{+4d55C90y zYt++Jw=FKWUy&Qy34@fDnx3#a0ATIW^Joo>QZu6{*SqHS(HiM!A)Rs zsOHwDISGU64l~+*#VP`CVW1HlY#tNd3e<%qrq%e^Rn-OM%wHVus$t2p(baL8$e?iJ zSkPuqcXO`cbs)2+P;(7Yu{MTi^xhhXn3!AU)?HA5Dc*c7Cg#M#E`R}q4y+>53_En! z4wgsDzag1o)k>DkoKLW*nI#P9T2la$JjWdx*is7UY!8|2i@#_t!;r$)@`@oFHd$X9 zG#Cgwr&SX%{?liO3PCOwsvcmQV75)u?dEvrV*7HBUw0NaXl}nU1TeRboC2T}+Zfyd zFLyG%P9vyon93m4K8zd<2Pig+19DD;BNr-V9U!Ste8sNJcMC`zM>wT6bMNb~xYO`9 zEvdW7Bv>29Mkq?=WSDMt(FOpY=CTnjnv)=cADUgx+;vQ}`B3RjQpk^MG}!Jp){~LG3mzbG>dxh_{Qhg3&uDDYJ-DUHKq}4!~r$(JT31FuF zJVl!9#EEjFdX;Aqh%Ez&uFUtVgR9@bguuu5&!f z5c^D(+O64KqnxpIuK{_&Gb62vjdYX(kPuFTc-HMaKkRk6z`Gnc5&H+$6sh+O-M_Sk ztAGmSS60YvK;iG!7>l^jX0FVfvtD>+OrTsvMST7q1NMa%Xh!js865^;ySqhCO@L^$ z`=%3|A)!0HWMUTp&{fJ?sRwiioD$S}ZZOQ?b3|mMzmi}<9Yrdc+r$`KuuTA?ahMzW{guDOEBk6| z+!==rDTIP8IOdPdfUMlk9+VIS451NWRve@gEqh9RsTxu&LwL zZ&uOYB?b6N@e^jW-RqiH98H?520@LchxKg<_DpZOQ@P6$KWthsD(uC41rZ84t3K7M5odgVXxXS4HYJ6O2v+W(9k&dyDKP> z4)cz(Hv20(9lV3^c|$vSEj|>^%{767o;6eEG|cGyW~=yOqBgjY7$uPb34Srj3;Eav zx8jF9h)Y^1Sbl5|8wVTB4Y`9%JTea6`NY|`UO$*heICpnjH5rPemouX;e9o@@?22gci-J2HtbF? zvc#I1=LY$%Ot0L}VZB+dt!RGk_;abURkfB$GO(YJUY2l9s=jpm)eQ)TTbA@WA~W}H zpcD@2iEeN1k}4V9l?`ZUXsJH2$&lk2)b;<+s?4O6zCBeI3e@@>apMz11ahN!9 zH}q)qU0TN{AH$5wB|CMFEB3@~UvtSmYLdkUqGyj50`3C}23{fxjE$=`{$bx^c8}#5 z0GK1Ue`D)AxV4o)2&Rz?^xing-BSs5{%v*~sut6Nj6VSIspc~mGa~$SEI$DyLMRZ3 z7^b$raNt4B9Bwm=yd{Vav}3*(l|Cos@hfv;2E_(sd)%8j0Y5Mf_tsaPw> z9$Ozjuj&=7%J7d2K=d@}-MPIRu)e|M*{q!|*Z>c&th z_=%CNunN*t$lA?CP zl`T4pwyQrP67HX>1s~}N*{;hs8jERD=^F`XaFf<(j|*Bg5<1TZ9)Fp5+U13uRtU;af+CVr4QRAx* zWNn`46e3=iJ89g;PCmT9fsi^I=j3H@-YBMZ=cM+{r3v=YYS7tHPXmr|4?V!qIts^= ziG~E#qo6%MW?_g+e;NdTr-tWVa6S6RA5~6HlsZpgF481v%)RbSi{3n*Wcke+$*cV2 zgyk)Ge{TRwJLvygYev}Y*jt6_h$#~ZHe2O4nSU(lJMl-1%Dal z&0kP3jobSQ%nQ%wT9!EQ=F>wGwPvjFp~LOFJexkFf6+twb*_-HvRAWv7BPvj1wV2# z-7q)4%t16(GR5eJ*<^gv9~g45F6{OnFRm*#+$+JiG3-C9dpMnn@kwGoCp2?Yo&NKJ zgs2wr*zWx zPL_v7;C9~(M?|X?H@dc=H3nJ<>jZ4?6U{>+)K^V7rZWo890J zb*$fP`<|n38cmLbKo!j&hp)y?s+3Bv4E#&8FG^ZxUSs=Hb#i$IkX@u;0p^SjSuRJ3 zalMzy*|F3ffHL+pw&O5rh;G7FwRb~49=q~>|7!_lbjJ~f%*b;kAIKmq7J>Kk=k5lh z^nH*&doa{1Df(-My)@ANsfSR_%Wj>{ymLD=R@b<#S8znLs;cRZqrrE*G_fv6<$gvd;TGs;GdR4^MwT60&Q)lJB@?aPOxUIlZflN&J{Zh?4BxLxw-huw(iTep<3qAs}_`yK;k<_LZ$?TUBX zMd96HJD0<4k2^=%gjv@(v2g}7kZg`J-x#B6z-FAiW72vwA6}^lG#byzz35 z&1D;*xclDTk{g+VL};R?7E@1^Bndg`rB~KBAFso33V7GFk`kH{uKD|r%Q?xRIDVJ1ZlaOpVU{<-QY(u%fpbc)4%lca&0e9@M*n6 z^eHVhmaH6tJ{0FlWS-LFg&$w+cQu%Bq;%8;Cw6lU+8?&|GW#*&4?GyXrbz_N@Sna@-*8?ER zTbWUxMhSJnFuv}<6Q%xX(!(ZX8gkv#E}M!1a^f2$$c*`NbG@1+=Mj`{B~g$eHjuq} z>H}(}gJ1*+)wj_;-dJXC?85#+KHzRMr-uw6K%C@(W6ZsT+Dp1kSwP<&Jv0oCt`FiM zB#H>tNm#P@1G`0*5&mU+(7x&|99?i6r0lxbLdlEbxhN)nM8(XQy`#uaEWF;{DN4*N zF8Byb+Bau9-=eNo1gV#k+^6P9h0!eDA7D^?W1UzXg1BMKw|g^uCAPlLE9fO+w6A=I z`{ADR!}{KX(0}s2&Z$94|Lo!a&o=H_$K%(?Z}`#6GbCD}4uk@UNd9Uqw&CSid(R<1E>en@H62X(Rg4#5 zSNC(jG`DuOJ+j$`p0-w}r_+$P(}>Qe)9dZ?^C z^8sFZltL7X%XP#fCcuFV6Nwpb^pGRiw-afQH`u^W}NegUw_>oJ1GU zr5>0Hdx^oqB|;+}TuJwi=Q~6U^{*E1+HRY`XMu5mejC2?;)lMikEaDEC~Yv(?NWv; zyG6Wkj9C*<^@VE>#hDaad%;!reVSdZj@aI*WHW@$jxlIbK#qG&IDA0B#dVTOHhiGwf+vi2ex zvyt*nTAnTh4dV$;ihVb_sV8#ACpF7+x^Y5h%HSa2)4sT<$II^EcJT~;;*mU=Aj~gO z!mpNCNJ{UjXGR%w*BrRJ6$jeJ%t9n_n9~ZtNQ)#s7D$S#HwEHitS& z;F*gI)t_b$CdAoj5npKx3HJ(Wp#!XTxSYXmU^xpnkl~4}tNh2Ai&@`;!C4|_ju87N zX5iUR0M$3(8_xI=pp=Rg!)QugR!Z%eM2;Me1ncEAda3qkbbX2=1uCuiJ!oF1=q!4y{|GDq*DmoU5*JL&VKV}Uyb32&qeuiPfiZZ!Jp)ET1E4e zS&AtnF;T3&L1@-9Q*GU?hdbcQAjLj#S;FMDFcERo83yHj<|FR0*MW0vEp<-QQExmh zOL(j{@)YBk^%sN@8w69SLy}g$M!W-o0xCfqvZ^-&)zT2m2q|9rUepiuRL39Y^UD#C zvgtk(AD?=XzTn=C#k7PS6Z;6{bb{d#1PuJ3NY1g6sBoDe4<6dGOO*|nu{8N}Nq9zt za8QP(A@%C|_|sPDsA>qLU2#<~;HpAQsD>*@6x#beVPurg(wfm2I=At9x|p%(R7HEm zh+b$B^r6q#)Jn+?l0YkxRTvfdgFIN$>41R}PO{fpB(;SD<>7>Xp@&Scu&95wgE>ix zIyh6qmCPkLyEh3i*N7T+>Zf82D#bcRd|Hz|ATo~X>rM0psjyhm!Vna)jtx@rFWWH;|vzBuPU|wH?Zd0Op`^Y~mjzgLf=(p9?k& z4fzh)=z~fi_IW3_Qu9r`Ss#zE;?JyL2sLaw{hN2vPg6-N9kde$V?Q{ySB6G6dvxmVscBIw#|3T~AmsJfrIjH9swq91cN)D~FB+K|>5VnRj6 zR)bhG!lHQWJrPRRd38E&qWesN%xsqUO9k88f$Beb;M~ z&Ut`sK|Z}T#W+ECa zdldXV%xTxNX^)~OkT7T25*0{;T0^>nB4j#8-gAn$Ig;E~l%w!z0++B)>aQ%>V$3JH zkJNO*Y^eb?p{bZiKoy0pR04EVDis=D(I6^16-sk=JNI9u3Xkt%MCG=z!ci3%y2v+D zxoEb18E*mI(66s7IbOa{1()N0Vhg7~PLd<(E&73b13fQO`a|@cg_=Y4{%`SeL3C19 zu&GUMO=^bjO=8$}F)&vmUamc3g_qW!cl`Yl<_{8|fQszh@;SPq$%L7WGQrV3sPPF; zA!CKhF*3j)+`(2fWS&|!_!TYuy*qME=04R2{_H*epRrbxh^n)!;GXFjlTxrry7N^C zI}Bu`+cL})8TH4fOV>I%vE8-b4nqU~JWXFwU3jG3Q=lta5ch+c+?$#RpDIYbAA>0R z+&@cAokb)!aZ8UXv5)sn{5op1CVgpQ4U`pB@|$UEBE|7@M%&-xf1Q*O{ZQtxrY5HyBKV!vrNr*6IZkA}_ zTv%_$_`UZKVHtC_>u-h_RlC3?*PhFpU9TUQ%6Xpx+PwZv#Px`6RXAXVaB=I@SX3}a ze0t%(g#Z)^fD|u)Vd@)ItdCI`_m1f5tV;8zYKES#`TZj zNcIZ1?jO3Fg@Up&D7btqE*)WEhZVI|?#On#Q;mDjJ6%?81Mw@rfm}&8eX1| zxHV>f8m;7@E91J$Ii}t?a*-!%x{J%RL9w(onOprQF5W93Xm8k~tPmhL6I`v60gC5m z1o3kQRR+jjSrD$)W--2Ryzs4rWnqJ1b0?x!cX8r5_fpP*sfJ+Mx`M4IX%md5q{XkK zAvs%x+as5XI;z#gT2ND)A?Y_Yk~oeV%Y=J&5$$jf{@7;s@&>cSmtd`gi9O-^SRGID z{)N=eb?=$OF8Tr^s@;B)SP+5$(lo0!_N#^;Cd~Q^paj_liix#l9fA9+^XC+ov`(I;2SF3_`_cHemC{4Z=5*xZNX zOg>{4L6bqD*k@XUo|1LaR}FYKCeoa&Feu5us74`keUktUfcIz6j7s;|BQ=X(<+608 z?TOJ}bFv}F*N9}gGB@b-(cU0#bfjU4u|+i+?#Go@C*PJ@rQc`>Tc;vT^gj6ubPjQ@ zb2H0%#zaCpcwTz2*q*>&!-Ku#IPTyx>5b(h_Qul!=eaHPa>eabD^)0|ICR01Q5lbJ z*v_7jph?K8!r4o)2u?SPefel^yYeLzTWqRu$EuMUxv#echyfV_dbJJ}XZ|q$;Of;K zXjLz^n;{ENEPJlh0Snjc6>i0WMtEYx;ipS?$sW0r;q(Q9@iql@>PxkGE|>9;s%PzM z9i^3QNl&13SFp4%bJ_wM6B1R*SQc?HLRR&+dvVg~51PPYWRBU#3+xtqNGmY`|?7 zcmqyuQ&Fy4W9ywo)18qSD0Tds(kvnyw|XrA;P>oQJ;mcFZSI&lSLa#4{m7qa?3ub= zE(amL{>jvNW!mqM)~qovUKHbuYfuS|vh=zLH=5+?I*|Zu_%wQ$ z*>h5`BEVc0A`i_|<1+%vXoi?sLRz`ZLf_a9;%g&TM|^1!0o+(I_!YCd8B(3WA&bK%$>kk0Ko3?)H77AC?OfpPBsxF$g(ia z3hKm0?{OLThl_pMl!7ov(PLrV{h}22^thwIg(mhGFs677thuA9G)ZrbuN1Wc5D}TZ zc*fjF#hnqbjw;FghW-l8Z9DmVGq*F!zwU{q9i84hnk-{NF_yppnkZiEB=Jo&Q=eg} zNPNQwRYRuVLehTME`_e~sUL47q<2nelNZcF+*s zTTe7jo|PJ7VujW&ksChwsK(4 zg?pnDYLK1*z0l|VO!uVwT*H1?A*!U9hU0IS1*8LhvX^8aN@fIIr?l`exc2q_{3$m> zRGF%)2qP058m)5k(1ft1Kj`UkpZ4~|n|VptjfEdD$$s|*2QB@P0>LX4!1?$u*O8te zf~EBrya0*-06_V_yN)LAzjwnHHYT>t2G;aOb~g6bChmH!2G%Yn#{Z!)sO~0h{HAPQ z^csY8^Pe(E;B(GhYMF0HZ7h)JfQv0kmjDf>MAfuhN#FdxpSDwsgXsA^nfF2CpU2v| zc5f~)w^W~WBfO3Fjvv$L_fSutxCxci?~IA73MWC=?o^Wp?^pucDfC4bu}rm{p8UfG zn)WayNwh6<1&Nhv`>i>KS*7l#c7yM@)ijlnbVcV3#DReQSTVl*(VkU+NjHP4VOTPY zAPWeQLBV^a+->;9#VZ@>K$0NCY8bFDN3jJk?k2LzS3?U}gx;`zOvW7hqd}$fVi=#E z+`G4JiA@JYQzpkkke3hJk~Y_t-Y%&@tW|i&1Fc{qf0iIZ3Vw3|mcvQ)tc5h|XjGsy zYB;yjBJORH7=(`q&M6tRN;oyeg4m?1G3TUJB4Ld9(E6R}!2&#C6W#!PBaah2!r1ua zQF^z$GIz)X5nRddo#xGR3Nbxg6!UE)Pr#Zpd|~b za3|R>0LCf6#a6lD(j*vZ)7Ys_<5l3bB%#2pY9;f^(lK|6&58&D=?&z0pdz-my%*#d zFko%NpuJt5`XWLZm9U1dYlh*0dxxP?#fSwu7BH(csAeNb3v?nAj)WN-b>)&;BbJj5 zYo}Dfv?6g-Y@x#)kd~$D6R?iyf}y3d83VHjNH)vTdyEg!-ug3l-ZP8|=3!IG4EZIA z-TQ9!j-6Gh>B0HD!u$~9LK7-Cl1mxvkrFYkQpsd$0`MVX7oSZq1w@DSX5@=ci>cQ` zLC7*&T!>qlcK;V+?-V3T)NSdOZQHiJ%eHOXwr%XPZQHhY*|zPjzazSDpN<k$HOCtB8>>MNm3Pri{`ufhE^+PsWy-{~@N9ccUxcKS0FS(uZtj=R8nT}wtIHysZy-16R9RIqqXg*pITkZ z2Z1Daf=QELsUyDIcT);$xkN z#eOYS5yO~d?Z47Ec*iPaUyZDiz``Y>@A!#D!+PU!smUxopt%kII6!HvL!tfdq80Bo{f~OubCKB&h|^uy|D$ z9v(q-1rUHJhW3O@l$rV@bHLL`XYW!+cs>BNr=^Y93M;O|^dcKvV1ilQ6I)8D{hFk7 z(4Xv~hMYIB5?{yhnxtZW2&+UvZtrA%%Rg!lUHxiv;(4d zEIAnh41d}@)Km4!5(%-Hxjx5i)z+z=&g+E3y{jb$kG{)(yZ?NuLi&N;@~7&D>pXMOb<8P$a}U! zO>*7O|J5t%*<_n?PICs}0! z_Em%i0ZAf$JYPmr%JT)>_EM1T^;C~d6y;iN{Cscx3eB^59doot)I?K=mt~u--D9)jvWOFv+qI*+Zc73lt-i~s*Rg{oGSvfB_v=($jXISd+4F3~$37V$4@iGq?b?sG%} z3GJ^ML+plYH+D@uA}(-q#euPkT|&?rY|xzaQjSDS~ZX3&}L+3WcQ%zze!ZU z+|qa}D#{!qz<4JVHKzIUk#u|oY21hrV@Bhi#i@)^Yu7W%Z&AljtsTL@5#R5Z{P%TNpmuL_+O z1LQGotAol+RpVO-zB0CtH#`Zdpwsx`mY2~-N27*`(CAj%7Tz58P^a%k?~a~~J$-}$ zxszrq_>K4vdSm`2pI3n@DDjl1DI|xe^A`&oVx3K?;@k*IQLEqulpnhivLH#4#Vk24 zuVx!Q%`41oKq=GZ>vLobxN`;bi7no-5JlG>boScu?)0zBUtO6yviIZ+Ow(#_OV^CJ z9b`gwe}NW`v%i}0v_}GzD<6WTHXpxIijSK(R|iq<%K|)1fE!y%A8D5)r%qre5Faiv zTmEtuVmgvTGc72(qUE8;!s}%3sVH3^zWu&nLI?M7u?z{rKDOq8QpX!9uqG;Odq;~H zFs58o0@KooT8gF;l=f3vWJd~GSuOKLBG8((4cIW86${21{Ctr%;_59W_cIu7i_Ud> z{w9odR#-}PM9f>L(cFBoQ%@X0lUW zz~7*nr-;-?)rIjtwZ^U{P#?A^&Iv@=?x%{@AY_xl46fCGHILalH^!M`3Pjufm`I1c z-S|$u#n$KwS|oQ(L<=1AICJS~gJ~P?mnAnu(}4;)Owjh;4IL_=pfHi~>TlSA??j8t zX*8x~?*Hn8Be|-v?sq zM-p-}^8vN6N79JYrpn*w(`PJY?U&DhDmDjwMd@}>CCvdB(3=h)=T~YK#C2hHD;S!? zDX@a@Q~wjF9W^CymsS>wGs7*l;%=}rF>{+|7sC8^h`GR_bH{>_g3*Ye!bcHvV;Cw}QN zn8&Qsxx;iFwLJL29{AY8X!qPg-2wRy=0ekiW@GQ?hmL#cXrO~nx5>5`X#dic ztGwVlgo)KrV}rXx}p`Hc2Oo}+hpcgpJU zO#I@2klwk$R+198PS(_~6!@}1yAJt?9_>ixd7VDv@fA33%0i*s<3qA#52=SLA^4-R z%aC78qLPT?IQrfT%{XbU{tf(Ji{7CgW5?q!Si<`^ikkGlDS9UEzbU+W299R`iKbps z)3Mu_K=e7OP2+<93lt);$&OB!N7hq4r_EiLjIE6k+Ek#ruq{tAsH2G%`1z8Pt|M7Q z%PE)%RE=ltdpVPn&go?O8n}HlT*4AnT-eBNW_Q-TJX}kbIWJB=Nh9|VGXIx1EyckT zO^HwDEL2UaV_r4&#T}zNTvuPDfVt94*@RgqQOoiiovOe0wnVJ=Y$1cEu<;ev#l2s` z-DTkmn#TG*Qq@vOb znT}K=%U%SvEAbD7LaDi=i=>%?b-O+5u14xWy5%%s`oy+c5|*0TNXMci#R{}i6g)Q* z8i6ilUcyg2Xb9tyli% z-0l5j_uMCrMKHwZQAmNw*PjJPR00>a*%mPE1TUV}+3=VwKyR%;5n3kdQ#TcJGhZ;3 z+K4k9?bfs|kx3vqiCivg--yoloSuu56H`?rAypRO0$h#iJrIJPM&cN~I0G}Ab0BKU za9VOMTQehVstxj>L_%dMc`=D2GW1Q7`ASks%H>8VF+`ZmgN?aMD@T?U4Fd;C_i83$ z{X~6c(>kdfbv0pOMbZS0OkT1@@^4X9nxvDJ(feh3^c1mB2McwYk_F4kgbHbRtTu=m zbVwfYx8?C&pW`hi#mmPNX7Yq~`YhGIslywt`XXp~|a_uqVBzqu783I1F|LA%{=KR+UNh%yo~cC!kC50wc;tG%{2>wCzDiHzB02 z+pmiPKk&gBnW-aK|&K#=^7?tATZ$QNs_IXMt~Q?Nf( zraleNT)wf1wh~&l#y8`Q-d)Giol?!R+?+2|_If6%^96^VI@N>96Bm)ydFq-#YCWh6 zMe{;V{v1#&nfj_`dEYGs+%LB#H0&E}?2Q6$gVi9OnU4924uur9(gxeeyIrw>h^*y? z6B`#dRz`Ed`fykai$X8AuQ#Mz7ioA1U%*LtOto|D8{3xqDKF=X5ng?;+aDXqfJ|Pg zul!p8<4qak0m8I~a7dyw-;0PcupeR?^yExACfX{Zc<^FgobJw7;E#6=-&Eubjdg$s zb-zsr0AL0O=GhH(kQP2abw`wPJA*16MuQhsg4Nnw?kcv4Q#&IXEo;zTTRj2u;{-YR zaHgoB^zG>d5`dP!1{i>@O>PyWBi#HB={8dILibQzqSoJf=s5oJ#dhN#6WKh~>WS5} z#SAsH%=}xx38c9D%*|#g1kS~;PYv_;d_5a_As>wlcVyYCDW5wbQlQw7s6LJSGknYl zLN0m@=(q^e>zIkJdK$UVeGL#Hrt!!ZiJ1t4tI2n`nB~Y8F4SIVL>myocf=UrBZy_M z{-NwmsK7D8qU#w{grGU7{Sdn2+m+uWr6TNugr0E&ji=1P%`QER6V2sdwDq}(NObjA zz-9F$0O`l0PHG>mLedXqqGwgizxlHb3BiDt+1hn#+aTou9jboS?uw`#rL0?Zd>+OzF zU(e7_oGGJik%Fgd29d7ziAVFt?fx$7i@p-=dt3*2lpQCldvL>a#gg)5Q6uVe*q~ug zHj2i=53YdUEm17J+lg?}8+hqIIiwS8STeKz3tk7~+oykf!LcW&o4h)0mx%|i+OsGD z>1fGFoG6qpcNoyDt#_RVZ$`4^GbFkgM&HhaZmONPMHg}ODFH;^BT#M{e?8XL&L=*V zxn7eb6T!%KUR6wcRM(Bp$Mc%_1ALM`Zc4X2m*IVF+TR<7Iddw2(Fn=BUZ$B-w>FMC zul79pI7LNW&PGOJFM?w?!ns$qPU6Knbu-av-zxM#QOLi2YMh7OikhKh5D&8}TNx>4M&5)AMnr}vaemK{` z^e>c{vjiBlFs;I-YQF;_zbWFx0^F}RRcmrJ54yrqN4*WN5%U<%@KXg`VHy+DW421- z1#5_3H^U^=9i_j!Dj4zcNP8X5;a#qrAB}ksrgvnHw>R^_0z8Tsy+UfIx)(niuL=wf z=Z_1qAP=AP{_C~@rP9VT<(HK4@w;`P`hVIs{O>+Vx!?AH4Wavj8f>QDF@Hj~jvSq{ zEn)*}i_8CLOTV51L=$U#Q}GvDS|3C7vrR}kEV+;N=O20!b6lLADV__cTLrI$?ltY^ zlx@u(Uj28>Bu-1(RWVYrk?zMqW z?|Kd1a?=MS+6DdXg{2QZpkn1I7SHnf8F|8H^L!i}(vTVu)DjdEt7+FR<`U^-XbMEJ zryy-PPMLz*W}5+4Vk&Z!J($6=*f#1>P5YOgU%xradI2FiSNpFAx!6zMPYPLy4B{77 zq|Z%kC`Cb@PNlm3o(2uNm1F9Axw5lrk2zUF^uPeTW4y&!v93tO%60MPI3M4%OQqrkJmCU$jj0Yf}5E1oGmzr^u)s}_7 zm1C(KxCtyL&9b_&gKL`!ysO@R$VO|d&OA+p-0wEeAJ{vz2lDvlv{1^jE-Izx*HvZd zC29VNA?x{GGGlQ=XRbD+k3$sr)_KeR9?i9IROS)q>Wnfw8%dgTZh4 z*nIJ#=^gze>3E9W4zImXlHBAs{^>V*^<9v2;Tqxxc`d)uBT^VYvHV?_aEt7sXsQ80ipW= zmH2r6zJ85&*N_PW4PlI zZY^6lJkyuPmpX#;p2iw|sUrbw6D%_DxXp%AssLrmRgk72Sk4Bqwdm^tZh_X}G73g`H zyrvIFRC+zg{Rau;)nlA>kBw&3sxM}JvX9Aa=qXU(3}|)rB&>FU+x-n@The{$>+t3h z>Qt`V_dPGQ105B}v#kIT(5yoia2j}nCoHq~5%!l2S1V78Xs&0bNi{8NOkq!8#}`XP z8>Sa2Rd@tCe+FBvo6`U!lovu}vti#wNWw5i=LKXRCHZZ1LE2*_YuZutI?n0|&8HHo z(75LkY}P^yxDwPLqoB}1zSNg6rLP44Oy0@ zlp}o%AXNx0P&N$*v98|q+V10@^&Ch17o$t}0n_@bdeCpNfQCiWOlge8T z++7{f{{m)0J=#&SwkV?>3Ha8WqW^!t4@_`&)uyKXP#bsv09_gY0HXi(hS4(6G19Tp z8JoB=&>34eIn!F$n%dDzO9+X`Dv4}qYB_C;p!&Ylt~!<|S%8&fvshrCMARRKRts4O zx7(5>g^|_~w1U-*(Gf&`J^e}T4k$8qdsmeyVU@UCYdpEnzR4PfzH9Flp`Xk~x^XGz zD%ItAIOI!qS$E#Qc)Jk%lyNO=0>68@SlQueH)n_W%2ca;{`k^fU~@8Rt;O1&==KbN zzNx+O)r=ZGm+b1+(4oCaWv^^@cXJ}=t*V7~aP#qUb9&mN_0^eEQk#7G+PV*ee}#p7 zwWO<2uIB#Z?`ZoyV8#4Eo&njKBtxu8%FWm(0dAs2Zmbru;EY;})Ld`qT!Enaw5)6t zp@*)E!mDM$uHKD#=S7%ZW-p}@F)dJ~GA`Hy^tK$mk^L&0hQ({{K9g)z*|DciZzP98 zE#p_`*@+#z<~9K(4~cfZM{vk-8NKvV^3|2)L>-!q17h9LjD+T@F}tsp*fot3f}~2dpbljbjH3hFfI;MOY1mkEPhrx47GCq}JcC&{Tca zMe>ukdSNi?=swu)Zbox9AJ@W;z&Ktpzt=CuOm8A8=CCiQ_lrDLYXNxCxJ9DM=l*sh z1S6J&lpQ_kLivuM(HEiSZdW;Gt~C0T1XdB!S)EYdcg3*J0% z?GTe_uKAsQZXk{9d3}Dmp|{`$J~kKx7dY4jtK)hH*fdUoMiMv_bBn{;t=*^jnVPM0tQ+<-M+V=4rSKQk!`mJ39Ik0>yXBd+lu>FYS-*TN6 zBqhYz5}uf*ol138$kC&f4BAhyR#VA>J&S|}Do#~r1LinNd~gACkhycEgWGK(^HbgE ztz#1*M={Vv*e0=nE)1W1@MrClYttilTd@MB_onqXXuf?#$0gh+Zkv%c1)rsr=4<`t zI9wAPB#^>9r?+CLe}Ns=Pd%DqorkzKEOvM>|C+GwKGyht0@R@mN z+4sKFD|xsR+e#|hVf)huBR`hsMM#RvJ*kB>dkA9MV*l#_Mnj;wv(I`>4ZwIF@vSi# zq^KUXx(8wsC8OumaVWKkbMU(TAS81HcJ5e!&4Vu3&$Dil76Ey0VqC7;LggxZ z_-Nm7yb-F+$jy-Ux}7t})Br#ZD}OA2bc>}RGBrb#APest2v9UOHg6!r?P(7-4+Hh1 z>CKO6@ zbS}}#DoeO{@GWwtO|SycgQBGB4j9=L0O6Luq^J=@l?n4@m8aJ4DGJSboa~ zqnrpzY#a1gbL84(#R~8!Y(@&3yI80e=VG1e5Z&uA+y!n`9hT6+3uhoWlZ^5&)e#m0mmo%?HwDD~7kz~53_ zfJEnD=}AVG2ym=!%IboW6$#&vu6(Mzrql^mN22sh80o?SrZ!F3 z_M48*%10KXe{UBS$KrzB%UQ}t^@Gnegm?xNv}-5lqu~-;Q=X9pb7d|m`iI9|&6gO0PgtW2AJH=1i zyK88H=CUAkf-4cl=PEnx{OpD1Zv>*o8~qu%gIplxxEY150;LwrdyWa_kK6H|(YD0( zIfxM1Zare*Gh|FSO2VB@b{ar)ox^7a-;xB``T@St?qgblPt%`sgb5%D2#m!pkeYwc zV7>wTO^M)s%L_jO4JnIbbBW=BkMy?$@O|?GREWl1Wkdw$QFyM3*C)$@k?G@w>YV7Y ztJ`A1<;bpzUOiHw#xC#Q{GS1pM|-5LT6&j71Lyb_lzuq%E5zv8%}1F^sE$Xtt&M%lgjuEC?45WyHm0GN+{~ zU`zTfO>lz!5~W;6)Y>xAo@&g%BWfXfDEEoULoY)7V*3vmuUSEh4oJT~&wC$4tn3x1 zPeDe1INPo{DT|5o1R)mna~??vlu>%PBmsj$$z0T@Bv_RAD2?7DK*hnO4@KlAL_L8H zp#`c3rXRf;dZA9P3InPqJ&HR;F7&v`g&H1mEmi@-ULM?*xxdg%9j_h4g4mi3XiNcz zurT0V6Uc1k^rN*V{uluVfn8X_g0xzTWw?FZYW*>)ZDFS`z(wy})Z`8=BnsSi5$*50h zl&AkR8i^<*vK~wgdCCx(ZUs&uhErErk)5^&%&sK+!I-c}aa)w$lDFw4)}V$U*nE)x zB@2uiUl`|0D7Txryz+uSNW>p`{K4gi)FIF6Pi&V#y;bh1A1+TZ+4WpYo_ql_t5U`{ zfkK2h1z6J{m{X;{iw`@^ggbGXV<4`VN6MJ{z;!bS*DE!h`II>f&^qHd48kK~tsRpu zIe*JI54L56DS0Tw$nEAWLH#a1cqu(SrY$rhFd zK)e14^Wsuv_%3@VGrwVUbsLO%uo8P5&0)@Ie z8B-Za$P!)VGH5r$@W#5zZf7!NT+>M?pEXx1YAP?+;Nh8%JX~x~lO%y}ohk8<2uRwf zkU)*C!ZsThE@q4+GfH1A$8iBd4)d0pj_rlq)n%>cXGh29Dl3T1D#0@-R5bdKLTZFg zpt5V14`93)EXP~Ce@SgD=zkJ`V2 zT}qAFD}h;bmAbNpqBM+BVfuaXufMrZ~ht1fz3rMNX*dnZTdPZ>z&nr|$tQ1!< zMUdmqjx|(K-Q->Cyy^AY>CXp<0tU2?fX0lcgDE%@K5*^v^(cX9uT#S&N*Y+GRf zu$$U&F(7F@>CJkqWS3nJb|x91_4-pEM_(ToKvg}S?b_X)S~cIEUyD!YJqUa{JKbN1P1>Jz-o0Hv z>4K+Q_r5OO$1C7yx+7Qna5&Fg0(MVbKfQe!JY^@q;Y!>dqcPvGw>->D+cM-}=yJM|;WhW_u?GP=SxeS0$Z~Og0}HI* zb@!7IIQzpRR}|R`k1R6o#EJfUb(eepvey{>q z`YFz&q{+7mNQC8*3TR+Hr7^t7beMUK2Aq15Y|hOyC70%EZlUx?meS|jn`j9Y2v+7J zDYM0_7H!gVae&E$)MBTp;1!}4+wYEd$3?m30BPf`vBp1nqei0HM0Gg)Yt_`$)n>cF z#Cqv%?e-Fm_;&K3Ei?hbJ#TC24{NXY?20CMw@LeVzR>9?KCMs|Ed=wHz#2vbv3Y~b zMH9H*G%HISfuKN0xP!jI>?l!%k1td8%n)DAnf5BT2tknZD{kH!WIgHVYNTJnUVHfe z77D!H(1E&K{43$gLq~w*zbnM+;U1fCb!Crg--q)*iMu08&hS(Pum?h8gQG8dwcsO% zrI4L<+~iLtG&*XDGB?GKbZFDuDtnxO+KLC0!J*3vaAau{g~>6fPTTf=Pt}u`f#AWt zIp{tb=(fU{Eg}qSI)byS;>1`m*{278{ZRQ7^bE=oYpv>CIN%kps4~k4l}qSFprX4C zc%3;qdZ8eCB!&Rd?)iM>nRDNIV@u9s6jd(bAaeOK6GOr8TtcchFqgZJ_npm*k`(h#cE^4|D(Ewy6yx z7Qvn$GLOc2)b9DsLSp~o`ceA<{r}9W{ZF%=e@1KjhzJ0%$^ZaB{NFe0G9t2S>043jFR{ZcRtR*9wa%_rN~cyYOpG0T*z`tOUJD=Dwqu{$!`*`hb25S zd^3;FpGn#+Cl%Skm>mdPVy2oavdXZND6L$YQ5z;K8tIx-;JifgS74W+bs}=XO^C=>sKvXtyU~dCq!zLb^j1= zeA`$}wa5z-y@{ArG!XbHOj?%`{FAtMb6L$KYc(rS=vYr9QzFI4MVDklt=GZlZ$%hatlDe5R7^5F0*wpqS_9F zO)Y6SHXomyD5p|uZ!?!DOfLvg!~E%RFQr)>ca^(=z!|VCY!)DH+MiyMIt6|17+!9a z6;m@kcy3T|I5YpiMhK?OFRl!7sd#km6R5G=&vz_5k_W>nq_G&JZ&+`sW3};^r2397 ze%xR$eCo~6G|RTc$uW&WUqVAuf}JM_!1Ap!>8GNC>L$JN4wOdmGGAv_#CWSxb?uv? zPHU5A^C%&2FJ=E#Nvuw*(afKKGsm9pxUkdrgFxrg}{%+LJXd%rKP!y(5f zlM8(Zy7OqHoZ6<#YM7y0X4vhOK1C~oA{2| zu9XUw7@R`TgnA#r@}_sd60v#sU@S+drfNq$RLN{e%&cP`UK>B|t%+~}tkAQ>iKyG! zaodYR)>$91AwkZh5z&k!si}ighj^}MRS;g3)x4apczM*nkZ(iWB8s+)ZLlYu_^AL9 zsT8j^aNa=oPx=y3%}KFf!c5SHVq2WocSBENI`h znLD1NLJ*0RkibEotBz)uaBL=g86>#xr$wPkJ8Gw%OB|UB0fvrDzAV5X6Z~qfk?pBV z$2gseVuzB@Y4A?kQQ49{56~i_TXB0%_bQyzj!6TTKlF7P;WI|JRaPMT29?IuD(@dG zPMRFZAFYXkTQRK)v{6xH;zAp?NaeyvA*MPldg@4t$#>ad?8L>u-XL;HO$czyxkh!N zxHbOmqUaJzXD|n+39gz1jl3t{QmWJTfRrLuV+4%v)d816nL=EaZ28YIR?d@vZtN!g zg_b+b84dPd@I0HNV9F#Jx;S;tj5G#UL@H}Vb5TbWlwu$Y66fU+OhYV@R2!}!k?5pu zV$&+A9Uz2yDxR6QO-P5(*20u^g0*zQ0SJvA;ZO0P8>PmZZAv75zVY(k)ucwZ<5Uf! zsw+l_H1)SX%zu?_aT3t5?-56v6T5d7JKR4U(N>syji-2reSChQG0#R98CrK`hL zFqKu}t@5kXoS@YN%i;HD^QxfiY0ch|m`KJs6!!6RMFW{_q@`fE9lqojE|Ru{?9jsz za!>kH6vT5+z89OMwLeN}8szHq;TR?YJLBCY#Al~V0(N@-boo%)HuFd&E?JtmNmO6N zMu+*i=a7Wg43@%=gQsxk3?BOZyDoM8B5**-=li1(uh?}5!+-48gynE%d|;?s!jMYv zD!v4&6R3GpQhIK(vPB%r~Qk+Y^qac z$P~#UmVsVkG3{z|?;*iM@V-%jag#r+#(f)d?0z^)+O9G`sz}NWvR9EJkm$Up1OgAb zAu_q}vd(?H%!$fbk_dj6ns`$Rq_<+K;)A&O>-|o3cxdk83BgdrJbC6fWS8NtYd_=X z*ZqSPURj|OT$e2kQ2p1dxRu+?Z1L=e1lMO)X;p4th9{=ntzR!@m3Y>3xXDG98_LL1 z*)W-GV$1LlsSqmBDmgy%mZ|jZiSR8QU>`5%rjZbJBt<{HStbx)n$jRs7rowdJwdVD zH}79PDxnbJqzhfhi4WPe={`b7lleKp3Mz*NgHPED8O{o7^5R$3gpgFDyi*06T@Ow{BPTeGkb#eO!Z?{0n z;FbH7;n?opTBmt?(6`W;oMR_JOM8(y_u2OXDZeaQ$;!xfVJGMs(0bWC&>tgH#Gav3 z-%zUL@A*eE zSo`9Mb=;BxDz*m#1G&Q(6hcG8Y|Y`#IkPzm@=>~Fi7sqj5tF-dH>na^u8?tX11gd%^G+q2Ih1!1%qo_Wy)apkq&)VYsVrT z8eif5uJ$DLPpK+KX{s0o3SqL9J+V&?XG*=};XjD?1`XVrdM6hd+*}Le*iZfuGzR86 z`H>8Q<|rKrRFKp7pDs;s1X*aZO>!Zy7zN1(>>|5+&57{R>@o-dn9clkh&J&cS>E`k z*s5z)%?GAOX*Ra|#^wi!H72N0UfC5Ef~5+p_JO!~zu9BGWIFTs;o}p`%{R7??D?CF z4A6{@)YvlV3|u4!Cdt%~^lNaOcBpd32;noxusDNI9Pie7MieahiLd3#`C%C^&xDH5l^cj(qoB!~s zJn+Hm$sQYdb6gqi`2a)aK!p6M&WHFZP{)wp3!ZU_^dVGE*Nt$nfIxwYmMe9B^>M_( zSO0E01E{e`KKBY}ULV(r8DW8pkWoWjx+vu8zgDe`E@@=3;kswc4Ph_akljlo2!nvX z?28gwl}Dp0gSBpf%GNvHeY9%M&j&lm!D!6M3>%gJCH)(#l;VaMTPKbo-J+r&AlXjd zt1A2(RRJD;8iQ3J_SScBz5L!9UdLwS$aB03SS8;iz{8bn?$8V2Ho0DakB2hi;PSj! zK*y$Fm_<)D0*s+!!ct(sRly?8Qk}Bl#rQBD4|y+xcN!Yl`k+p!XqdyrKDZ|$HJC_a zC3TEv{n)Z!50}F&lTG^B0rnF-c&&9hrYEh!z;mr#GTR;q(N^R&XbjxdF#!N2s-hHYN6r3> zhu1f!oB$&1Wfm!9B8fFU%xuK1k=boIBd4}_JtVxHRL>8Y0DC!kO?Wn75Tl1t+@DD;o)VG=sr%K?FSu|NOV?93N0G} zRoP;BpL&KVz5;tcZQM6pq{nTt>MR4EQ{-2S@0XA7SAsu#HVTf<*R7-&H}TA`7N0Lx zELuSQY-E>FA7bSBDVQByxoUQH^3&L{Fd2IipghR*;mRWVK>TIZ=kpdK!kU*9#73u|xyxwncg(^(}Vk{0b?ys?a{l1ris~Z+KH3&Q_Rto=lN%Xgf0Y4!1ZKr2ynhWdD4BcC4Zn(V{BD(vpTY}hJMz3k) z7Z5{G4X|#w=tHt;R?W*HiL*-PoeHqp=HBaE5XBKhGpy!3V&M3jY-Ip31MKYH1gj*i zlt!)Mo=yO`t#8BdB9e<3XoRp9F6_Yj^*TswK})v8h<2IpotU^Zpk>$t7oye+AcMVW zlm4oos@M6E@* zk0)1CQJPf@&6y-Y)>RhiE`1%Z`z0za`Scp6Ho(F#C zMg6x)LvHy~uLiiR7)F`4>&I5wo^K&Vu0miB`%<=z1TeKbSrDC8R%|hm6ncWxNM43A z_{ZNkvNDqzA6RvVW{fvGj4>gCrQm+>>WO2gH!<%z6B5j3sg z5x;e%caN78O52De`{I;COeBcPMM{pjoWlTbG-vy=IOz$Mf(X?MH*N`8N`f|0#9lHj)+ zJ1zN=_#cDi)XA)Xr{}+N3aA0fI$rcU0`aL3k1}cmnm}2g$F~^A^W1%;);AUShkk-me@opWQ~o|$QMM}~>}A-i1ARSZlP5}G9=RBE8egW#lu z)(j*tJz&Z6*>d=iMYu^s_MY)P^za&Pjj|qJ&JT0>Mo$ozUALe8*^v`>%ay|%+H_+s z2hm!nTm=ft#rp&Xj~<5=(2<#Ma}YF(&sXCJMgFntM1>GSwr#zw@y!O9=e)&pyJ}*# z%IfEZg|8?do0g4lkg~JmcYLntHSA=&6ko}HGO}=%oo*?xV7IyXUe*$Ir(EwC?kaaCSe;Q{v7x1*~IoA==bzaM1 zQVj9E?qXk=TzY-#6+%CJq#ij8v;2mta4vrd-FlZ1Gx9L|gT={zm4bcaHQzgj>V@1c zwH#NV^Cn#J-c_9m{&D|s`J9WkPjCk;8`aYhVs$KjvLWxoIg{Nt$u`21RknaT>Y^dhl|EIfgmrmO1A6L#T;a`uQiC zWmpPY$?0+F21N==Dw!iF3F%fviZYhv>B(8CC7G$p!+jut|LQ%d=%Y0GkBmk5dr|!N zcjjzouV-!IYGO_2?C$(a<)J&4tQMc59h0Cvo}j0YsM@8Hq@|grk(`#9rlvgpzZWu@ zx3dlWoh94-#Tfo)F*|!_1NZ-blZ&E=kes4$Oxk#2;aH2v$rCL_R1fARAb~JB7i7GC z;CvL`{NM^-!S-}lB2z^bnE_d2-LUJ6)AQQ30`jw8s7&^!xmupzD?{kdC~z}1=HP%h zX(HtUnpU(`SWfA>NCMGGXXD3EQ%+x?7F0Y5*f*k&z`n_IhTTH`FcNI+)Wpa#)PP0h@1vr05h%ek@SO>l?gbM}SKxc~j0KrS! zpy*#l`pLh4!%LE2eEkLiA@u9>Rk#2p^7|oaq!OFMPw{Ui?VdKNGT#VZK3LtmW|byg zj=aOBqpuKBnO{76t+THB}zP5yD-y+l||q1xZed%(QT zBy*-3(6+L27l=&2zTT5L0RRDx*c0yl+MB95mSF%3vIRlr48t7Nh9xUFhK_7Q-<{)KzR~) zM1ktUB${GkJ{XgxAOc3dO^US@G7U`gZ_P%n*Ouy#oJiSj9YZBGC+pj z!I(5phWy94#68eq8gCezllLK%BOE|*kKz^kQc3%|FMwAOH7K6ZN3)AWYrZbsY+ZO4 zu2+VwH6x!JODw2Q|!j_z$)J@e$}hY83>BGg^idVgM7=F z7(1Jn-PIEBJ&k=%;OVFQ77G78lsDKbC4Fq+C1(cr3L5PB0G$F2df%RfyNHzHYB|(N z2q{%Qun8D_hS*->3*XbvD&M{IP0b=KTZxPe(WqFtLxagbIk#m&Ym-FBld2{0FN*H| zdduK`lOHIMagKrxrVu5gM81awyADP?`Hn9fvu4vAn#G-clQu3f^wlxK5f z!odmS>sB!b}tDivVHp(=8gr&>xZ!4u+ssa7$QlJ(cTPB!DaUAi>G#@<)Zou5T3i!9B4I@s_? zcFljB?`WOlzEi%)x{Cebu+v#P#9};{*enal^qjJ6Gs{P`C&@Pk1Czn~qVi`y_+-`B zbKezUB3Zq&9ap&aCU1?P3Q5zYQEhIcrYUJJJS|OktBTU0eE)N!L{h<<+aR4z_nJm~ zO|x3mhu=8)oLB6EDmr2eKVrU&`sWECH2!AN?tR5pQ_8$~UkXS6&e;rQCKHiW?X|L@ zB++7|V?Co`qg+MisvcgL;V|i`9Cp$c8WvC5s|hsr)M5(f9(2u&g9VdE><>^eVTV^P z*J$nYkijXcdYsCp$zCPD_TvRhF$IP4P&`gEv=f(jmV+=n9!^NC-CB-=H%ZgP6F_9) zm|NSFyo^R}Kj>r_`p z!2Aqjws2TyZfBub=0T~$I%kPA%Iw62RTdh}MshQd_t0m`c-MvRJmm}y@}Yr%JYac<k~en&oSNTnP#`dS{^1T=9z(N8>9>v!(x|vU354YF*>_8bZpY zD4tCcI@3j(gIv1kg33;8#8B*}F$Pm}VP;(S-Z8p5-E>jya^DU@QN}4rD%nU%#XfG8 z+`1r9iEhrj)L<=ZJo|h-Pm`YK`}u$G_rCA?mgRjHE8cE<&_>616%{F|byb3=zAaB9 z+AkZM89Q!#d^~T}?sR1iv$(;Rq7i7RYtTiS!r^R7Q z_I|;+MF;;GK)Snord)7$ew5B1Rey@(=MG+Y+a49Dk0jTf{dd=+;@8K2<6e7nZ>o>$ zr|SFBrnHi;D;*RRZgzeeyz#T`y0hf;xrYrB)3;4dn)B{@dvRo9=j@E^j`QLoH4-UD zY%Yqfs%bl_c_pvwchWp9kE@l%eR80yQfh_n+Rk=}k}WQ+-u$lg}|r6~mbBF^vP{W(@`FmkV=yG+g4H zJQsSl%-*Jl9Phc=`_}e?b9l%O&%|5pEkRn{{xfY2vfaD}ZM_3rsutf$4r@M9R5q{W zOH{9B=ccAi=a-6FJ#&M#SEPDvYq5TRpJBk!w3;vX`n5rRTA8xPYpQCIUrfQdAlgrTk#5{O;^~u0>hpoJ^UJ=8Zf(NHwyen-*! z!<1C2$Ih-w^=^%(ZpS+pSK3>cAM!D>9HhSbHCATea$@wUBApe*_B0FH)TNw_#~J2@ zEswhLLv4)Qk`Kz}*t_|OcZdD9vf0!@bxCB2N>9wtxCz(RoNWBBqCw$$xS&;IUx!rGKPyHLt2%C z9QSI7KBRbVHzwcv$KAjZKdlwhE8Q$uNK{NBx!=n3*~X7{Jwt!+$ zZs1K#-IwmdDUEaKxhmI1c9*{#S^xQc@Eg;>2-RZ6Ss`DUT7ss9PCqA=8EPeNNiSTJ zHds@#*IoTtZ}l~$tZJvq%p_7qMAMOER*^|UOJ7;8Qt@ofgn4~2?srz7E74ndZ$>BU zhp3t#$;=l9`f7T$`FW=G9q^jMh+8yEYT+zW^N}H_jV8&=kcIO)G8rK$nNgpl+?jSk z89|Q0FR3fR|4N2fg#OEi#}Gg(Ag(V|L5!V=z|8VijfWLbATx0DUGh5GZ6Ar^H&E&2%4 zmrW0)egBGp;#+pZW)4t7K^=C1FAc{~{OPP6K@(tVtxUF{y`Nry7_{TOWQawpjmKcv zJl+nvKa=i5<8Yx!0jCEwc!Ksd&Hb?T)$-qNeMdnQ3+W(fcE7M%BsUYBs1jOI!H$m% zvFQD>F*G7E=(fo*VnaTo{sX&@z6)0}#G<9v_~}Fvu1vC^UVrZNCLD}Q0y|(Ou-1RS zt)+}92pYS~Xegwb_p#?gIZ*@HL|yXz;ZrP9gBSD;-xX7a|0eqWW%3qOzQ`tJ0YqTk z89ZQzs{}w3@EH_3n?=Wm0IArV+Q_YY-39OTmes-gS!wWi0FexIUu9h5mCl11PR>CP zHOxAmM_?%45Q;dzK<#jn^N73z^;4J(ZzmoZV$tn?jhjy-0)<#?ZhPedA{=2U*!?t$ z!9W;nRzLuY>qj7DSG3C42jCOeRq&pk2NwReA;b2RFrLY1&d2y znA6Q?dk=zP{2|t6Vb-IbPYjVJCvYWswhhF}HrsC{bP^b(?E*YW$mg-4lN_G~i4yO< z=$3zI9@I~OW$Czm`~9m?mPh4;ENSB0t4yO$KytDML!tG3k|7pZ)M0?kz#vH9Y$lh% zl*Y2IxNGDF{@4Zf!lrPArx=zmogw7$S>eO&^6)8QZNGMHJ*0?fpb8v&WQaw_UP_@z z6GU#hmh1_RPzKJ8b*Qv1EQ-ye3(c2E3kZU?Z@=y z(+RxFMJ;2IEO?a&A_e_9pk#Ta^ZnZPSfyq@%1NU5Q#xI$!k)@3}uK+#c`3T5{3ma z;ZW^EGB1H8u!*AWvY#f%tr*j3q zJU$1vA$K&)sM7`iwQPgUI5Z!hOnR>pLx>iI?#7|5nd-f;AU3!N_HIYD3_}VEfIA){ zZ!P4vcV2;m<}%QOy%{j6A|OZ;(pV#3Yz7JC5+oGtnN{xr7DHfIbR|bSO4c3)9or%4 zW4C8t-FSdB83P^){V|XgJzVhd>)3>$gz)JB(4Q8+#i$beo3td32*>YtgbZn7@Xc{Y zLHj4mjJ@z+It7etl{@y@kLts>T^z;kQTe-AX_E2X%|_Yb>eAVmUT1i@_|7Mz$cx&4 z2PsV|zUjg!_pq*ni!rQ}J1(N~`cc*!z3(jI=aAov?Z>W}r~+V=7jHZvPf`ys4v2d? zJPLehGa+#JSuvi5f1*1|W7|$h!#v-`Q}K^)N2sLQiK*Wn<~rHS!37*aR>L2A$Ra15 IAXy{-1M!0rp#T5? literal 0 HcmV?d00001 From af275fbb2e3a6765856276b77cfa3b52cd9baf23 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 19:04:05 -0600 Subject: [PATCH 31/35] docs --- docs/api.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index a3c09bb8..a0821bc6 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -13,6 +13,7 @@ Common Optimizers adamax adamaxw amsgrad + eve fromage lamb lars @@ -67,6 +68,11 @@ AMSGrad .. autofunction:: amsgrad +Eve +~~~ + +.. autofunction:: eve + Fromage ~~~~~~~ @@ -289,6 +295,7 @@ Optax Transforms and States .. autofunction:: scale_by_adamax .. autofunction:: scale_by_amsgrad .. autofunction:: scale_by_belief +.. autofunction:: scale_by_eve .. autofunction:: scale_by_factored_rms .. autofunction:: scale_by_novograd .. autofunction:: scale_by_param_block_norm @@ -310,6 +317,9 @@ Optax Transforms and States .. autoclass:: ScaleByNovogradState :members: +.. autoclass:: ScaleByEveState + :members: + .. autoclass:: ScaleByRmsState :members: From 608e03bd51850ccb08d72932f267fb7947d29745 Mon Sep 17 00:00:00 2001 From: wglao Date: Sun, 22 Jan 2023 19:09:21 -0600 Subject: [PATCH 32/35] remove test artifacts --- optax-0.1.5.dev0-py3-none-any.whl | Bin 156407 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 optax-0.1.5.dev0-py3-none-any.whl diff --git a/optax-0.1.5.dev0-py3-none-any.whl b/optax-0.1.5.dev0-py3-none-any.whl deleted file mode 100644 index 28a2c86228f0badd42c41ba4f23a378efddadca1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 156407 zcmaI7V{oQX*R31dw$riAj&0j|V%xTD+qTiMZQD*d$?133*?YfV)u~hWzxDI3TGt$7 z%&`=tLBY^~fPkQYfZ^EGb_yquYTY#O_(!@N5}rkSLZYR#;4^ED!3@IOvRn`MqZ(yp>hTmP+5#dXhji@Ab>$ z!~Hrkj#swm^aEKYnZ=e#@&nbhrTT*%{kFD zXc_s0w2AuV#HV7H?zXfnADvnB=!qw0#5nixPHYP~Is3jGJQ;a%Lkx8vyJXN9i`4j9 zHR2S`a0sdGQ6xCoXH2%twZuLaApyN`{BvYmXgl7JF$CtynFuI!PD2p*Poh^D?50Q>Td7m5Fd;)z2{B%>QAPbRU?ydISdOSr;(?=WHvXeRJ&vIX(lKF5nkGg2EVWEN zjGR&#H8p37Q=b+nB?m!57q^8P8SC-w_XQ3K3CZBx@h4Swx{O;G;E|`cs_8+al*1ZA zjjwOCW5R_MRnG<_+%(heoMC4`MZOMqge-PUy-&B}2xCK=p(iMe4vQT@cvd@7IPQc& zNE5w!G}C}O^)fZZXl_hvsH}XfNF1a3tT?!UWPI4*PHwm$0t3} zs*`s66`^hOP~D9BK;k%gI(e^CYK+xgd#$lr@81*k7X*qo)z#Vsy&btcfOnv^ZlhTq zsm-T33&wW8IwS8&2styaBU!bi{%Lv{$YsBM#&teLe8<3OrlQsmAQ#3@8^ zL}YJrE6sE|;Ggfi&{b9Hnu}`-NP-YvtRmJ`33>{1I?!z`F-DBsI7fyIm9+r`YcIhiVUQ1w0V>*3A5VD3Q~w*d$|ko!5S0l7#%EM?O;LCq@pQ& zlC9LfP$p&OmKl0`Gd4Ca79C)O8`LPCwj6a>@2=6UYH+(|sJ9ehWaRkD`CtL$n`L?{abkcSG!Z2LveiGE zH&BCEM1Wb>r0uH9TaZl&s4(*32#Ft^D8#*`Vb#X>Aw2{|ULQU@C3VKxuJK6Q?TC-b z2sj|znMkx^#Th|ECLE`F@T?SqOv)Jss}R*r5Y5RXQ6W^t$iJ(<7FQk#@^AWj+O#*8 z!;~XTDH83e#3RT#$X{%6Uv>V(YL@XS&cL6qxFa~pieZb$g<|0j1#LyfvsxEqMJyG4 zBRiqsp&Y(za71kmVMhQEjrs*~m~qsx`{XJ?T)Km>646+-sV@su(OEd(dwGkU*6ooK zhxry`7ebHtQQOSlLd(UQ!d+qOt#R}1Pw|`<49VcIAvxy!7~GZ!rZk-~i6o|iK}P_D z&0D)tHR_tu`CjCBYDx3W4yEqKFtJG*VEDeODLlJ^&4KEqox%W5QgXv~xJAZRtY0$q z=2q)C<=(ZPEhwiLTx)OAV0VHzPh>UuN|*UD0udp3X#u8bke0 z8p#Oc1PY;Neo#cZtvU_ENJwsl{j`4&d>1w@ReJ4(wdl)k(TcJpg?Jv)drVY=Px}H& zVHdMD9|jD-yecmdMZp(xnO3!eSEj}?2MI!+?TxIiG)SPTj_(%2J3c9gt5m%cYe#BR zWKf6axzQ;|e$FwyrVwhKWDVap{T+4sUR$t$vcr4iV1l0f3q8nhmh^8=o={g-?LcA->*7%-%s z1fX%I$?sh%o{7+Dc2piPb;vmLyJ4)BDFXc(ZQ&_JRr1LI|F&&}-;i?ix5x&CN+E zVf|VQr1dJ}U=gsLyNjnclb465qlXdJr+e*`s_If|X>UpDd{xA-Wq;tm$K%HazE4+O zc70!yi?_MHztPky=lFYV07~_t7~!Wa;@p6mD)!AEt`8cEjM9TAIdGnLoW*zh0_o>{ zf*lIVL0A4T`JMFEV3!tW@nR0p+MliU)Y&|NXbC%w6S zw9j$P5~qVR;l1DpX&DriDN`o-lJ=4h%7!4|#6inHp~CXT@0(B&W&Sgi#I;(1BHmc- zMwD9Alj``7x^@+fTz=ch;cVw3Px1r?)RAuV^aKz#s@PnAS=^aBwQ(K45yE;(5bSm) zJBTj~v{siL(H23(h-bv0b(hw=V`w?TU)yXX^)BP^vL|at-F0jTz)~olb%N7|1+mC< z5fO}+LUCPu+T-UowjEp4pv3LFon>USwFL7QM6VVv5$?Q8eg7tz4x!{jU)&beP zQedmiM6U&;1ErnE%CX9g9kNzEy1N5yc)Y*zwR}V%Zgt~`&q5clNl?q5HqiT%uZ&GD zbI<(4i%Fc-F$^7^*z$RccYz?HikZ|f^M!t)gY#+GKyUA!2PtuG@EB(n#y^|Q(!MCL z)tmpKA=Wdi5XUad&>>m8))ifk+kKr;^fKByl`JKEQkrk|*kpyHhBR&cz&EUH&FB#EbLI9`>sFx z9oNtL$`T&pUYw+7j3CWFPksjl3%OHFn-h88`_Q?8jW)%Enb7MfoV=-1(RU$gJqymb z__;Dxv2`T5VtWXO3Wd!SFX-ZwPl%ht-Q<=PYU>5ducIog@wMXN`0CWD09U)fz;l?Y zR`buyve9iUH8$|?ZV_~dA6sa(W6Hpo7wq3~?Z;2S z0dTkdzau;+{+~qIvg};ws8X6N&WfN*z7)3Y&B86PDE6TS$9Ies?W#sUWQAh6m)xNA zu=wuc;V8xkI!#d^#WYi&!LAq|(&t@YpzLA+&kegkfqTZvHk9&J+5*N^2lO9_FMbw)Z;+^t(8S_c2?JY^sg12`c^En4?kHH+-@8GrNyR=Wyjql5KV*a0NgfY4h#uG`Yz)#`Gp)t2h;GkbAxw2knB9VNy+CC=#%e0)%r zz0t6gIYpOG-N(YEPYWm9eT4hT^Pys&0$WEmak^$4xeX(2jSg1+c&8v5^+f)Uq)`^w zBVx6iJYogHaa8zCi)N9G`6*H~d~vD^hGc>YJGf1)iWlV?-1iB*T)&isJOyuZi0hp9 zMATEq@%0NQvbUs>M^^&L@X$@0?qwh1C;%+Xu3=TJoXsxO#T};){D=mx3qEZ7@p^0K zoAENf)FBRa%6Nhbp39owZbBh0 zro+xj*znP|&+GQK+Y|UyGC}m(HNH+y&j+wA(laEn)tW&|TBRya0LD&~9A`Bw(|#R{%x&>IlL zpLL$Hvb40Yp@O{t>w>A9MtuD1QhP-nIyBzuS zPlN9rdf#0<+5+N^Emi5+A2-=()fu(vZBMe^9RmApm+uIMWd!pQ)R3&1avHy7q@OH>06hZjupA7L4}jXA zRr$nH06@~OSXeG`=__><$n~tXvV2*HSmh_ha(>Q;g&033p-l$S zwvG&!&juqA{0@2M@a=70k5Mef%(__|1ibF@ij2-15s*SYg&KEtw7OYTGa zLxklwfzq;d+_L_91L;WDy>!0dVx)K`k(+|YKWxyMEWfY;xZym=B;cj0(!)dU3?6E> z3ydq>*vjaP_3$!AS~L2y^88!vs+k-8;0GBvLhJE*QF$sPY?bRu{2LWe#Ia;W35_j4 z25<75_`HzAdWFx=%dhh^xPI%Q{K4!Kq!&S>B#hE>d1^(r^Ub>;_Gy$e7HIZ8 z&FB5&!H1l{vJRSRF*ArSdM8D_dYOtN>^REJbIOt-0yG59$5nI^<)RU+?f0i}+i&6F zxW8*>0f!H9`%>lk#0uN++0j$-@~$Z{cQBz9@b(?>>-l3fc1@-uu?^sj>+cQGYAed@F~A}@}* za_m(YDU$?age#YCG7au#v=rCa8<^G^f_f~IxZOMHxvr1D|1SVkAlcOFtRqkZ!GM4= zVSs@C6Ze?98`{`gn>sO=7&;p|nL0cDud6R>cgT*^bEyuOJ+x6%$O&i@&@J}okZcfZ zLNMTg#~2vQFG-LU&?m3pw4MLi5sz#<2D_=?UPIyXel$CozKgbRgpm!Y^k66SI{t8y~Eu(+D@_w@+2cM3R_o zNHsaw@H4!*2oDaX*$w{F^y;jNdeTt-zhEn&VUWp$nZ zy}=9KP8cIGu93XYYf*u6X(3e%kf&8M$cUgzi5%8qdXWLV+orHKYJ?R}E?I{d<&!Dl zlB>6HNvTaDFuB3U(<>_2@s^Eui|36#j`Mh*}AFf;T>dM@lJTVq!g`HoU zFnOdOTyo&Z_buEeo!dx6s|>*RsZBgUN&*(oq{N+0NI?MWNDJs=%qdI<)GI9-EH#su zYT$zfk*f49bfpkIV3Zl?;f%J&x(X}P&p#hvJ$kolSOwTZ)iqb~)|UzT;6ldu14IJJ zCCteyqSh14i-MIC*;BQROmEZL*je$t`LP;s8>FH+!kj|NiY?L*3CYyc6E(UV=bz7} z3(SrX5eR`?3)W{FKB70IP=Ue9T7_sS)W%E3R8A&sd|Ri{PV`gd$4_uf!9d{{GxKeX z`9IzQJ{b|{=AOaD>A*QZ7lGbUEFcJp&ZMkRJTPia;lgXTzp`LT%smN0({`yi^K|AxS;s{BzA-_@H z;D1iq^)EilcY(Q;5yS0)FFcuBJElt;9nQSDN-zDH`MD6X5Ar^2=HJAN#faOF!W;=J8(td+|E z5OFllpmP|NCyQ`~lNSQWFKNQO6o|vF{se|R*9G+KbP0Cq3|(Givoi=ngxQ~B25oZ^ zW9EsOu|B&-ia^;TCkEji>mgE%4!XFh&4IC+WL;*BR6202U*N~__Uw%~5Xp})KamdR zXz5O+8T&C-Hn;JpHG;hyFFpCd<_@u0rN8{vD~QnWBT7sp9H z6rVRkPN4O;x=tq&Jaz*dg93=AIpi1v+(x&Sb%|y0U0$Gaf3+|NA1^9RY^glNuOmkD znOW>MGZA=mw5RuC)@5cdBU zv-F)!ot*!Ht3_?o_8+)>R{wzua?-z0$Dg-43e7%=*3$VBs9|wr5DH2PXZ=d4P)ezx zNA#yhSn7e!mIJW{Nd$$9h1cn1>W{g7S!E1HR7uF?WEG-)oW>XR1e{qqy)^l3oO__=SG9M8sI&kg#Pg z*|HH>4>>oPB3S7Gof6)=AmOFj(Oq3jLy?|@k-CkhHlfgDP=hXUgCKsnpeXwlKAc`p z$c{DSz8Ug&Oz9y&go!`S{P%DE*}ccJa{y~b3}NZr8<<;M(&xSPU(SNIIL1ux0R7RV zPpJ6vLSuVfKv%V&`~9@3AKzUz0yRL&XRXPy+U2Nv=k)uQV3&hdgz`h znkaXG{iXU2*>P;gmwa9R>gDuYJPf{$s%CGj!!#8QtHUTAhlQ++m|(3cu&XlHT5`(* z#wbQhm4?V$G&~M!E*7>nOo#c|%GQE*RQu+I3~Y)o{Kjb`y@bn_fToLds)Yh~n&t6| z&9`fgOF3@P6qtxsOG06@a2XWjo2x_a-BtHJpL(Xn4Tw!ofRxY;eJuSm8Hh-P{02kq$mB;^)kgx19XMU0qC|^iM z(w?4cp-J2Ir1E~OR6!t;6jiJ#OvQL_C1;o|0&!@YOYx8(i@eU>_q{RBP=JYbLo01& zSsYe7AF=>a*L7K`A!vP5Cb?CL(t?yNS{`3MvkLbWy!CGt5#p# z1%{w#a@-TfImNN=xrxRZCmc({4l-m2o~50orHC}ET~JG@rCkt91l!XA1T~#I>f~+- zW`s-P>QkE%eTXHtsW`6_?Pcoi39MFGC%i|4QiAd?kXk$212?#|ZI^6cx3|97R#1_7 zK>jXqZNfz?=g&96K_-Daro$MMJd)cI*AI^qWPIVG#I`rulY=2SLp;;RKB&(q%m?3u z#_oMfPO@!G<2UaArQCm3ZgYk|3=IENS;9ZW{rVr21DKhaI-1%#0}QRLJ@oAz0j`G5 zrut6iCjUX6hN06YJIYtBet<(MDK&*+ViISBowt*v^{KN9UR+7@c{4r+3ULIqBqNZ8 z| z4w+yxOvzL%a(-YICUO~;vgO*>etmq31@Ok+q=r>{+k8fqjqbxcLF?!W8QyBMjt z;@W{RuFdWQXJaD(5Z*!LqdM^pM?p&aQkX^@`ug9>xJfWZhON>16t@8L6;MCsO@V*%n2K zngE=E%DI8lPN2aNCs)SI`uih)pWm6$6uN6~0N$_JO*iG2#j_7jXD<9a*|klpmy7G; zb9ZiEuZS`y_RhZX#0?Vf(I>>$9@U10?`(O7Q-^hV~peZ9#1AR zfNM|S7C5#K!4x!#C^NI|jnpTi2(~xSyk-~L$bB#LrE+s}rc;w)10mIq2vH_v1(%(t zqEmP5ixIuQ$H+*7j3}o@DXsfWX;sFdt#PCsj{0TrTHc|EiPMHX1_K6s({a|B2lRq6^`*MP6hrZxx}Q(1ap*O;e&*p8x6R_|l{{ z|BALERoTs&`#NDy*!{II$w}K5fCuFO!aB-VhUktr|8=1)SUHBV$rX2N~6-7-qGdg#k;Y|$s z2l%=7*U*k9VRJmH-il4dsEiXz5M@x|OWq0_=uE4ZJ2)v1e#L7#GGU?94gH2GTop0X+EVzAP^vf19J8*3_0pHd&}Zz;A12liQpWX3?Xg}*SJ~U6r(>{x zgeNd!ONYLNks%?tPlk7+w`PBtG$aVWwkpW?`It{FqpTWfNuCygp%`lNONDKDoY7mU zM;ft6r;@htF!0%Yxf7CSi*C!|@@}2O%K?qx%f=V(^^K|haBSWnj&R8C_3(Z7XoSf7 zJ7(SeBgHyDuUU)J;|%y_>2ueLB0R}N9{$2~;xEn9_h(d8vhdpPVNtL&(+uT|N({xj z930=^?;}wFv$D>7`Ef~ez*>!1dKz}&Alqj=<}!CALGYYT6y2arEWj~e@yKL7>%RT8 z*i)$h{}#z1=qR6>*c?*@NdeOby*v`g(ikSD-c0@IbE=howwNsn8D!d(YRB&tOZeNzLIcX zXRmk9SeHg|Xj0~xqEabqN11~^c%x=a%MC@MUENKr(@zS}o zt#Xxw&MMCC?F1$pA)5#eVP8Uc;N}^uE~rfRrMY)Ct|;pMT)%7GV*Sp=3woI z#w_QLJrFXC1uceZH>u(oS&E1Md!c|%%(SeDlatW<1>TTMqd|=4T|4pEH~1jI=5}Qz z&O%&%j91o!$IO+gb0|+lvULdL1f6kSyr8z&pf*7hzteEVt>j=UE5M?-T@su!op>UB z6dbWP1g%2wt#MMX_sV(gcQ%|x)39_mkLLz4YOnQUcKIUZVL@cn9XGugPelLqs!9=4 zP5&QE)|5I0j^-(f)}_(1k#+bp+_%R23NQ1j)$U=ssMD0}sQ&GZpl+5d45555h_!Bn z-~)Zry4+amp^7DLts_Af6Kj7P?(+vb%Ghr-Hc?*y1$J%7^{(aSMl9NbDaT=!MeQF} ze;S=3P;;BI?1v1A0z9)Nh`m5BM1`MH-&bNLj2z?*k)!RPw|{HaLhm)?G;|#X1zNdu zabyw%0W3J4_VgV==g-N!Z07e3g1_6iiggmBBzo6T-MA|v!y*Poins_8V z+pD9kh6O~^z|wvuDSr4+(^MNKfrA~u?}ho(CU~S6L&p~StvV~&k1W>z# zsUR8yWb?wPyx#lYxEp%zW9kcV)O}~%WsEQ2*ZPmpol#!|%F{E3XPF!^zBk6u8pql; zs&`vN5&XPQ)hj0Cg)~?;VSYj%`g?(bsd)Hj3rY-4h^qOKN3S!;n%|+uX3KQ_kazWZ z@Rcdf0UFJ-)m3>l2bdrq?<=GCM7%oQ75t>@9m@$nSD|I(f<=v;myETQ3I#4XW8cHl zfsS=dI~9vL5am=EJ^l&CFCON1Yk`I@+vS#)18zA|5ImX*a`0zVKA)3kuKF>1CJ}&f z3>}}6g*9mKc6?+44#*MTAHAPPOF6(S+C`hgV)Sgw>#NR3#KV!3>;)%YID@X+Hg!rP z;jbLYR0Wee1AN!+nzR@E+3PUDEm0iInGZ;29sa7>gm%%ctnolm#Tgc4QC7R=qVUoX zO3NghC4Yq!4*}_%m7o7I*8h>5?1vI9)8Ifr)&DA9SpO|K&8!XG_5a^!@IQf|+JNmQ zCsNNR4a7LHV-;m=0LUi2_uVhRs>LOm!$zS7Et*y#ekn>xqKb9VS>G*gq~jG?j}1bD z2!iQmXv?2FR9B-9k{H#LQbXqaaCn5l6~P~Q_cayx6eHnI z$0ZGD%!vgd^u#0Xq!!W=s;Y@j1tIf9wTbNy#sO%9xwgH+Oden>;7=U8C|w?u8PJnH z?b8bm3W};_YGeXyrHo_Ad7OB?)Uz=!d%4$H{u5CRHKY(4_n1>MPi6c3T9R_}oX`RzRJ>r9t=6znVZlta!7g zjNab3JH9_39@%{G^>t&)R+FNzb@k@gkDt69`8u)>uIXaRcF(>{JPDG(RG~;$TMYjK zRiumgxv(ZVBK8+r0{YF5 zr&p3OwShY+V~=s*s$p8=s==0!*A6GH^wc|3tnkD|*h*O-V7PVOtk6I#T|PT@`17a4 z25auT?mmS(Y9jSL-RVf|pUcMBCyF^!$AxLERLBQ$R426@=M4BG;w_-BK?)6L2R)Q5 zp_n@aV4$RwCxWR#L`Up#?i2seav*Fc)K{1{Rf? z4P;apg-}dsl?&D|K!c9ymvfxqAf&Ue47d=h3(SiU6Ac&Ad_U8BYxg?@>&XL?^nx04 z$IPgep^7XAC&)oK@csRkD=FSmp%6_TSq=uYLdvtUA|0%Y2sda(sDa(isYA&)EfX9goe)EeCl+q%fj-6svzqqevvFu~3>?t*)q2GeGHr1@e|I`ekD z2XCH;s|0ga^g-ek1U%CF=F|~b)1m8=vn#()s7#BM`*0XlpGt>k*HlX@;{n2dDnxj3 zX1~nqNfbSrxKB;N(19Y{&i*&ffNpopa86dcSAK$smHU-&%mmp!84nz!5aenCnu#!&=4Z1EH8)nfS>Gho^2V|R^y<)y{DE^g z{jLjSI-|=u-8EYk(wD-LDkS#J@Qr69fYX zKfPrSmCgb>PaDLNSO40MS1For^mjk`jOdnPQW!-_Ktxzw0}QQDSG{`8qFY9 zN-*Rt6zH!y#`OSwjzUWIzgM;)o4yk!+T9vVDmARzCgJJ>pBOAfj#hJXqcgT10@mdH zeAnkdht;T=Y5}WX;{oze-=Rj;IUriA?5H6!r&-C5Q72_Zf)69C`78!h7H+eKGgsGOuo&a@^P+!%Nc}86()(1}Z%jB%oGq1iIs}Y*tC* z49p)h>OX@#%vr-=_7T!T88_uHceLGF6gNbpiV!6p^YUa?>--$P&Bcf7NbPr(VWpH~ z6abNSBHU8S-vUyLq-Ge3D2r#|X&*iRy%7G-2rwJ#7bf~Q0^I*45C0=uwJ-!&x%`hA zAZz!Z8StN_knmBzVzxoUrgK6N$91cia#{U0U1WYdAR=}BPqCDe!&c#s&#pqFahcba z&Q#p~RF;RG7+0Z{vQ0I0YA*f9Y1)(2#6`~nm#k^)QYmXlVyx|wl*hjjKq-yW)Kb*o zyuiQ5$Q>DLkRW+&BG)ogLf)6mS}qcr_fjhUBXOBL@0{>NWfmmowS4{!jfktfC}k~y zE8mi2Uw}uaLM6#>*7#7&YHFsMB6TLSp!IGFwa5(&_p#N<=)_<^XP|G`w$?_v4{GqO z?Wy0x%cEVl8Uba1Y8=(n$mDF!F?6|FG=8ZJgdop)y{iQ?zq>&C0Uu9B?K`pG;SuRRk!HBqESuC8=Ommzr1L@AyolJ|gKQzwpZva&N zJTqK4bAvi&Lre+>>)+|{gt}>a-2Z`dpJX;5OQ|;^=R^Br7PtfBf>@K@O>YWfuGm~} zCTeseMIYOcTM<*}+5>koIRD`5^`y#sLx-wcN5rP#y}bOg#TZK&`49fIDX}4R1f2|| zg*vIjSx}cskqXO8FbwCBGsoxPv8RLPlS>u1-Y=haoQj>5UoS`C2+lYt|B8^46ZeO{ zc|HX5SlvX#{h2}0LYC04`746mC|^na7kaizSNGN)dEQuYJ2D%SZ-3*WCT$Nb)K73r z0+KOWX3|f|iW>}i1lByXpid9R;_(Yay=~aKZ1hG>zv$`Li&A}Uu6$6Fv(|{HY>Q+q zU62ovBmt0ZRqnfPZ6hp9G*9(8H0kT`DWe^9$th^uF1wabniHoE2HVdTu&@ORLcz9z zWW+bQn)L@|^}q=cRq76pXGcecU3oaTZ2q+c9(I>-k?X0J@lohaai#sy)~e&oTO}v% z++aRr1JLrdQ%k0Dh0dY9E4Dmf{Ls~UWTgb9!0YdvaKe9VUb{J#XoWU zhumbW%)OFF3g_w0ycgI$3m*WTE}3NJ96O~M&fGlx=Ok$p?}5HF4Z(ADi$u)r9s0jv znN@vCovwTNb2_1@*22sbvUh5fvj@Xvkmc!zY8Jx9wF)<>;8QS~1YKWn^G5rY>TEhZ zGJ@;&+@?2!z8ImgJ4s*?d-7=BdqETI>F?A|Qyy=%{Wf>VxeON%l3dsEHm4i3cZ@}x zSBeWVKz8x7SNX9`^g))%y{kY|cGbs4ahry7*WmkpX)En-H{ERYkf`H0>650EnvP0J6I=xH+mL~U{ zHLWr4rI>H<{~6{k^D%3+{=@3oKV^;k-^$wB&d$ov!qm`2-^Ld3zm`@tYBv8V+I-jQ zJ4o7}_4Ni78XmZ)4~xLNzGB55{v#C4G%Io}Q38aNV zB{QhEc!PZbY6)jO6Qy!6L(=UHedQn|cz1o_vk+7=I*H!=HdAec;K$(iHrHCL>|9&+ zED86<1oo=MYLwE}$6LxOmLA~RciMfAe8*wW5jLx+{%5ZrSjA3R}+&PHh38EaX zX>SfO?g}-&de3wAL85}<1#VKz5*~ZCp;&8M*!ZsCz@mM)(I!%NrtBDVzCm}6T>ahN zjG?Z5XYTB&u_Ej(x%~%MN4DHJJBp_lTsVr`{bhH@Wl%EpdA2tNtP^=M@~-(X1ha5R zu+&Zx6IhuerC%m#S}8^$^$V{SXc?g3nL1hsrR1Je1t!`5gf&B4^nn%q_o}|Vvoov8 zohNhd^fm?y|6xHKg_#ZIlFRCn;{j{iF>53g!Yi2im3#HOC&7(}GU5 z90mzrMX_8n@faI*h5aDOGPZ)lxrjZxf>?t`%YX_#H)%|{;yAkL%@vmGK(OZI9G0a?c zLz46j3egcjfXvTWYkfqeVpa=fqF7V8j%}Q9T)X|a`PTGzO`)1^`beKZ-Z=X3KcZpg zL?7p^XYeS^56DncI}sE)8L0>1oK)yf0{_m94}&z#jQyz<4IrH1hbX2?adse%{$tav z!nd2;`iU5P<{=)XWCqB-L~E)q#s8j*Rj!jQ-ZZtgMOuo3&6AKe)UB!&5_DO4J{}QW zu4FP~bpa-Vg$`^v=vfg}g-50dgUlGtb*F6;hDQwWoRi}6^hi$6{cS!6qniz|qVb<# z;M0IgtL_55g(7f@)%}UN%$<9iyXTsJUwh7k&;_8adL|bs7>;CY`p^>`49~SNiO$Xu z!BlCxa=K8`{QjA!IZSe9JEjJ&9wM=}-Nly6tLQ|8*T7C-_7mGbt#9biz7)5XEIL%3 z2xNw%*@J0x^l*_>!4D&NiN-CXsk& zH0U?iim}*9IA6F^GGytwAHSU(+zlU|7h_x{s4o>G(&6(!WQU3nd~t>~tx9Z}k1Srm zfO-=YwzMI;Yw>bNPb*;+@yFo?F-1!#Bp0fRR-=-!*n@4dDeZ*xVui{3G?h{YD}Hp@ z(OE4_I%6)12`js@t;fSTFY{xQdN0`Ne6f`Qm`mxj9;epBs->`Sz!!Boa&<&>Mp5knHw;m+a@?6x4uh z-+OlYBGPc&eT?0ucBiTB9A3x$!mac1g*-I48Cs6)QXlbpkWN9t{KWr7!t|o?Cd*)D z$!@e%Y5#SGD%qHYL=fGVEJn6z?>F!o)A2&-(`e5pWoYpXm6nqbH-br+NI2Vz(9Jz} z4Enafh%0hBS%-gH**VEd(Tx9xe7J!mZJ6mgZ7#xi2mPRZD1rErRJk(#w@Xmpf7L4g z)2hnF?ZyoMSF4e14p!6htb*5W$;U8!U%)020Df_zsqv?|4Po# zSnVn42;iV{m(|GKVv2FD;YjiNHvn78!xiwNNukaO%c%mH61AC@pl40Cul%rL z@j!s*Q=MX5aS3WL;_axzA|2yZk6?cp{u&tL{VE4;+Uj%H3gzWIOo@9Oq_)eWgw{HfVl=aT4ZOY-K_XTnI z?)65RCI2a4jZP?xm=g6jqd1bCy(Wl!o&_X?59I;vkm&4;2dZp^@EVMR~XM=3@yAi6(@%z zjJb4w%PNVzF4_Qr24U_~Iz$zlV3i|O{mmK=vv&C2k5q>7_99k#VMuqMoxPqj4VeFF9 zT<3x1djiL>Upy1HYt8tfp>L!LG{9Kld7v2EKy1?%yjCB$ih3_!p6q&4P4BmkOnTD> zV8S>$-^K@**4#PM$IPEmxARa7hJC9?qRx7LmKgV!m5B^OuaRE#bN4;@&i2H9KhGcX%Q&mWo0ef=PcNy;^!(obw|kIp1h#g@icQ`wCU@@%kn#}9 z&>=4u3EavyKD@2ofCIOv8FeispJq26Q8vHsIl*fprC(Xd_LSr762j@DuI zv(|LDOHiq07HzY3`sFzZujN)aIfG@tK-|XiQ`1jdhW_?D>*MG?RAu{qY)NMZasmaO zM(d>eR4bOvUarFoWJFL4NClBgj@LmsmqarvxgIveJ+2A7CNlg_(N-%E0FM5z7(Iyu z1cdzGxA^}^HFcYR#YmKYQjO`4Z}^f68bv>_8-IH_$GvT|Sx(4xPa!-=K&nL+ZiI>? zb*%PN-!2^0xKifz$YzAdUgtWQh(|-XJM6ouF-p^kmu*zvpqF@1YiCsXYRWa>;zkARSnA+_m_J;1r^#s;!NDr%ru6ADTvs96Tr4dp&Y-^4 zs^u-Fl+y~m!)3DS2lbc5v-*j4xq<2ep-N9{;+MAQrrsXT!SovE+ z<{#wVyy-f~eq?Iap*(5X@1mqY1Luk$@>4!Y_6Z&nsocmI2`2iEnzhs+H9iMACOpax z`rrZ+9A^e$Ixgg3&DRq*!Gw(apJ&wU>}(k_v0+A}CzrC9Bg}8oy&Bi`w7dHk)M74S zc@7X=A;#Y`=9B0PQN^l&b*Jb|kYHf|Y?f51TE!L^ZVgO*SmN--5KZc7@+t5X?dtcY zxJ;k)Eeh2@LP|j2_{uN4n*xaT3}8yyiQG69JJu9Vql?;9gdHrVQLQL) zq|C7!+3m^kovmfy+PN`(_JlBT#Dpz`3GQ(Jj=4RzH_y9Qmz<;N^75QlbH-PwGM2a- z3n2_NzfbYsU+(O55O;TUc%L=MmHEA`yB)vxJMh@qQ}*8j(zn=y9}?^Mdw8!g0&tB+y7OxL zJirqCT3T>g#dw)i1KO9v(qHsB=V8gA{NnhT+Z>id?jh8!5>~@*AQsdUq*Nt;OQz#8 z>2#$|_M?R(y+BRFS`#ewZym2yJmOW+b%lK@D;mi2`h-u%ru_AN+XJ(LiS}_e)Ms@h z>8am|{Yf7S8qJjEuk^R$o&n22^SIrK?HUv3e~!+5y#ZU|xG4|dBx+T{j?yKZ-?#-U z#Dtkj9d3-rQ$2c{y-D#Mj_q5KAv{Bov(kazxyReQ%@xzBSWf!GcxZpPv6Ynfw}JGm z6U8Id-QUr5k9%AsfPC#~&Pyi6eDQ`!&k&R~{WQ>=ttfS{#IZnl1AX>< zm8Tn!)64@UHptFU-CD;y38p&S?A@7RvbE_o`}zr*i*T!$s3`+#ybjPcEKDJ6^V)~R zmz?L=tGz_Ms?#rmzZ7!ndU~BWZlJw)w!6%Fgx#`RIuG50Xn5D4BAQ#a)>`3=2w}gU zDOaBDub6+0#y1T+lykw$<1dtfaw5=5F$2DL+5YZ3e^?5$-=ko1;?|W;qQCn~3m%lu zRJ=b0MtaP^9tjG;>TXJ0w7T32#1n6@Q)QTiNU}Qu<+@zMUU=kW!uOq*n#{G}bK#)! zw5{nnn#!EDEcwvlJa-+z3j6G1`aB++Jkgf*;Di|`TC%1Zq~maW*Rjc;>UN2^i}R_C zJ~X!MB%MgzcM429sa#|ML8lnQ8NK`0vQ_{h2Ts&>upEdBmIi0!n8!pjb245Au zwk(fpwm-NoG>ZH$UFaQbpr@t3t%vwU(nKT~!?&v10y$6*1?? z@nt4#^(#!`Yc053#{k(--}F_-T>?Qx)hQQ!O5Uxx)#RY*pQ^_BO@JSk5(OQr9N|{q z?wKH7U?8_VzWS{+AYCgzrXlx9sMIie_U+k%OVqWHOxS()+c-_|?$`%Q+w2-eqaZl8 zQl+rDTDY=S^xl{aRklDqY_GyM$Hz?zw=ov->6P&Lu#_T++`XQ6l$7m5Gml!eE`q!_C{)BTk$rNNL9oR=YL_J625GK1{tYp$gn;dUZ10 zvtmKWX+t8M6WC%yu#}z6J`!yvJPRgIBwm6@=!79$%*4FE7NN9N!n|H<;y~ViRFBv- z2tzxDvsAG0Srmgx?6M)WWY;7tcSPl9VHt>(BV;wKeK{JtBp|{F}^|P zU$ER+)3b{F1(7ZNGB(`4K$JW&0-SAvTfp*XXz|sINg-1d^v%{oNCxG6`3lK!OXgG- z-RY@&$0)^5phVi~9`VV$rVlcF^ZE>JqEaen`xhG#u!AxARdWrJhyl?w1!W8!K=pN` zPR7@Nl}B1EsYoXYJHyK|8T!QmFOTXAE2u(M6r=N&sS&+QG5k79_9Z4!0F5Qc;!WR4o~k=0 zn8YBir(!$DD5RQgU`*k}*eG?Z8EV8A@}-mPoI^Eg5t)*9fpq@Ny@NmLz?d6>Ec^J! z;-~ZDj2Ot8-uNEzH57E@mSsX34YB%uy0Xg}2=8X0e4O4-%Ou`-q!9pj@q|6Ao=o?Mh6QX&g`-kbcN05Ci<8jGGr$g zXdVHvk~3@-@uVc(@cq?Ui<3j9zoQp$R!IE{VpvMe*;etHA7i?0x+Sfh$o!n~DdoXU z0>!F<)`cdfa(0&REDJuXBMiKQ@c7o{(3cQiy0k)u_#2{VCR+HXOze!$xAYD zndGj|*7Qw66kF=W(IMEH=w+~Vkl%V^^adBZdo~mrwA|oAx!yP*s8o|WZaOi)`GqvPW+Ja&T_WC{p3NYfgTdfB$;x6_DQog zJ39`t8WvjMMvg>ZaIh?=fO%bod`RPmd6D&OMzV+IgmQnNm`xrfV7nWM;_p|7t zgE*0y)66`j*Vxtl(DqEP&iZnRCmJ&_%tvZXsNF)uZqs+?dCF|koM+~F59iit?gv;B ztQwUa7VN?174PZ>f*ab<-?`$l(2lMck@HzZKkKt`_|hzRR1hgAW$y%VGkw1Y%QGi^!~$MXxjV-@cp^b3sB(IW;&d(V0zv6T!fcAXzGYM zjy`6&jwz)O7DO^k)CUx2Vt&7GC4fXA=|g{(WwwdlJ_SU-?S$qQvh&b|XxXHGx;TFQ z<-etvwo^wd^uoGk&b`AeMDo$6RAR|36iLkIib*|-m86O#=yG?xhe1G`CW)CRj4J21 zl$$3aNBmNiLq^flLqZt*qR`76Fg}$P==-JU~z=?*C|2qx#*y__w#3maw?W zj@njH1Uz02JKnrp{`lZxXe;ubK%yc=dYT1Bo?DU$5QlpaHU?HDI|^>&HzubzLs@;v z1yQ2e)FB#gy$}Wo(KIAW(a#F9bV;p;+qeSC?-UO;jryZmV|F%e*sZN6h?|%8%D3y= zA(n^k<`%X5{oom6dpozex!igD{oo2({AynR0sHOM?&%gm;f(4X%6lxh2`|dD9dkVQ z_fIh`a7!1H_z8=l7Meh3!*;+G#l}7HKR6xySh-00qCZ(*7p9rSGY~;VvP<-Q_Up~> zTW89&aZ@{?6n}!U--H8*OTbW+ImEG_*u0z-bxMd^O@C)~p-YN~=2!EJGYspWMJ=#Y zQzq9r&Hk@5xkA5USqMq2XKZo!3CKw4hw>_7{NqIUE^(|cjQGmb!RiN5S% zKBG9*s7$O=J9fYvF)o-nsdS%#yx_t#hTxdN5NtwR>wynz^aOhnD4vb#NSt?fRm=^8g>YzW6EDC%*q32~K zz^X+}}&~XxLO^fy}a7r35Zlz%Yh?bQ5I6o2~(_ez3Rz98hA}H$E zUd8?}hYcd1H)$aQA9h1fa75V%z8|_#*IgzfDu&0PH8~F$JuCq=g53iO1CN^46VuXw zKlVk(U!$5_k{U}WJm5QlblmznD@7G9n08QI67EmoBKq;47@fg(;(^R(c?z+tA2Rn! z4mZb5)S8n0h~d{b~01;AbFAc zpRnq5v?h%*DFT%$Ubc2+N5@2Cd*j3H)Z^OwsT6&KB?8AC!k#5%^nM>d2><>}5%ZIjp?{oqq3N z8l8Lti-GE6An@@aPveZ`B5blzqlNe|1ZFMfE4D%i6lN&Jx zQ$>UWi2r7ZK4VRUaY>ahEWfbNC-l|fasq(Ver9(FMOO@Jl$^=F%D;846NX@rSy~cl zRdCem_5TtJEXF%wl7OxnAe67-T$z&$KFk7xbo8jS>x&_wvNGsz(xPq<5=$uVDOOUD zafRo({XO5^@3p$;Q+*0C2vdu%@?kQQTD}xBSU;_LkG9;N4@WU;bQdio3hZ`ZLM%`q zNU^Ev;h4~0N5`VG7UpA#GM_cx?&6KJuPt56$u&5=BQq302><0->kgHvdXP7@?FZdyh%N^z<;Gxqx&J}J zNu!($F4^~~A83(#6_N{Vk7-7vckfIw4bH#)b%*)%!q==Xx@Vu{E@S1e*I5%()e^hS za6p=)*l3GZL0(5oGuqRdA5P<)a{TNO4jXNkS{Jl0EV~_WcA!j%TjC?`tjj`bIkjseI}<}-_DbOCYzm}uxrBHc za)~VeWb?30leL|vb-wY4CUehOoa=%ZQ%P}fY&l|k=`^! zc5V`J#cUmZKHHF6V<;lEptUTtD1f4#UuA8X>mI(K%I5B@Wby;QE!8&JOt;M)rn+Ep zXvGo7=5p~+Bhkt61708i@<^y@n8|P((lx9ov|Bh8WlZg7X@ks=)GGF+HGo2Z&E%7r zT1}qDZ3o5A;_H0JEi~6T!uXpsQ(thQUnio#+siQAw*Zm03`PRIa+m*fi`I@MD%&u;e@l1ZbdJT9j zrFjBkTNpFZb8P{l(q%NU-)U0VU5y`??&zieXvfg*QRB*~p48{QN`wqzl(&FI2Rtka(KdsdcB zBPM3#h>7b!Z8Sf>?XqG={yB%r#St1?ks@(+ap5f>wlxmb)PyVHb~exosk+Ne)mUIE zWVOw@vc6kU z`qCiIrd>+K^b{6L|2tRP3jG9d1`eX-@J@#RLT<9NuzMkoavUZzK-BcsmK|iFwh@0@ z@^gXy_VEhv`yG3K8!JAU>-2N*mvQr(XY_4Y==(J)tTHTC^S8%mQIF%Voq*2w>~Le$ zFVO!C_Wwx^DPn!FrT_u}*!Z{RLjE_{|JOtH|DOLn)h7zrW$`0~+ruKn^hrYlEy(xQ8r?)1Eyx|*$>-B=s@IJfnDyg{}K zZ_}4c8Y|7s+(n(g}9<$g#R@KHUHV$BE%7 z3#|;y<>_MKJ==$#rI#Xrmsk%0*brC>X*+E^!u=x`E>Np#wg_PAr6{YJALcsUCOwM} zGj^febp#2cl}IeYYaKF+s9)Tr8}K{BVB+q5@$DLq=!R|Mrft2ljK^&!=h=of$-ksq zTB#9gBRA*DnRD|9i}eoOUBn}H0F^(0jJc?cNuhJ4mb=R!n|;X&-Eo|aH7&kPG)Jh{ zi|^ru+ZZ2b<$I>nem_7BFief#Q9ZS{IUFL{Q9)AXX@~k*6Al~9juo06jAlv$De-1O z1msx))fyfX*tEMYi12NKi;rahq_62pTZB%;(yF=;tscVM|ErS@QDMm7_VhfTzN4Ou)Z8FDbO;x*KJc|DPO(FS1Y zkeDbcZb*vIP7%Kl#m4~vKaCBN8&W$rObKc9p!?X~A-)$bAAV$j9HRb{L2i*eBUn7# z$N=Yr>O3TY62WX;J)E&Xr8O`ZBnQVEfQFo6jv7t`sZ93JjzkX45rBUpBlXof;hauZ zuzMZ}#@c42yKUhMh@fggXxNS5fhgaW{%)YS4HEi8JQ} zokpCslAj)c0ud>E>U?}7zb>mk=7+1z z>2wAX0jLUzuuB|WgoHVP&XNDF(*diP$!?z8`D!u`E?4UN#Ntq|@^Vt-|o{PZC(p&@_KHu8RxH>EPj z(@s#@owFfA1yHFR=+#@-SfXPGCmbW!V2t?4ApoXq!qMp5rddHGbGJDwd<&pemCpi8 zBLOgnYMoXo?7*>3P^swQPJO;hfPx!W+|;1&@jmgRI<512SLsIg=gm8V)MD351LcO^ zgCXCvF+v52pNW{@C%zfX%MLXPK!CS&dfI1$6F}|cDd7V=4q=xb6&q`S5j7u>F3`8! zqC~(auKI!U1*)S0kGH@Gt=z8bml(-X9sz|Z+f081?Hm&)gaZSVD8Jg{F-|~ z1lh7Gq3luh6ZFXkvi^U-_dzsxiE?8#S~_}M%xrdqvw>*#4h$_os3ADbk&r>sv6BN# z+n7{=+SV};WJ&=n0zPGq$H*N7W|9RYn7e;XbO9F8p-x~hjCwb3@sg^R2xv@PG3g{^ zC+c}Ya{g|f^R{{Y64F|;G9y?jdjbW1A%!6q8)44F>@zV3U9u;Be>f(PF>D}gfLH=0 zWp38ux65n(+M2QHwnBpc@^OHB&tTwJ2cGJ-57mWA|W;O+UDG7;=t;2v4 z0H+sWn&Z^Lwe=5~GY?|8q#Gqy&4&-q7{j4#F@)Rd_k_~MdIdSagUI_a+_c)~g5T*) z%DnacELda8bB&`~Uf=7NEXQ!KLAG%+|3IdnK<5kIV3e6)1m7dH_cO?a!Q5l4!7#4= z2-%7bSrVjObv!WT$Tm-0r|AIqgg@9LR@-aa`aQ}^@l(LJw=c_t@*_*mE>hfq2~Hz^ z>W&JMZ_pgb#7q!^_s%X+ zFmX=7Gn^BN4Jl1X&VNlZ-P-S9VVpWa0k#d8Rg$>In*Qb&mhE7Me<+M?BS1pm8<_f7kZS_YtjWzqO#L#(@$sox}NjT&ykvqIy~6~g z8hOk_>=DEnC$v_ME9%L`Z^$_6V^(@Wbpy&2)501kKJ}qi2Jzy5Ip+gA09gd;`xO&C z2N z&72b%*x@@GdP_F8uU7wP>e*kB{$$@cmStdQB##|pXmfT6X4GW=Sp2Jlyej?nCe>DKHI_=^c3GO)es(khIxXanGaK|KV7 zmSVZTbHX!lJ*tAgb7%p)DL6_l;!Q$Akblx>HM|Ss2p!A&8lAnVLFk;zFqwqR^W&=V;p~eX!0Ez;zB*zxvg+jRS|giFaj|FseY{v3oMay(9Pw%Dok_$A!&WskqJhD zdKi;(vmgR;I{z#Dc?FDhFd>*i$SH>?A$kb%9Y8Bo^ySxF11Lp=8;exu$Ah zqF)&2dNTYLf!*=MyKjZ26l*rM@h(j<>?-_^0qocyOOIz%_cZH~ewBLbHMs=Ck-tm8 z#O%k$iU1Afmv9oYFck{}uH9sC|Kg^vu&fP!z|Pjg!;l95r1SUn6@XhwXT`q>XOk7n zc##8Kgy`J>pa>;YNl-OKr&SP1w#d2`Za?C>sa!rael0K~A@1PWa?ZUfz+!(qN+IpG zetBla(=w5;g;lL)!pHdY6Ottf1G->rw_9Ga$Esc!_O)mz z2*T+ZXJUphdO&31%g*R#_qcmn&wrpazcK930Xg8mvC0NyFMwbqWicPVGhKc0yWhYa z_C0F<8O`7NdkKq)kPj-2!;7Tns#VoLP~%akUji6vqKoEWmGf74>eS^@mezq7!3rm+ z0Ws)GeS&1k$vl9rA4AussUp`O22x5+t5$78iEmUv^khk}0RG2*{SyD5`FS-d#{*An z)umlWz_;{k;w+Dq?C8&vw-Gsam!eH!K8BRkt`u4zC35l$MK$`kQR6ChqUzPvJBQ>v z8RVu378n5$zm3?DKuJL(q_U)ZXpZp*MNg@;+R@EQIgYi@p|;BEqtq}y1Y!u;b6l|9xou&v>*wmqA<>C!;$O^M&Qgv zqX(S+*mHe=LP+KkPh?NxWM2lfc>9%4$r{im%8?ug!c3uIS?7@e?922S(bPE5-SjOx z=`n)ylk^OVB|E@5CPR?XnY}N!CoM5UV%R%G8M0C9QK`-$Pk78X0F4L-2ylYmd7Qn; z6<0c?dN%`5gIx;#*j4rGNt&lxz6Wake6?7igB{53z%1`?aIGjYE0Sm`lPCmG;iyMh z>ik3oE#3&BAaoMHjw8x80q~*mO&71|6PtgNgo2T?DYxYZz*qOZ{5DIVtyGirukWB@CvEy2;%t!>?mkR8h*tFn|f%?^QMXea@~2M#(MXE%;B9nMsI z3}xF<-aK)Va%SH1_zd-$-)6p91`3);0sT2{_ayJt7n86Zl>zYzuSp15FOm)hzKQ=S zuQ{HV6IXyQo>)9zh^yO_-Ct4Fon znlp!ON7uOKsFnP}wxZ2kc@OYIRB@^D_Vz_l&*r7u4WH-Q_iC)}O~v10k)EW;)J6$J z=mqK5`K1b~q1;R(Sn{Ehz6w2A8|&*>*iVi*VHGNUn?6rcdXp}I2tj-XWDr3Ikb}T81q4h2cQd+=eVRC_R=Ft2)zJC8~RLw+ZIEyb_X+H7(}@M>t?DlnDzgn z-HZm=#PCJf=$aLOTCMzBj!o_K{n|@xz3;OWV(VjW7s~v5t@PuqJb#@lx>@X}ss^iN zGC_=tIpl|>Zk_M(+3(~npo(rZy(+YGb;oGtOjohhX%Avk!mEM;WhT)M$ypkML&NvDo*tm#LgFz)Z$tuD#5#TOTB7WueYJ9v;vMHtK#o6 zzNEyfTvXlRMCF(U`GIR6RjebBz}7?}vUUhm-UA=f#-8yuAVxBWHD>H| z1p&ndfhHjqk*6J3o8^aycj9IONr;@SVzLVAHl&pYKwfPa zyI}wuYp;Q{``QHOO0UyzwG&QR#IMt@wNs8+09>P7c&mjk_E7=n{J~t__|ql9a(tq7 zLdv(E$UADdSe?8Rl{eDeOj$i!I~lZU+P zkHrX`K#W}fcY-;^AQ;5NDbMBs)+ed~?@@iu1QB9!*OVa6zs;f69dgMR z4HM(Ik(El6^`~1v*-qYM2E|+zPwY|gUyO((Q!~u62)V*Lp|l|4uNekR!?GVLe93iV z?^XqK>%~-W8RsLTdiN}~UF_!VaML*eJcdei@RbU)`|Ii4VPQ-hIo_7-siZ0IGhSxu zA|V;=hreBRyI=Q{lOJ~a$6I6Ok@tsm5EaXSgk>!ITTpfDP>Up2(mFw+MzZ0HpGY-v zkI=>9N2li8WSNkfXqWX|K{wiyXKnmyfI75iy#TFjZZ>zRIwIb@@t?IKPq3G>Xv{WT zq#J0QCGnTJRym7McDG{*95+EQfOrNFo3#>Xl*5*U41=CP9iD+d&`A;{MJ*XnB|;Y= z?DNW|19&OK4pB)!5Bmm&f^V7(`Ars1tikb zbCJ=4AGmfkd=E>Vy#soLR%~Nqy?n*7avNbh7AzXB->#s7?=np$?*kG>m7>Mv!Z?0o zMB_cX0&Qm~868SYUk&-FK$jQRHPH~&=Hc&1$a;Q@Yg ztC7k{z0^}E&)#rO@|xV$U~8(aSyN$aE46hNyY9-fMGess*;iF?_Y>(~T8FUzv%-b} z#JfEJQpD8at(Vl5l8$fFA3%#n znZ|TI;{!B6yP<&v4=O}1OfHl@LfB4S*KeQ60=t#!pONxnpCVLvC^H?GC(9=yA?Osa z{`)&`8^y(C zVYMjv5+K}tkaGcgo1&u(BNxDb0IK9z2%Zzwc2@(HM3|dh75@GC*~Z;7zECeY0EGPJ zgww7j7zA$|KEf6)-AFbzZNz(f+@|lr5>P--M`S1L8>91ZuXqV|&&1QrGZk;oMB=N! zqjSx{c*m}ozXZe?aaK15Xc^+()<{?`jQmbM2{Rt7wJK}H#`P0lLB?ds@ygO1q0 zBo;pv43+8|2ncve?g#pWK1NeU41>&*acE&i1ce>kBv!0*P1dSPRhTG4))@FD!PEyt zbY}S=n_>4r#I4%5=?E%qSvoQ{95Lxv+FAnPoDan6)2%KIt()S9ciaU)nF0Aqmw+X?W=5jboXydPUX02Wh1>p3mC1n(NPLwwZA z5j3Gmq=p_s5hZrFMqZxzPlBTpApKh1b|#o~-Iv)^;Yl8?T9s}AQN|~zEvB~YH@kb; z?^UT82!pyqWL7D9e-R0{RlQVrY21IL8E&i<{&GU4opSJU7E=<=kyJAFhqQO)q^#wF z!EGY}oP(x{9VPkwa)3?Xl34S&@vL;fl?S)mP%hi$j9OaawY3@i<{U!@5eihdHZ`a?#tYujMmnw$&^9#K8iunf=jrVP6L=$p?AyFEk-@JCD%>bLaEq8yEer2SqZ|~ z+C6lZ`BWfAA}V9rppLQ*^g2scKcSMXu9!-MCY^7eIwxV`2rzHJNA zEFt>BYfoc-Jr#&EB9`fwE?q#RkIJ$fij1m!KvtS8pzZt0Xv>4{gl&IMZO|vqLOh$zEK;}M5N3%JSK^DFj$b$j0Xpg ziUkpdT$&jfpXiR-!B)aX<&RcMTD?7@OjBsS$-h`9dPI7gQ8b~xoa`LHd=V@c?+I`i zGcHXkd1a_R`5C+kx|Af-ENS2s8#%tKet%L%(;Osb7OR=EmWfOk;kk#xmu}O}hb7Ad zq9;TzGB#HS(DkS5OL>EF;6w?Z`ZJ=s5d9}XIU1{b_VGk-v$kBwsVxZdRNJFJBScjs zz9qmgki~UooMk^#?<2@`>aeBco>jj>i(4hR0$tlmh;A`!R_}JV)8p7f4E7xqW^7Wk z6|2E3I?P%T>B@*@YTC59f3Fz|4g4lLmU?V25|;F&33l$B8`e2YQZbXbgSJYQ&K&RNLRfor{aMx8QfDl}oRGVp`|Q$;JQn_T=Yp(dj_Qe<5w#K&7T<(Yn>!t#5@9O#U2mI|~>FmJv@=M{*bGKKUnQS7q z49ObeR+#cmW#q%C!Vc}G_UIK)id$*=ippDDQI{V`ND~#w+wM)oriwr;X`+uvz>421re zTLXXiMaARq{l089@03jYs+jC|4XnRxcW=M4IBqjnk&2$OvnB@(f@w+K9F`+}0nkcx zB};~7-~^$KPGmJ|G2{uD&PTf9=QiI-DYXd8vnw$lSj~N9o~*Pi6~%N!et_!&{%pDZ z^{wc(yv_NTe<-YtF@T9K`cH9aiprVV$#70k1pjx=CwGC(wM}NKQ?_7m>t9);&7S-%T8v&b^ zr?L@3bxU=41_1@pe2&obSK(AbN6rTfSLcaf#}NB zH_Wx=5YcerUYZDsrE-d0hG@-6r;#I*yv;YNUh}3gW7rKbA^`VwFm$xdSOiiJLoDj& zKqGY2<>A$|7&TH{9K-b7yXpyV?k8#lm13XV6-KpyN?xK?V${c)VK~s|$9wbWJ<|Q- z)23yhKcYYP@&aj7cUW(JJhw&Rq#Ht|$XGGom)IL(BJ)0w{!{OFO*Y2AXGYm|4WNMk zGnBc14Kn9pgKqTXH3_|0M!R|$M%{h=I+L-vBQUPYKoQQwZ92GIf6$~`RGsZ{nT_XB z&2Fh5CoWhrw?q^g@Tu>pz97u;h*W9-&2w>(5t#X0NQMavv&$Ze8M_NY%=O|OPj&XiiY*}UbG|-@l#@jHgre9+E9uR%K|Mgk& z3k1FpLmj$ng62AY70L8WI(~%9&{mKxl_u3LUgNONy*E$?{Nk8OqC+9|f*WF~A@g&9 z4iXv6X`xAQBxpiU_hZucoHjqY6Tp7;kvk-0d$@5(vWb=_d|=y0 z!>#YGyU0C1{?UyAOONL5VYi8fI<|69B@WijKOoR~`bwpA1H!ll)h3&#Cc8;OnN0+< zh1lpwv+bQ^@Puo>kJnQu1zBqIVGPnea+}gvrX$D=bMARDZ1QJN$foxaH$lF)B>V@~ z{?;x2R0VvMS6(OclC$&Wl-@Vi=y>2A|_FR14cN3Ks?GMB=z0U>u%>NEJ3-y*GI&gnscZ5)fU?zO*JYV<6W7=V>; zUVA;XSVlV>5gFhFinw}8Df=@jk!%fl7&4UB4Gl74QHts4$)Fc06~{@W&0Q)RKnL`6C!u$hj`^HqsvvQwYxVT77Qb<0bbFG=yrt46e{KJ4Wb? z{qa_f*WzPi&=QcX*SV!L0%_L%W6y>RdOE}2esZ%|2wA1p;mctX21IOehOaJqUx zS2tpa-9)9w8CdTSui?$gS0En#{-0#`Vi=7&4R+wD@-k&Rp}__iRQ>bT#A&6*$qQKC zzGp|n$*lAb2_#Xj&R4JQooAV<@3&TS_b=>MuE#qlJ?rkIu-!SrDjpZUBjTHtDq0Ku zm={;@d8(Wc{IN#ArmW1Sq*XApmsx+$nBMuqjq3KA5Zd!2Z%i_Q2yJ}tAPU^^JlXlM z9_`+c2`*jy%uBt=4_VxMj2AyiDus;#up3d`Bb=-7)|`0a-1IPf7wTDs4(k+#_J-)B z@qX8!W(mTW36(&I*MC}}Cs0<7W1=`N3{GLC_N%V_RNvO;ea+?ev9^jbYb{1{ZR1}s z3_52srhV$NyO!6mutf&Nu(%Gf*IV4}OBu2xaJM+Tu^ok11%`G%wQblDF&~emA3VM` z$n6`}_`Cg^KF?AweK}jPEz3gsZ!F#i{rWr&%;)pF3k3LQRTS2Vd+b!PIVyS`$4hWW zUu?4XBd;fQwe}-qEnzGPS3t?&?5a0}$i`=R%mzDd5r7n2wWyubxU0gfX{P7^YDGB& zzH3}!jUo=Wh}G4Jc8v0t%HUYq3%@5SkMLZ;N z;yM*Q{APxM?Pa-8DD9@`$f)!}osnyWuM39_G}S=*1VgJgS>Gd2(|Vh7Sk;>d;p=YR zougMd2;tub$Md4=c?*az>$RgZv;H*KS1E6G_o}FL?)vzXr^UtAOvYArMo7cU*XBM` zi`{kKueP6jlyDUiXZc11@zV-<`R8l0h1ov6yp+259O|hQVw%M^MsV%sb-RD$Xnw15 zgRoPo5cIiVPfB0#s!PqEq z@?PZ#{KEeavIN-6`mYRN?dA#s000>Gzv2!5AGZF#$&&wzL$DCP@wiJibznu-PPi2zVu9yzL>xUrD@ScVj+{hK zzruZi3n1DnKOm%Bh?gqJV{aeR0MK(CCjC88jG!SkvJ=N4VXu<{|KqhKD7!{hB<`a*XPZT z*~l?YV!%9ZAY&HpMi+v1K1E6u^^`3ErzrTh3B~uX*2upI?c6L7eY(`dpeQb;ZY~19 z?zXI*c7nsv^Y2}|o~?C>?Q3SD(~H9Nk8K0T6YD zOp$B?r?6e6#$Pp-&%^z^Yk}4cZ<2Pw3KzUI8nr_F47ApXB`KsG8<8zz8tAl0CE%zDglx@Df1vFbcinSgJP1AC2y0ZBa{I3lJo| z-9fmJGB4BhV`&{yNzZb>jn4N3d7lA9?vZ_6}KVYFWq^+}w*8-t9 z^B^01(I%#dX8!2br9VRa8+l1os&4sBV=7yg#&MemOs{%Y}TUO`V00S^G`aq%3JQu_rCS9X7K2W|JD+?M>xa3^Ud~0`7M!&^t9G zbDR0$u^Os}$0dbG?G4EU7ZajEB2O^$Ri*_ku2Z+VM(HgMe(lvF{dK+FlmFU6?#zYo z8uLAkvqnOnkmTc5`5J29-gA#yU`sZR6!1;?+`x4FnxU(1mMQ6AEM84xGT(;F8Q{uW zy@8H{*h?9Cb}`w}QY+BS4%*dDgxG!?3>+U4QZp8RMYLm)Qn}Xj!=hDq$Oud@pXjrvv zLr{4^<++a%i`~LxJ5d6!VrkLAIO_qEdc&odlG`+*KpVxLq`j z1Tgu$4+Ygmo{^V~OTq&F!B|&?&2&|YseZtieoG!QO!+g=5qlQ@$c<>h&H#@A{8zbu zdGFi&N(C||RSsejH=;V>6ANZ@VO!Hzoku=<)$Fz)bF-Td2}N;0BQEWvPN>mnKm??j zKwe9~)audShn@P^WT7i019U78OAm_2h{7m#R9o?~(z6-VwsC9E0rtEC*#WN#$~=0q2PbRKQoZ$d zCc*ARw6sk-4*SxDI1Z$e4+~|jzn{oAV|`TB;SltEMK zEG9Mlg>q^n01aGuC5n85A(ddr(b4RX?F7VP6lWp-;W=l?MWf1fq~Yn)9_H zUl$$*!Y7V83l|6K%q9M>eO52%iDHODD5qv+LC;{?CisOC!Ry zRXLR_`9$+VjIM;ix}76&P8pf?2Fn$inwBed6$AVogI6L?0BW(ks&bW53d2-2XjV&C zu;0GS%nxyNie^1NIbR(jT^mF#-N-KhB-eI5>@H(5?qwAK;zENaYhc8vVpFWduw=Id z_P7e=nrg=%T;yPXtgBZ7L3-S z*N?k3dOexbdOCpvO~=)i!Oc_Qe4G}2OUJSIXit>yNKMvnsTPB86A*vbP3Y}R`Sf;mcBl3>>8&|GbG|=4Uz}z@+Mr?)$KhkuxwZDeu zbtWK{9u?z>st)ppYEWH)pc|@KJqP!=E?~&44%d{g9bXdJU`J@on0pYHbyg2fv(TPT zW#D+<3y}dDvga#lptAr8;ZB8y_8d!&Dt6eB4FE5yvKEdb4X`Xa~K6VUiz|R2?qWTTkpUg z3e+uW#yPR=oY=N)+qP}nwr$(CZQJHaGWmM?K6Cr_{E4-9RjpNTt_MFgs~*DJ&_`d+ubSsq5NNyUhtVXv`b}14q&GGN~OHy8L?pRf>=x2oUFpi*^M30|Dx!8cJ!6 z!cE6EaDVmK{&aAz?Z1gKq23XV?I3$dq2j9pAvsioRVeXsYM<}nh+W`Z=emfwk*Pm~ zP*a7t#wTeBzup|W0DXu-4a6m`c9f&v0J|g>NAG~ky~y)_DDkZcsY3e6UP`jHrk-yH zkEj%w2>aVmM#Q5?XMf(USHi#3cm5-dr0XsK@vjEPTh6BFmh)g|ro2q8CxTEFZB%S< z9|jle$7u!1DoXI%7jaPCa+^_??=^&X7xKl55^(WFiFk(qRBZ^Ka&lj4WKEL1{#h}+n^iZXm>55TA(lK?^WVI zoiC=pC@?~Fl!;gmG4#GrFAc7yby@Al0j{8C`4m{i2TCY0JN?Pu=(!YSu9%`` z(BJeorKL2^5pxPqAn5##-R&g;)FX0soLv>-_#M#8P}=jB@c_5Xsw6p2ISE&3Z$~s= zeCEriLx{+WP(LAH5eXiGmh@Ai@QhYAdZ$djsbW1^R&|h8dd>;<#b6#B&~)FQl-fQk zr*zSD$ragKrbP==sS=cr6i?$~0K{M(?W=Jm5WOyq7fk9WJ%WJgP#la5^0|`j?+~h4 z-_sVxl1qdJ#cn&$%R5G~qZpWaO}W6ItI<^bMG1II#m(-sY=EnL?}WbJC>*F5OCHK_ zK%W!7L=PB)BXw;nw=qCDLj!|MrDrhIf}c7sf+{mn^~`XanP70-G;L_lRlK=v6T_a%@@62{-G52~+{j}@!lew*azHBBo(8x^yKr9-pYo)d&-1DFAp zI)H$X4ytT{QZo($)Y3Ig8tDac1PP3G7=TQ4fi8+EJ`D2L4qsq7LgN~p4;O(9YRy;b zQx>_iI5o_vFX}3#1q}|l5cj$x5Aq?gJHq5cgtsQn<92B?^P_0cVdiW(4gyx2X7{fFVJlWcK3>N-PT@SkLhbge9O|!O z9_tlXT*Pa_YN^f6>zCR_9U&Za${gr|*R{ISSXU3!ZcX#`A@r_sAhD}hv3nm{3lwfI zzm(W!*vnTBMGHl!{P@R(!$`qerk)<);ki;1K}$-Zn=ZU0yz^HZSuTR`ayj8`8zJuL^9k&jOj3LM#c^{i#5;Z=zGt)0gFMbE zpnS;@%8D~5pfKsFSP?Exrzgg^z4tN6G;=C-r)A^Sjz9#6$xPb*IBI88!x)+>yWmAo z5D{7d6fnJJmOSt$sro!HWVD9KI=K(y@r7?h%8$=<+ni4+-xG6mWOJcwk8W=7Cm4k@ zVXt+FUxLEOHM7=YBQW^lBpe+aHq&5OOMyzE!PToLHe5bajAS=%8(Tm76A+K6j@~5J z5E0h3+{HIf5|&Pu`9pw%yYCYlzCC<#Ezlu3w1N{H0y<%42~%20N65t@AjH+Z*Nr=e z+On)Y*@Wy_jg}}ygVR3J##!LwmlT|E#OVGhJnXppADRxoK^X8C9KjHr16stZ{YSHL zFs!~WE5c=gNb&T;+NgZt=h6?MjlAB7qB$VX_TiAt*aD3_2IjXDSuS@X(m7-=FX_>EK86#V=e$5&Jj+)8S(ikK|%Yzs(kFX&g-^vRZc@s{fi_-w&#Pb8%n9O zcQQHnfMmG=y9-2$v;Ix^>RSOgo>boSt=(9LJBs^iLL+K#`Vu{&Tj`L|gK9dk#T{csD%Zet#eA#0ylt7pt8D4M=ko&J@V*iv%D1e?{BdVoYgIc+pZaApw93YxK@48RTR)xM?Yi;cR@7U1)Tr8$oFOo3* zr`rDaXN&fKe{TQF{k#4x_50ub%fDSj7s$uOn>6(4O_=`E z#c6x?i#qqHKiZI9DCy%_;)74;+TnJc-4?ur$g1!ve*ZTyaKC?#kVT);yYB#ol%gX} zyp+WVT>BoN$s(zT52D3k37k4b=N%Z?pSTC1Q6z=Q5-2c}2Ie`e7q^J~qs&SFZ@6h+ zV!+=AC2u+)R6cpGMi*Y+B4S{l`mK}v6n50duSK0)` zfZuB}6mv6jHTlNr8@1&oU^8yrrST~dUHHNyusmi);OgMw1rFh)02+StJjZFfS&C5C^7dl~+(7wsZH4K%$tbMf^s+A7 zVEK&hv;{CFQsN9|gbw?4Z&;5Dd-UvZ+wS3XcoU)vOG^h4Cs;}T#@31aMs>KH?oo|c{VEURB@n~s_u zXDnUE?r@Chk8OxfkR`mu*94k)JKdM)*Dv|lb^zKXKajH+{jV-@8#+k33RIrrU&Kf- zSZb2Wre{kV@;?=zN1rV|JtfzpLa0zy*cN*lC(j;@A(F<($(l{XBBjpIp+tRFkI8!r zIrxb+%pVHZyK?e%-a?(ZiqdFP3^*YC6l91GK(g&t0-8Q)^!D#Fey1sb)eOpU=QcgOF|ukLBtNnFxQ$zTZ7J06dz4fmSap*$*V z@1f?#18{hS7LCwt-nHSa?b_G2lGGm)I@vGYRv*pft;0-Ti;lBIC;?k?(H_RqnG%C> zGQ!w03M(`4krn((A_~RKlGg3w%_WqL{qF0pIS!^EEi`7sz%fK@wlc3?7+Y@%T6a2nBSeEuQFQKai&Nivj%U>5e93S2N&`W>^inM!*%}_d)v~V8i!BDs z7V<59&Z(y*;5C}oIKsJ`INR)w}w32Fr7DVvHe}{B|+zVZk ztvWtOWPSD9$#Ux?1^CexF@o9q{~m$-4AQPOiryjjj_lzi-&-6D{^V5DQR^795G7_F zo>x`6KZPVvi<_opCSLmN62`n`0+ISnj{#4dC{zEl%ld7+wfJr}zXaHXcEzi&4#!uB zNA&1a+-SR>F;Q|I`TNd)i24 zQytvr7PKiMWu5)*jBT6{fZw>8s2egpr1Im>?rEvx#` zg~G3kAoH8E{HOEnzvV23R_1ng<~F9kPQtMIkK*PC;?K74FwYx-OD>gFB41uHBh4M+ ztR#M*IAb=nVQz7vS9q1;jm(-%>Cfk8H|KhTe^7}uULtf#b7!kAZ>Q(gKsi;4M3EwJ zx6JFG88%v-NUjh|TxrMjaV0yL+94@)+|apMhBiVA7*#Dt=SW=^7pzV(gnO+=^a-vL z$f+d3UOj~G2V0=GWJ+oy*zP)DX~MPpS{Du8f}mn5{gIS7N#NoWG66F&p%H&tLLGE? zb|D89qkvJ)8 zivUehQ(5FO3I_Snh=~hUY?TYJ)-_0ZoocY0Uj)!0oOhDGZ0KNQhueD}$H$wNZmYc= z7rI(-Gog0wprxs3rh%9J$Dll4;2&tc{l4I6rhs-%D7b@Y!GvG(;`}yE`GEwdSdc^u zi(>R4^X^FkhrS*-Cj{@~- zcUyc(;RX23Q3mZIU#2qgb@Od;AytMW#<044Ix{3`p?>uMDdrFzRFLI!tLQ{CSV@SO zj)cSbLF61YSwUl+tg(^j4W4zY*F2_N>|!Ld38&p>=kxpWq&!?~Ubw`waE2?KgwvSe z35;}-mB<~K!-LM@qpXglJ7$Gp=EdIH9|#L(TYyH@P?!NLlZEwrorkPwwTc(Q%IM4CsEWQ^s~^MEgPX| zGJH@Qyqql*xT{$5Y&HOSl4tN!QG_Dpm~JDcV(U3j1E5sN5TzCvvAeCRqab&;%cq(5 z=z$P(s^uj90L<1k^z(P}Qxh0<;VBrXuv-mPPVH0mHI=~U@Dq}yp(yUd(&Y}c&2WdG z)#!hVy!PTDvYCg(Vi4Nt3W;n{rRvjd*a^x|+i0%}8^HS}CZ|<~ubc3FZqcUr3jT0? zjt5V@e+f{E=By>rhp`ndgiU#8)EAhSpOG3iicjHlmv}*16CuQ0*ywm$fhKRMXg^9H zi@Ev^QB?Xw&Hc;0T z@W%>}1+EHeAN6lI@*5l|MZH}+sA1(o8b10y;y&QX7K@1N?vXxbz%^y9Rdl{%s#Y`A zg=&)w;Y6jpC_`^?lhC%h3cda*A3_kk+9fo@OacYBxOMlb6LkvmV)a$6dsWWYuEhth zGO*^^Ln>uIFDYGvxbRiZAFYgGF3~&A`tsZCrW$BJn>7fFSfOHG^2Mg3tmA1&BEg~L z`0tseRM za4)ejInJ=TtugY_*l%WQVLXUEh*d-klSf$sYf@<}AVy4*8#fl+2gTx;^B4o$2`&NU zS%H1%BoH5xn*PBFpePAci{-6Lf+tZ}D>ZwT#XA!P9@MJ+F#j$i-t4dKTX*oM>2HYC zdUfFc0a)&157#uZz*E?cQ!eA*+Qg#`0x_=f<8chFXvuth7jhOUJE=A~QpN@VH zT1fRv1{<5ESc3c-T33yrzV!R*lU>n-Dich`>EB%y(yz|IidFE|IGQ6q&#ne~w)Fh^ z#`bt7C94S028T0PyL`J7KSNEbcQ*toP(-%o8~kxXmQ z%+Mj35>mXaQPT_FxYq2zC=A?WREdYha@AJJKdQ_T6zk<%#i_5D{RCKyD>T2zr_RDI znM{ViRB1BOL@zvA1vIq0gRR(iq*$bo_8ix7^c@$ZcMlm6PyYmp?5n#Fdl{Iay3ltIyOR~Cdb)Q z>|nB0{uh(nvk5&njTvCU5ZEDtYbhZ_puwRHuC?SpEHLy@kEwH1MFv5UeZEXbEg7h* zX`jc=tTWhXe3k(JSha(RN}X!Qo6Yx{hVN?zxVw|}=Dx*jZKM}5`Ya;)Vgr{CwYA9O zppz%%?NBWt)3X))tQ^|SHdqqqwE+XVHJ2RPO>d~zm_;@2mUF=R`9fo)Vj?2N9$8J7 zO2>TSg7piI=Xly#qR|XXl|1WZD-OL$-?pdgtCZxL$tLpROsp46k)cMEI-X!mYPE~$ zr(^L1ylSm-y{JrfZd-JaqhZrX6&`WpykO>1>JtqJn7%ZHPPczjqv(W#lW%Q*WOfHHtgrF$1_ttFD3q*Iq=~< zg}RU0YI(CO7pw(@)9ut-WWxL?KnNq%CBRG_K&S!9BQZ@`w@j@kFs(3$b?^Q}g^=V5 zWr@O0NlpMh`{0`#+qS%M_Ob8}-qX-2nS?$)&?0M^$8SZh;igI7ah8#8uUuFt8e}(Ft>+Rb6 z_REv!ksne=xvdUf8xyy_H_bTwmd3~i*Jz;*?4Z44d}Y}yq)@1V=N+ciol zXnJU@rWR|46o700MJG+u&;R&^28P3a&;K?L_rU(Y$ff_EU+8~H{%w{2hTn*16b)Y_ z?rV7OmsDI*JArO#EFc3*PLW~3ick`xi23bIN;=wzX(Uz;1$>MwKD_Vzyy3xZ21lz< zv*5|pWMjYi2Czkl)NPm24;a^?$Tz`;(Ynvll=^!H1^D;WS}zFfnHwXpTQ8OiF$qkO z%s2w`3Oir7F6jFr8TMbyh8^)f5l6@zd(%a+%BQG+S_Ps*3rIOt5*S1P)>s=~UnmF^ zl$JojgMb|gQV?jXBVGwuKGU8dOnym9JnP?|fSm?sUcD94D<4P8$j8FMVFLvET&tp| zD>0>*lQRoDdvJWS9%z!=T3eF4d|u4w z(8T^@^cqn+du}hzH0YXN^aJe2VC>zb-dhW>SaAW}o19rHFY&lGhWDjjR9ie>Nkx#9 zKg6t?K-W=~5o1t#hbVP8EOMEnjTR$JaVlbj9gzNtCT_fgzQ^tpT9ZaK1g1j1+}x%!Q3| zqAQm?j?7*D)|d|?XV%9n5j&7BZQFdC8Txi&Wd+7;CM;+RO>tqZ0fBQ(UbcQic=p4S zZ&_}KQV4EU!s>dRwWg0D73<4LEeub0QTo9Oujv%o(VIFOldrDm?NKb&GZKk~I%;L} zFcU@k1IkzP=yuI)6AFtCdQf^6;taFI^4kTO4$jXu@$Xuu&W$=;vsM680qOLx;)Qof z(q-PHC6A-k$O-|#P=!efQ#jlwtV#=JaJ|cv$dz!W(ryG9aL9f{rtZd{+#t0H$jMv! z9P^Z(iyBSs<;~XD$R`v|cjaRr4QI}T>K6|gT4kL0Fx7-W5LgO|9n@|mx9kpBX-x?@ z!4(Rj+ww>g>3GTv;$^#jm=$e)zBz*SZ`aR8^>*NTK}L5NTkO9RHbZjv+0XFBy%K^IFo5yrg6bQb@M;26 z#R|NENzO6`USLoQ8T*apFd2}5fKpYNAZ?1xDo9;}9v5WUkC#h`oe{bmV?p(e$#xNt zQAPiQD!Lr(^fYst#8_p$Nh$Ji0O8_Mw{}-&JE}$}e<0!QwlbKUtiS10pp1%d zD)m^pV~s3e0Q=<8oi$6KX%HpC9Wi*Drg7$PEVKKm8O51;Hhg!HJBt7f|L2yy$ZaUd z_zc0C?5P?ZE6+KZaWJ-OQE=wfV47T)DTJl;>Vl31o5R1HJI4SqO+)>4P)SUm#F1SX zztIQ`hmbztaw-m2G1|VrG>OVBXWw?5j8rRbo)d=wc``;VvdaXNpgG7O${D01Gfr%c zO1QEa6D-gMai1Ps7wlD3UKrbt+m4}v2uyu z?jDl?;VTre<|NTx8m+CwICoc?U_Wraus>Sr%UFXOg;F5I)a0(3{3B(t1j_TMnU{L4 zi5mU{{^Cj;JfvjFI2xmB64ME>`R1#iXl{@52-MUHcu@^XtF_Vi3O2)(O;w(gywzWc zOf+v#b%1eJq|;}~w$4^jBlNENMi`9s-EQY$yYdk(R4@V#5~!01Zs)(0#Yd^`MsNo$ zshB3GH6eo*m*zR+{}lIwh)F2XS3cmJ^@Rj^3D?YYudcF#lC;>v(Jy+gwu38HYocF! z9IF2~Ef}R$>VJB1gFZxkzlI;Tnq=y6pMVms5qg=n^IVUt$tRX*>^8^VayR*Q{8TzW zASsHpxg7s9HjHClW<+~rR~?-uT7rmbPCpqjd6(@XV_j;iXS$~zT!~nOVA1+#NeRXR;Ajtd5i*4LB2hPue} z=J1qe(6Ij}Nc>q?C1%|tP8n&OSq44n4xNsm?epuKde#DaX{Bq;{kl}itUCkF>(xRU zBHCrJ$YFznj7h@;zE2R8ovbgi;5E94q^Xx18vh~aGR-2gPayn0?3hJ^6qOk)3Q{Kh zC9aPnCUyewn6|s%$y!{}-a`AF%;(Bkh{L?wdrKbE_YzIOp@anO3y{a^j&_6UYeKoa z@az-BC+ra9qjBc__1@4SPs-2T5%MGyghr-MYbaJ`m~Jpm;D{$sp1VAnY>QTji;#wD z*h3*Kw1i_6fK7d+a@huM|HhON(l*jhyB5ugc@jh2n!yj;RH0Frfwdi*q`fo3bCfQt zF!IT{EsS$Hatjw%{pbwI$w5GnoMkiVxAU4NZ`+`g)ag{-aW&8c3|N_2xke;~z!wn+!19^KYYK95vFCqhH9*oJ)`Xg}1|3?4F&f9v(4O zh8oOJzo)jIPQJVYX`azRi0@QlXOdc~YgoRGClbmwF8=aX7KLQntAAry=kzAP=|ggs z!S}H;a4xie`9#PA#FKYMWUk47de`oJT2d8>O;%dAG>;g89`=TE8sV_~GPU!?H2-(N z$wT4ue2Ka|(4D%GKTrR||8y~sy>Xi>G^x}@y52L0l{k4fupg(U48d9UioN^!oWg|viMb4uD4Nh=sV_)Zt8_MN1VzamvYot4+yt?d>?p!z}4f>#i zoFHTzHxiTTiqIhrLduhmpmJj<39*c3lwh_N_Ooe1BAC-gd7^8%E(6ri)&x%MCxmZU z@_0mgb4@O%URGzIVnleNq~h`&Fv{-`$*{-jVE+k(Av=Vg_h}5A;_8FI0Zg_oAk}#} z`zZGev?oM!#Q$?ZdJ~Xn4)TGNHJC(klgIW2)M_Lhud_hrpv`BC$HEn>VS)HYii9J- zm9et7%J!Hnl%Zx{CfH>p=7mObd{6%0lmq?6Nx3Gl(K;k#=Z|w&{I@1gyVlfaE)aT< z(VK{;bCG@5f&0=K)@bYh0_%snwq4sF~uL^fEmNYwLV-uq|(ElvD7iL8jS z*Op;*NLnh7MIE-H?Rw>rDVgNw8tjDPq_F1dk3}mNPF%j=PjwhHBcCKG!>-{i?GBqz z2IJ+>OL?#Hr6QK0WICs$i!CW5Edex3wb(@ll{MBHuv2NyVD&>%DlGE5F3a?Phl~Fw z7zjR5tz!HIswE%;0AT+AenkFDkyWc{+Wv+EKBu+lYj7k`At_l!e9n6KHHQTPNTqd2 zsKNSmlnzAoC)Ayi>k@lEJL$*2C7|y_W9&Zj4lXV?pVL*b`we8R4C>X)S1zTyW@X3;VM6^3;s|&20}6r(f=Wil8DU%k zZemXxW%UAM&~K9wWT+bOt@z&wJt^8e@W0d%U(4#r@$>V+K|%bedVT`cQuFCDnO$kx z+8R?;+A%%HBg32m1^j!(WHE>TP+=0d$g~JU3nMY4OhxY=hpjjs#Yb`%-)r%BX7*8ini`P}0#7laio>%mmjRBbKv%bbqo63k%sPDg;qv{XUD3(VI8^1-JQu%d^ zm5c=XdEL#K(TYo=<;KgFC6z04Hk(d(TdRx5P~OwIN9B_1=g!z>tbCq7p=b7=$J|{7 zdkwwpMUKc&l3pY`ID-XXGbe9Ax;l^{yTcCwhT1@aKMc<)q*8?rp$aG@nlC}0LsS>d zIK>`>LP?Y&p4yH)@QaW&UtO@@oCa^X@HmE2g}j7^Vn7hFDWZ+lGAo4%^kk2vaevKEl<_ zQ3$Y(h*`~w@*EYx7}vA^2p1vG3X(?2y$*L`-OkW&jt0dVcLa}k2bzu?^j&2h(+mxk z-MZ1iYBEo)YlUol;(_t}T7%lcC)-GRPHo!9*;(=!`Adq-)8YXBb(TkOd6f*pwJ(>^ z`%UL`$TOnX#%)vDm$kZ^dPu+^6&(rh1c4X8gO1G%JzoUC)b+J<%+I?KSGoAMt|Hh~ zeJW1g|1`O3*ul#yhC#P5TcWRYguX_YYRQ9N32VzDE`1t^mvbqONJgtGb9+137zTfm zFX&1#2v?m|JWIPU&)MC8QCxSiH$)!k5>hBfL+wn5vWqdBku6ic*vYs3T{1zJFF4}F zx^s@!< zVvclHah^3F0FSRM(a_a+#8<%}>U@pOV$iz!%AONA>yQ>iPCf0x%6|nl~=* z5SAtY>wGVEgu_G#J6}%qyNf8*Tz6qj)2GJ;_!(#dgOAF+^Ye1rU{bIezq$o+jao7t z=TpU!@%OpmZ6aFA?eqBAELq`}FitjE=Un?`j2^|N|!*YO-rdmy`^)_ms_KG z!(pkmzbA(htj@hv3*C(KfcH%Np4ogs+ zaf@G0pE0f)2Xxyk6h60Z-Zj@KFeB?6g^Z|PAK6Z0gMOFIDNFcrV&U?MQyNaQKP*X| zd{!hZ`m`)Q*|6WyQqR#oYh2q7X*X2?KMQvr#zq()C=0DMxI|ztFU4-U!YwM>d@ZY@ zlHtUVK38j_$gP0AKHuj=jkMiebt_=IR@XTzwL@of=y%&U+_A#s+_zT%`ueSrVrj`9 zmiA;64nCs=O>#z<{}^69hu^iVTZf;FnaXqV{56Je*zR)vmGqruM1$!=coLXX=h3&x z?-F}>4(>0LT9|ClUAgrFw^k$7@ZW&-eZ6GrXY+R6?f7I&>sVcW;w&sCn?##!^Lu&~ z^(}t_zww5=XY9U!nNEF9Rp|+X5IF0IF z(v>k-WI01J-|s0|j!mlnsSjTwVbrzkFQeDH8Z49N^)U#Q44*uT->x4t3LFR&W0vbm z^pPyM6≻{n=uOft$3e-JIKP+ozYMjae}8nxX9(#`wbKa$s5aOpV^OdGuyAF=a(r zys5bQ_mp84$HmCt+|q+e9cYpb<@gs-80i#`u=-A7Lcj%VS4XkF$F2KiWCQCvu=;GT z{087B8+f>{jiqPjTW~sGFGyd;m4nEMJv1?nX)X#Yr|x*#*9fSjJ?HViu}lB+)Pfqo zOH=eO*E z3&u3NHkRgEvWcwilL+`s2onut#)P@7;>(py{Tlmq)`8H(h@(jI_?o;LW3*)QLFM}S zqNj_Frr>PoY#AlEOh$F$&Me>BdIXDw=m=N^7qEcmgmeFpD|xDaQ~r|8gaBnCgSE5T zs3j)+TwW|pkTzKGBIK#$ze%{+c*yuVuN8u%Q&ZE;YnL*@^?c{b#sZ$R*&Qhwn#A>E zfS|*_w^IE(iHE$D*YFn(q7@*yO4R)ga+(b>OB3`j0VqA1rGYG!JnosH)eg`}QlEQI z4UnSa7V88jgTP{`3{DI=^F}TLy)mYV>UijQB)@k7yVK``ZX7tIkK@NtHKk7W*O9>-t92xX@2xHdLW2speTBTu=RTRwz8ta-a zVEc08Hig~Di-9^|NdlNc!EFz%c}E%Zu3t_vGBP@x$>uEn0g5AmFD2c(tM8_P!%z4L z$enCe(szOYP8{uxSdg6x1_BaQP)sbHJj51*>s=I3H%=MPvKKddr*leJPALyh<4G_v z)W9sLJH^y13hn?|3hk73r|cIL(F_bOc9jtk;CZZM$%DrJ+NF!)RpM^JWH#^BRG%0M z6HF0&gOk!2F7}|Jj5FtqMIB+YpTYGbOtCJWOm6bL3ngYgN%@bk2m`$h7!5%7lEoDa zz;&rj5LDOZ?q~DR8dp5$1aRlWVL33Ie(8}Sv5&D$pA*7rD6sZQxEU>Fw1ENoB$8lK z`CAk>`K$9jbEyE8ALRc=)(6RQU-;3BQyP zPPp}*e9k?Sjig_r%mk1v)E_Dm&XVOCLTk+ir!b-nZeVlmqfdnHGE16WX~wM_XgIx|h=F(t1P{GmKN$sdabMb5$9|H}h_HA_->g?X;;zgq@qr{JSI8<^xoCnAM z@*lp5oEs&4~$Xg>cUm)UH9N#tRN!h^oXj{ymi86 zuPhAEStZyC)QdUi(n=O2|h&0;D*c5+eV$gCKaN{IMYhmWr01BE&evPHF`P===F)GKrrE0~++l7=@F(1~25{X)b| zpblU~N+A&}d`SP6ILMltNFr-?c0?xFrR49{B*~S9m5!vdGcsGK!(ZaFBVAcmhz63G zf0P(8d6$fFVh$VR$Th&}-+oP0!UTsur)EV3IvUJ(^|JJJ0B9NfS{#|m=pc}V%LJeS zwvKsW1I^+%l;R{PE1%x`28T_-ZhSZA#2{<@2tT$Eg$}R6|6qxuzO_XuVpMwFhtV=c z3r3&^|A9l25bemoi-rxN?%--`jYGi(_(_g9RHJm+j=xb}U1zwpDNcvs3z2(|fSz6{a;7kw^o&LeR%WIviy@ zqt{80v%27*8U`S<=%CGRvabx8RRudVfGQzZh`32xLF=5Ep;nC%huC79Sa!`9S&~ z1eL9+Mt?lJrYu@*mnB!&Bn%Mi{r+uD%#~MDssP%E0?)m zj2K%EAvWKmZZSDaxi{Gnlw=NVV99XVv63MMeabkAy_6h!N??- zIzmj&bHEpF-rI;M57!ZrN3?`Dd5G^Wu*#@Ju~QSnI|!L-`;;po%zg`*OW{zaNgf})iaSgKCi?bv{z@+hx8h&X$K;7@ifR;B6ei%5Y400c zBu*Cu$Ux4ZP{Rf?G9QBGm~nG>97V6$Yq;EsIrECih8ndW2#U}7qKKYvA+4bvZdLUz zssCEB2OunB=KI@^rm@36fvkDPCq`f=Qb0YwoHhi~9BELZRzw~unfxh5-r~Zm?b6u; zo_(D_G2B#rqwzLv{`;TBt&xHqm*>c&@!v&6{7*MVM5`pC0^;Qg-X*&K~lSACX zT=Og_tIfP$g3>}MrI-o?(%>Gx zD~(H_tD_KWkZW(HH|!!VErz$VJfhdOs ztL{5`{C9231k_JXw=!Q0y{h(`KJ79WY~}@@C0=tX@_2^N=vWxfQvx&=>}D(Sud?b2 z*o`*|70D@9*Fk1p-=_yTngRuTnzjcm<*&nc;J35Tja|N_+>M#Xy4L+avmtvP5?sMw zxV@eb008d)HN~79Y^`)%^c~Fgos9n%cm>=1Hza!L^cn8^u2UMAlR;hE?zjW0xorC7 z^oAva1RKiz{k3KC#Zm+%Tq3_;y>F-c_m2z7F`W>Ag^3(awqK(29O0Jk|Qm&fg3bqCY$G?s6RD=i<2~C_) zBNLP@`(+?;`*}a2XLlkLNZdzsKtiVmK|Mc~cb+8u>@rua25r?-RI5^e_=pmy6P4)f zk-`fM#77Fe6pJ8c3Hx2dn@N~3|HSqCEZ{!~59>3BK7)E`c3h&@hUxv&R8_Wjva8#? z29g;9b8qJmTNZG?`+7fLA3k1Z#7}GK8do47rW~eqRsgDu)Et4zBRzo1V>sqaNsX$N z0KBchKHp9jPd1{=9rsedb|VLn$QrAVLUmq7Ov#6Jb3yZKa~pY)p;>=F#{IlFdib{h zj($61y4&NyXm~55)63JImYu8Jpjx|rUKv?CYkFEsv$rYf;&6}jdGqhUG8Y0KB@cmB zIQ$7X9wH)MX_O||454lcg}(t_zCns`k5Y5__X?C^H-Z7!MU7xWnovF+aV!#0A_8oH z(qq98avuAq+&9s|*7g=V4IV3=ysPYlY_@_yc0DB(qBaPc1nC8qNhd(7Q3BwVOwn4D z*&<|2E163M05%PW=R$xzgYJa~80Yb>^|h|&ptrH{u63aEJ)!(|ad&;8GhJq`B*xRS z77VS?)ashcNY<0CmKQ8%Bg$9ymN;kSTiETxY$wn{4LqLJ&wSz_bNsJ8`QCPOtjJ2L z5ZV!bl&rvYBe>kPVqpunie`>NKboS7HdzjV=eAi~zH zLd6DGOTc*QnE73{$)TD{XPzLQ{sZg;IFg{LA*I8|qc$)y0H?Bhj3Jpkpsegst7&83 zzjBcCY%Z<{>ZWlZ9Z5`@!NDZ@ky<+zYUUZYaq)?Gv$0}XvTC8ew(6z`2K5JE{3K}- z6&*wP6BU?oWlEsGS^V?Em<@g#*|)L6LrIy;$smOcDNG!rxB@y;FRSZ@?%2bmbSmdk5ve{6S$TT}N1e1gZu!q|1ObW0uZEOnQmf z6jk6=odKmGuy{x@5^^XCRqMBO3&3Z@l51|9Fk%4`L>Za^e#Idunn1!c8G+%G%(NeM zYh0DA`6Z?|foTMdO?U-_aa3+Cep+i`^>9^WkJgBRVQdnpRrF_;T9A9RH)gBzlqD_* z`Q98<>iJG~76|0x@FJqF-1>SqR+LS^GKUV$&-f@Xbr5NjG^rCImxLo^43JJ~!p4*} zS;$#x{do_5-Ba}p3?sz-uU2;v!W#GRZ91p-S$fl~c~w>@57!i0cH|30NWZ!aJ)LUO zYp`My2pCW(s;g>E+~u-SF{eZg??cY~O96^Z+nYtuG+XM4p^VdZ>NL0glOEvH9|)D| z@U$w84B1`2KT<9wo`5Jp zzNk{1%1VB1p+pqSFzRY7zm8yH0;u8IS3MImY9j9`-Eaj2wl8E(1Zf+zf0a9 z*a)|9EZK?V8Qth2a)0#!zy=4q3Am!QP6J`fMrjc>va~-VTs4B!im>Yt(uxUubV<(gHsk4507!?nIYw}4;#8Pq zOw^|kHkZBvWIg9@Bz+EX##(1cA+rgu`pG0dsuAaZTT4MX#G1*Wmi#+rsF$^em+eVF z4}R;YM)nQp_mrzH>mE;6W+624XHz^W&p5|ZBvg4f6U zqI{?4&9k+ov90mo=ny<1zr&N+ zM3EF~^al=h_QC#Mo!4UOjc|4O0bE{*EhtC_6H-imqlP=An_7}ODwgdl!bBe9e#nPZ z^bve}*7m*Fu?Ln4agv_cjK&i28b2b=;laJet*aeO9`2X%ZL1=U)m8qA|2zICskX{ z^cqxa@CtLad`OTWWoy<+h24u82b&1()K*CMq8MhQ`%sN99u=3kzN_eb3pX&;qh3PU zaWqro;4G;N4hw!5BG1Y1+q7}pS8~UFaMr2g|2;$R_KC z#VPYEy=n)nrJfKJPoGHsT^;p8M z7m#>2L2z^dPU(D}>cMI(M9?s<7Xz zAjRTC7~ML*CS;CE`{Df zu1cmQh=Y%ATzM}Zv6&8{Nph2@hCi&;-!Vau)*Nnya8DBs2am zy^BuNVPNhR`2XSSoq|IPx30}($F^-}$F^Famc; zD2Cw4{9QErScCZXOjOw>V+!6=Xv9=uTqRX1i}aNd>Ca{?uBPN z#zwIe#xMxFzfz43upMKYm03Ru%)0yQ&g=JY@{Kh^m_C3ii=x_yWyQuX$$U2O<1p2Y zUb6ZuRJSw~sXhu1-AUET7&aD8}AfG1|#6rDl zGls&(#EH+B>lk@&SWbhBE2Tf5Kwx3Jkzl|Y2^3F6xYTN0`CAzuxHl&(r0k^!As@k` z+ql(^EmKqE>7FMJ!2-kXki*^rBD3QHhW0gfpCw-B{8d76D{yJW0fWHffJNXo8ybVd z9f~Y|@N$gKZ?2g|#wW_Q+eA5@T2uSNf4%IK9|HGd;nLmS@Z^Ek~xR8jxboPxh$s83G1F~mrOHM`cSkcjpc-2>qC zQYd3L^$O6tU_%L>{jsYxIO450w0BV3_X(0zKnO}}h|d?45nC8YXXP&d?n z8ePz^(H|9t8v{NnD|C)-6tWme#P-uu$RrF8DfRe#F2$68G~#~yty|yB;$XUTHIt`l zsXu1hX0u1MR1|&px??6X;ifPL7fGXhO0FuUrc)j(vmJ1yX^?7CPV!t$PDmUz63f~f z-5(eoD~2d|t)_q@#%z(mNi1#pVq}*JY`a7=GVYbjn>g92vx` zlywcW&xN$J(akOj*;v1X-Mog7;z%o5&P+|rfvm!tc6D1m{lnFC`Oh+>c7A9No4hXZ zrGfTcmHD`C42xOTt>JUa#p{e|G*j`M zku+&0O))@KMKry~tf-QmWs4yOI>H;~V7bR&yU(^!InCGj9Jr+hpj+$xjy>`}f z8#5YPRM4&JEiHR6V5N`ro7ca|y#pEzLw1S+(UZNBz48~Uth?%qo*&4F{4k^6sz7b~sUeVzw0{{^H|I07`OH01ex^UcJLHw!i`E%k0 zeni(uOxDKYnqGE6TvbxG5V>XIQ5_mACT0*ztmM~N>GIyQLkrL?uI@uDgXxaQJvwAi zH-k4?gc~Q=tro7wB0FkE_nFYCXLaWd|8bb;{*W5#g-R!`1129ww?hD?FSbqCL~NU8 zwEqY88TSzvAb7-J%qwA-2Z=b1(V;dMw=<3L7jp}Nahn>I30N1WHi=Gzs>Rb_71eews9K4yhz1y4@X6T9{q}DT=Z2d|&zma4s$8 zzpOwc2d8?$ZXPyvjynI5j)Dc-oliH3ZJbVTPdhJ<$NQj8cS+bC>>f4_ZVyN39i8~V zgPbu}(UHGa{d~%FFdbUpiFiTKa?#^Xhe34_0E*Qh2~fu%3%!BPLk)@3=xKLQqy}oR z>U0Poh|v`MTzIWPr}#|rqL`8W=EJV$r@`2p^Cxj3bj*H$oj(ZdB~Y`-^3QhT3~lF$OYl{2RD925xG-gQyN-0q(9C%S zA$hsLR!MIGibw+rConKc7BD8o-P%0KGv&L_xm{ z#>9|9P`jlI<9{kAPda74xtAF%lfW_k5E=mqMs*?kf|h9=c(-ab#E5?D**hmZK21^d_+3#p8t!|4jmTxpeWZiMI{_GDj+!%MSQ4}sn;lufj z_mUmD#W~*5Ztj$}KoAu~i(FUY@%ecD5i~09T|ZpzB^Auz8|1AUW6bzmAl21teCmX+ z1PLRQI?XeA|92c-wxJLonFCuU1oP*i#)&CZpOjbRRW(WI@^Gin*#rhB)pM8YDyJYL z_s^{C_df;Q+_0MYRYBJ#VyKch$YMo0FT1|Yu({Sx9K9GS0(!cGYaZ#+3orzaW3$|U z0+?;XMs{{=t6Sd1L_n0oYU*3WlX%;>z7IH>>Vik-LAzRf@>!Xe21$0ku7yPmGppk# zL797GRBS=W3vmlZzJ-dSv?js2_Kkx~vPv0_D?E^uGmRb&-6tLdVUoL*<-{_HA>zBk zT0(N5E>E*QGAi@PKF*J-6OSdvAB)Q>Ft9@>~BP7M1Vw5Ds)99IogZ(M6JD6D}VhEb2RUR|N&K;2k9Y zffXb;|F@+bfbD>0tuC3}$Xb(-Rl(T%=IH~-nGIy^Ujs-ic-295YBfj$&(!U;HUrO7 zhWk^m5v&`Ct#oKkyn$k+5!UyVDymBQK%Lf|S&E}(yEh#Idr7wY)9tvO29y4FCQg(T zOx-*QoP&tjL%E@yX3tP`oiOp;bv*p6l|su0IH}CJ*W8FkG2=?lD;2eaMmdRUvuReT zW9^MGXj&_ubfC(VNd02mnSM!0iMGk2)fS!$s59bpLXC1ho`S1CsBF+S5Fsjn+UnCh z1#&9U!BRJA1aI=?gLOLXP?)!YoTB=gERurCR!(iV*w*WX!dgnzhF<`K#@2rPM>Gk< zBv0}Aukq+Tpe2xGCd~B8(JSRo|6Jlf{#j!r-T*GhB%u;cw(MW{e#BJ^1ZR74ZGj~e z=Q*NFV|eY0s?j#c*yjed25)x6gOjFxGs=kx%863pM!==80Q zrMzIFhH*R*z?lR)V3jq!4#4nC@=BZ6Qgwnur_i?3gAJW#-arCC;yjrp< zxZ8%phNh}nOrN!gc4y*CM>~XT^hU0WY3Uw|-+QfwkD}I%QqnNNR(vlXt(iY0HO#Ka3281hNO%JR%{WhXzpac>a%GF$G#jgxA+SwQKC7~I$NP%n zf!uB)d4|{YsBqJvB;%^wFc&o`No&%)*)X_R%SvSuEcwOHIp&dJR$C_N^gPaSmBe(S zw<4&@yjyc$WgDfg!njou$|XaxiDW=}gJrq8_xiDJs$^5;x<~(rM~bYd-23aw!V`C= zR0@u>r94M9ZTrVF`3z1WxvoTpTY#|s8$gXdlHi6eM~EvRR-JW#l_?4FpZoV_uni%d zSFNO84p%?`R*{=GL%abGJDpj9!PPR{K{`!X=TiRRww)a`8%FAUqlpS3D zm(7@xYmWX5zVQA5H6hH@`3AqJ{`ME>XK!<;=y@w3jJ03h8m#<&59jpUptYSO!C?bz z)FMgNd4FrbPqBba{B%jIQ?r~Yl=C(4dFMW1+D~uFB@2zTW~ZZN=YvtCtH$lLvK^T0 z7R|c*AF;hO0Zsr+ht7iCK=VAWde~2N8D;QAVBU~RL4n+_&r2eP%XT5hoLw;pWeXB( zzSbvzGx;c!ZYhj5+0D$&t+4$QZO#bHgI?>=FYD*jQuXrO z)eu__s%_DVP-|7U`dI)%__6JhniwQoV+(hB22%4hZWqXQOO1m zQr_6CLMo@YiD)eOW*D-j5TwHH@*QA5JOeQXLz_S)+J5}Hi0}9pz=VpJN1vNDW6R)& z+iqMIO(3uj9Ata#L&^{qYl1c9emDpj9uPJMZh(9I4Pd#a0=U{WKcl=YgexY0>(Qto zsa(ZFaNY?9VHWz%+ZQjl^hR#IVg%pGXuK1SvUhDwL3tv|70v?KN{M!y6rK!cOMc95 z#eX#$f$PJ%#-;PAAw}zN%2Gw*>)aKzt|h+C3~N%a%-zi)4n4l#{X-x*;2c(#T14AnbxHz!GoSZf_9-EDvJ}a zYTsSpfIfJ`my=Ld&NiYWxF zgH7tc8y{MW;9)&=h<^y`HA%1^d`vtg&XZSStPI=Q44bi{adEhGcW2UI-{!m=eRR$? z3cPgpbhv-+;_{S<-As8?0Fd-2+_ywPp)``_Ht6&2HJ}sv71RL2izNQp$?9`UEKj)6 z;WvNMwTB2XgeDkN}*g>JSxuqj>{I5gP;~iwU1#hoq?5YSztc;NHHBRh+mz_ci^J=aSy6(7_VH#Tj z!oz@qD}qX+$Se4_J4apw8J3t_FmF=qUq9SHFHc`uVI$`m2^<$wJyx&GwWYuxa4*Q{r~$>t_&8-P+Jbg`@|p zTk9=*tt#KfE%dR%CKT?hbVTBDd`NXZ_!m=ewWOU+Z(05|WKGR@1Kr626M zPx%7itvv+f(??+mo!WeQTR3G=`uWqARY6YGTr(A6UfNV6;f+TkK#Js8OBZiQwP4lZ zOgB){kfMeczprVdJU@Y5F%v2E!~oCYnqJ(+V&7%GKeh&2zo#|OK7a4^*(N#OIQ24R zU-W6iOa${%gAc&tuK|_Up+3;z%<;zKytfzUF-t4o;q#4$U5Ly{MeH%Ab!OntAlD)J z$@9)zo3h&)Z!Dp7EBxhTu`@?|ygO9ceWjQjHCPTr!*VS4M-zZsj&74{#* ziw)Jj>1DWz;6PZfn+WoTZ@hCtE}9EMt|^&L4Ui7zg_0^AUfVnqbb3tK$~coD{Ke)b zjx}s@5bRea-$V}B`eG{$Z$ok%Rc!HQm2m*%Ew>(i*ZbWx^@imKpNn2dD(%i=;pM_` zPW-FOM9-qE%1A5~Lg<2oL6KV15;Y!<^Yyg8Y`b-=o_`Q8j#CHt9hSG%u-szuP2AL zfAP}dEXD3zO5~x|*48K;JnKYmYAAypuS8f(P!L$Am=G+6CMTZ))bxv3uuB6oJaE_Y zX-mic3iWs6aZ_L};mX!s9U=7J6RVp8O}|f5Af$|(O~kbSc6F{{RK`?x!Jd3eWJ$}l zvIJEvvXZuvCg>Gj|BNF?5^i8m#$e zt((y*eo~7*JFg#`Y&o!lzQ9P?n%vxHK}Ly_4V^ZS-C?heHNWfj-$1dD3vUTbBR<~^ zzujlEY&4$PR<6OM??9=sR9i9?#H^97x$N64Ik2Q-6-TC7UXyuy>u)~H)4{aHkkG&2 z55;E&Ub4!EGN`o+)ZVKG%HD8-NUY{5)oae0cjL7J$Ip} zJT7f+az_}_u{oDVJl(S#hB>&eJ6Upka6FG&%CDBEj3js_z9butvsmZ=k6zhAODSAn z&AHARqAOrkL>Otf>xk?i6~^pK{?79)YH4%5jW;%0q;caUGFM zyyaYsSXtMOJ3ItMIZi^vJr02zB%MA}ndhfmbc3VsLwS`jlUr(vgnrg`nms3~s}B_J zcCsX38t@$ZDim=<;$1WCARL2z5PJv_ov|Q<9zA+DtS)XrSWlExxMK%IyFowhrdMsfk}({gk@qZ-Wqu&q_@%8ypHwM04hp%=vO! zd7BmY)+%RmLt52HL9n>keh483Kx5_SyUz(gr?^UjbqnKCAT=V`$$b{pF4?}ayp2=4 znufgt*LQv9KDZQz!K-+b%R^+DA`L=01%(NbBQi& zF#SC{^IIUaTNP!Uw=ff=U{3JExKGvJuLtk|X)R1x6~$=2JJF|wMy1b6Qg)N__fUdl zMLCUt<*($twn^UYU!5^(q864(cpSK=n;oMU(?nIOpfslamd?r@)|=aLh|q-GEDW3# z$W8Hr=PLc4d2_4hDpi)LNr2wu4C<*Xihz(s?up2`u1QMf06fv6#o9SaZgglwX0!&O zAqU2V`Eqsj%Fr{4fW;7G4B_ft;mHO84SUj<0yKa0E~l9iAOgP&r6J+$C`jpRZ4@di z_(?f*zkf<;29Lq`pnil28WU4=FmL|s+?)`Diu58p5QXtoPBL%?#Fu+}c1jfpi|5m8 zLzHPGXXDUdiS;HFlFs9^!iw)W_ZR7iRKoFXt`$rr(TstkscZd?hX25Y?R%lCcBcvko3jHG@CEm`7F*r zSfJkgFrg^)wDTI}vxRq^ctq)q1XBHl_AfQYJ&DPr5dY3Rr=l;f>Hx?ro9Ck)H%8(` zNTjDEInQXqSAV$Js&lV5r%Mit6JC?t-$ah0xi{Zk21OlMC>u}WXEh{Y|HaWvLt8Pg z$FNVMfM{%`hyM))j00X)5(Tz_dp63y@5vo(hTk|%UlrAX7p9GeLJQ!3Q_5gxz?ovvRw*9jAHVqZ z3|)GMt3pP%kDn8FQy-N`U!5FqSi4uSg{Gj8K_) zj~mmc*B%Ta!CJKl^a)Pc3-Y$Y#-m-EwyW9NmW$yI$sM4I;@yR`j{2;GbPdLY1Zb|% zPfY>TJ<`)io`(RzZEDWAMkE)9Sm%ISzyl(_2hug_%%DNh{7Xax!aoK@RNaK**0GnD z{~ho{2z|le{QSYD)Ra9sTKw8UKO2L832t7*+l&_0RWCTghRL!4bxu5m+R<)&wwFe0%`iY*{(ss><|Q_K+s{y!kr%ZH&_+Uk&Ds&%|Thf z4obG8fQD0n$E*NZSC7Vnq_4}t3Qr^mHdHK>PhnQe*h5%R#RVG{ZAiO3Sii*4ju4D$ z1k$(UG|Nhe&7>gqYa$FW5%%sjrjYwl!~&6PAgY3W(avSxA-(b%M1ZET?`5>I(K43~Ln0*Tv3FZI7FjZ>J%YFK8jkbmX#`8V zu2iWa#I`>{gqdsYb-^PPQlU|Z3KUole9A$7b+i)NJ zr^cdDBqOunL!K9<0LIvO|TgD@r_FXx7Afi)nVc|UUWp)-! z$hv+SLH}o8t&eUmR|_N3QKX<)b!z?%b_xQF#|R77MYEJ;)1r?+CIAmbxh2Dw#`s); zJ70?z<7VRN-`(BI=+ns#9#{twPN6vJD#-K2a+# zoU6v)sSzf`DYIgTOzcZ9GD+xbr_KJ?Q_bVrY7+llLln;;@(A8^dp_f^$*MP%9Vv$| z`h4N+NA`o8<*v=!8C<-!kTk6=5(VVZt1yTURHrsQ;8zw}{LP7@47NvCFy`x1XjE|? zrv|~uYxZp-b%ld!tj%~Fz;J9sR^e^eZO_5sJMi=V86p-;`kv_ZK*X-ipKK#jrBZ?1 zEBKCNfBOQ}LG(Ms=vYMD7&*|Vew8cact<;C| zxzqL3fS>Cs9(n{pYn23=Pu#7FE*j3#*{;(~YziYnyOw07?+2w#=uvkJ2W;zH9L=)V zS_*+3HyXm7tCLu(Ok*#Pgm#(OTSUtnzt#_w>2IDfW<;z+(qSSa#}+8JiVtFFAI-Ta zHv!#&TvSZKNfrXZ##*RrN6`1S9+q^!0QR#^JOOVT2c!B;QoYiRM* z;&A7TB!=(fC`J$o)uae}ZszsVyUtNZpgv47Bire-@(kG}liD}0h|!_eRBPToXK@8| z`}O7edadPOiMAqcA9VD03MQ)$Yr0p~|1|CkKVvPe{Q&@g|CLvMIm`byWcmM$qW?=- zsZ#2*S?7cA`k)dyG$3AyD{L!_P?A7IX##-|D=00`9KszCdA@cr#mUv(QQs&GDpZDh z0^xc%p2}eD;0efM(WFYD0weqSXV@ylq+2+y7jq_YCh(U8(JQ&gDRqdB95U5|sSg?A ziU=&QU^fhbg0df)GAR$FMF1T={$4gCe3wciL8DK6GjtY+!@7LIdjwijo1daNw=pmx zj)It<77BQ?&NUYAXoxndA{nLg$826uOnIDZ>CC&$GwVuKFvdaC zKMq^qgn?y~#^a7<#w5w$l3!secWipI$qAnf;YY@k3a>zfS~%|XiMEb+x;eQEXOr=- z$R2GkZDLu?_hhn0%`fz1f%q-$z;(HVF3^wV?q=3F==#Fpu$Er%jq*C*CNDqyRtCyR zUQWP$febWhEHOwBYfhq19k%em<^-#o7bS5kXgrpvAG31sF8@QW8^8wNiiUuQ%ryKK zQy4l9*=iMc%0MqMEGv|>;0S;MS688~MA^9Au*bg3df$%L`|CE|*MhxisJ9=4onFY0 zrl$yNHvR5h&F;ys6e?>;dF*i$9U&}F9NT{{?X)PfuPsB40zw*TJWVxIic7~PNAb^0EbU(8@oM=-KIK#kDdt|#T2BT^S za5pkT5_99RKQ+pYIM7~wWTqXHJ^CyhRAW;(BywBxrJ+B^-*Bx<`F21O5v01Jo1=5OeEg|hx(2;{95Uryjk&2=D{@u$@7%d{y8zs~M zf*}@mrxp08!q_PAgMpe%cv92_ZlPJm-~m{DZ(vbG2_-L?Egs376* zFph#NK7S775V(ckU{*e{7XG|S0$IjB)4~@sc zlr6i)32+L&hKWy7B3tQnlfBDc=*e+&QTMBr|au3v!K{5y zStdVjXd1&tnoWtbRoe5w#!L4Y%r0fl&U0=>JID3sua()Z`z%5;vIob@Go1Hn%sQQJ z+re+e7HXHLX2dg`3&B)QI!x6zU!8`WJFLbLkx2?jNS;(Fp?0>oJt0^s``hKdNC$U$L(mrq z&~<+E5&FG8P%5)BvnSf$lW?A53qfv)=_c})YO6;i1vhF;3%FmTh;uS)qW@ zJ2_?jYzJ}7jy#*HK&qOa&o!gp`+i>+kv`Je+|bxFuEP*BJ!FPS_wd7{rc1KSMVn+J zzHj#soG|Q*Lt{l87OplvIOImwo~$D@mfp%BoWn3vU1G4b;XB;nu+pF3HJG8K8!_#} z)ks{WWKxV=$mv8Zp*kNOIyV7VfRN4Nn=IwtxSdU&=$NNpwvQXf8w7=k+4;BxuRh@X zKghk|xlrW{1Tw8L7y7GrmpO}p?=CVWA*#~SMlD%mg3nJM+dcm|rc@Q0bCdrnW)iBJ6)fRo(jRb5JX(HuW{+@$qQ+sCNQ=fH3BR;E)_6K4G#k9iFDKP8^upB&jwEx z##YH!OSR5K3ll7k;uC(U0|-oz7w*FAV+HcQ|34ITL}?N+(|_O89(Jc>!4DZ zx?f6y+Ae3Hfr?|Oe0n=+4akHxp=asXn!<{9 zvT}P3IlZmj-Pzja<{Dn0O4s%Y^zCl{HV+n4+;9D=T@B6~@C%_l;AlPW1kxj=!OwaS zn-mH<89ye;a7>*bQN_RobEsCx8%x978!9Eq1JDmnYjY)EH~HlK!n(ZFyEV$p>-xu@ zLkVh3fXKuGYALQCHk{(_U)sozvK!C`5KE&BczW7#{sz^7HJK=N2Dqu!zg`nqDMcTV z8uC|NXKKZf&3NU&h8)WRMcJZQ>cRTP?w15RAJ~U~%$Q2%FrGm7j}j-@eG;--cItgM z5XZgtC7;or1E7~NOJ6`Q7u`~mgipdxSR%gy!$?aJTVE(iUY^pX9x#hAiG>9a+B9?K z)Xx9AI3kaBVXwC^G;2UVm_<(!=iGLHbB590LDYm42?hOuz`d7;B$KACT4vupo703= zRV#Ruc%02Rl>z#(L-8O^FCSU+!{+1~@7h)+PhLo+7hMDlEQ6#?K_NI;$_~mXTrrg- zAq~ar=Ixjy6*)dvjfZ3e?52W%_B<+Z1+Axtr>FaF^bykfC0hpQ$($UKTtQFQ^1UYmZ(}DCqjg}&#mW6wtwDy(xxwT zNk9lHANs&&Xlxw>I{dv@gA^HEGC6Tm5Hj^)lB<9g+TQ~TR1z2Ebi!Cxn;4~Dn6@F6 zVY|8|!KID;5GBWcXl|Ai7m&vrSyy=4u+NK{EVN)oFkdYid!j*4xr>onvf%4w%kNg~ zL})6jnUV_{$C^%w?i6IiI8CMvP_1H78+ptzs;cm6a~@{ttql&OLTqaV;0&UdFm9 z>!BJ_RgAeV*bo6byFe(?(h_*xB(c$~hk2tpMF^dBtbYQLFu1_>N-oq)TS{r4`#_Li zBx(<4TUfo6K^LOMU6$cJ8{Qen#9K+%oGCIsEL}IGr6(M7GR*ZFO;GGm;3=opvTVD1 zzPhqW9{n$iO`@S%SbworFFy!%>ZQN~(-#^qn#r^`5{erB{P#TEk zfT};J-5KgmBuoU^^zx>Ah?2o5sJFH|jdj#7F^^c>E*{~B2SxYuujk)E#U1sg?e^!* zR-gM!*u+!IyP@sK5|~Rc6t;|fgq*%n2&|&#Ho>CC@FG0sLQ|uhJdj$j^@oqOlQc&j ze^mn;N<`3af$CD@{O*}lSgML*mWXob>TqL;@HS|;(IEHT;&G|TQ8&+m-CnSV8k9ht zUZ=Y0HY!GbDWHa?Q6U7(o!s{PNhF@nAWQi!+v)o(!zXN?_sc+a#%ppr2iEe#oG~Q= z>0OLb+G8*>0adOzc_ivkg=1y$tKgm;NP(W<3C0qViKwlcuQ>V#9mT#4a;-k+Wq7eI zi3hsBsCfzXuY3KPD`JXnc*$>fgxX&e2s7@Zjx>k{>m!oRPcLZYUfu~k3Rf8-k^@yu zTRG@XA8oFZ%1Dy=^Oobj~2ry(e02CVtlp#;9p500o=NtBm4xXo>=;Y&*zR zDkXk^F_2cmq~Oi}_;d`WJ7k&Je*V0pjrO#j6z_kI1I{*n($EaM@;Fl}R7tbiEnDqP zx`osDOVQ!{0M^`Xk0{`6&*(cj91gs(vcZaw&eEBjfstw^{GEDXe+YepdySmH^#(;)2iK97B#^p`u+ayIhw9i z?}=-n-;1%PqR`Xe3lx|>82_QVuBW2aHH=4C+b;TJq9fLw`MAsl;QYusuQbR=c?K#Z zM}W=~VoR%yyXB<1A6E)EXNs&#r>t-a7QD7H+7JpK_R5;JX~`GHuw|_pA%7ft59jAk z@WDy10=3;O(4Vr@>xWXHaM2X@t#dkrseZ8cTp)`;wS>Q^6;o1-`MQezaR-i?e%KNm zeqr=6!gbXaZ|&u#BEHIb9+3i+L9PTPJo&z74k|4Y!_K_&#{I3r8(sJ~SRPx;3{LYv zFIO?HFz<_`YqJ3l;WeT&89tT2r*j;?!BvMMG|6qtvAY-sX%5k=54AX}ovrxwmDZS>M%b3tuIbF3n8~o-gdfn5-esdf3}3 z89mgvR{Pu2RtjOdOtY*@z{L)CqG zYq5$t%M9c6cCiLcZi8#UcXF0_`(FRMzl?rPM?E31EorU9>@OrqDQ{m{Jot#;Tw=YD#*oe!>K zwKc||f2F_O-7iavms1f@ayuKg{*4R@W!!pK?U3xY`=*Q5WYoFf`KZijji|Np89BKS zp2=Km>&TCFR34^sTn$?sg{qzk+aj3xu6Hro;B3F-RU3npf-K&OV~~s@j3)ldrF_Gc z@HYwVtoS@zqeqfW%6_pS7lW98fZGyvYkJ26X?8osWu`?;E->K9Ldm(xWhLPk(t*+~ zx0pT;t))0RIUN~R(d<5sC+j3~&xBkTEL37KB{sU+q#zR!>qylXeArq2*H_o+Zf5tO z#&C9nNLQGXsYPv5q8hr)Lf*nv=nY=)Tf@7IQ5aTHA)(d5W;=eI>0LIDg~yuwt~c7R z-vuNS;E_iO z7MaAQ0d?Xy?|*iG2C)Cx{o&yoof6dLh$R*5KW|(UjWpfYc%=RGE*>`>XS!~kfK%43 zs>oMxhfND|_8{jzNt9qECzmB8wIQACh$w$4aQc)an;?yWoK9wVW9T@-qCul_?2~ya z>>9V0L1lXwUrx4&uUJ(`8hvcsfouSvTNQn`Xa=k2n3UZhMcj+9;R0;rK|;MH(9vLd z_zOw#0Ts~tM4#;t3-NK#ugmhMNhd+aL9CBTkM_OZL~L4?H&a7rXC@8e?nEq0&0I+| zvsv4dv^6J@7f6VktzPnoSwNAz7Wx!ROQ@{`+!(s?5@W2nG`I7FT@M_mTWWj3lt^~% zw^%muY7TpyC>86xRLe3rN1I@BZLI-)PJj(V06saiCXqpG-P6gwn?CCOD`#6 ze|Mi@X>Ctb8;&krWMp&nIetKWr;k-3@RK6?cmLJ#^M)hmMM+=i`M*l>ybjn z%j=X^x`)f1i#A5g0|quD(lRZ?cV?;?p~GKeO;i*Lq>lY~2nY7*g~t@rjV=v8lq9le z6e*A&Cflp;kKZTi$t$U81M0G;c{o74Q+Qi@8$uP6pF{72i!7l`LA%~+r)7ED5-AS2 zgtZ=c;JTfhGUzv#6^yr%M=*?Ovq=K7Tl*7)yOj{Z$fj)W-4aFsI287gKv5Ox*pDSP zD__c?K5|Ie6115%jj{XdAwCvwYIH~6{OU9(PCX?bC65IDo3=1AfAZtfKVA;>2Xe!3 z9&klz8Sq$Z`iIDB*me#jo^dO>M|_2D`ZKge6)2%KW-w(%P9i4)qjpby)bLn+&nd6T ze7vNrzySZg@6>EQ7>Nt?45E|}y$lBBBe;%6=`vO{-K2D2wVCW|usPY`vjMwt+5%#S z{X&@Atd;c#w%!NZN)gv3Dbe;C=`gEM)@gzHq@_<`V6KRImMa_ES<1Yf`aA_*;vi8Y zFmEj0_qy1QvCN`(f5(;F~ej<2K4`Bfmck-*Ny1m@?4hE zysuNvzm^?+0moW5i{7R|iFf4Q3Q=$o46u*@q1$wgT8v#0n41}vTp=s+aRFnK(P7?9 z+(}#mi21q!%#R4JKjrk6E@C zp1ig<+9hX4yAaH2;5_!E&iZ0?;@ThSxNv5&v&r^GbP9zGLzFEU1(7MWf6NJ##SgexDEI7k8G{?gnxXek3aQunfVZ#KkVe?aUceXqM zVe!k@7^AyWfp4B&a=cw$*xi6@jJp&Knpkenz1~SW_d0$4mjAhS>#jDUPz|tjBV*~m zB@CRX5Bj~T_@9mh+&g^+36WO9lcZx1CgnbqROZl$^rP3so}ZT0nL;heaTT(mdBlzA6!ddfdNf$S zHGS_Lvxnxuxi`u=^?P}Eu_5Je%ZBxEWkZ@$MM{xoscv_N$JevR_cJt3vnG!c$qFX& zPZ0v7p!TKT(y8n;>Z$QqGd@af(*A2k{NarX1W~y%qz3I&GDVCEs5$2w&8VdW~!=Kb1)SS@9A z*hMYtpkgBnry4#g5`^F0LZiRr9LM`V?AGa82!I2#9h@3fPfhFJ<0)Q z@p>f(AQRPp8|Y0+pw(+3Xs*)Qgk)E|J}8=}43+%bHFlEwJw67qb0d`&ETdG#dy2{* zQ>TF6b~VblQq-v;6GWx1MWvI_4H8UkrH5-oR9HSCd^t-}d5j?y#Nzh}`VQN{smI`C zc&^JekB4GYv}xaa%DmIB?SsgC~^KzpKQ zevG#H6kwyC*{VaZ@ zHIxyH2eB0!x2++sVBIBsJfDCj0XOEO28%IZ5o^e_iSIZ3}LIg#MpC9F#4eniqKstv=iF&%REzCJsKft;OHHp#m1 z@NsmN1J;DK+KKtYf$z#Hh94F*Wz)bx2|cq?G)__(*_~O2xi+<~Ob^$OrOS6qkJJoa z2?kFpw&5OhH!l8eY6$qGo{aeE&mw)lW1le-gyz2K>unCpPpdZ#)ThQRi2+Ge5#--t z_cBFSziaEP-ZfP_4!xmikT?$PxiZhoC;?Q)sLeu}_7^2~HMSlDc^b*FPgyovH&rWzWIKNh z%DNhLsciOHwD}Gq=Q|wLMx~EwJt#kz4p93w_ei8>YS-CC5D$UiaGT)v%OnvKtklX& zhl?(r*&Z)jLrKA$4Jq}JIqYUBj@a?=XGCGWHKjHjT2V|wn~&YDNX?u0ZFtJmLRghM zP$DhJ^}b)tHi?rd%%SfLN~aLtJjgcqSrmz-T^HVkTRo~R)Wo3-c}BYJL$F+}YM%nJ zE?z?wjc%M^-ibUMyKprYJT?8+3UK^c{auqtc)bkK?}D&ojr%ur+-chg@>v0uY63RK zI}*!>Cz$_%4+d%Sadc-t_5ZMS4&0$dQId{r+qP{RFSc#lw(aD_wr$(CZMB5j#ey{vxHXGb z%hfaTFn9uody}iU3g+P)e#R#X!!UzIZ?Hyp;3-_zij@4alveds)O%8G$U6L%&xZL=%|qmw{7 z#0`0LnzHs{$p4|Z3+=n6=Ro`UvRr7GLYIh3bKP;*~!{`)ek=Z*gwn zavB$;w-8W}hjW~1uL%wU#Nj_ci^a}6+h8k>Reb;WU09e-q7zl~mehYZdgD(x<-UKp z{D~<;?tswt*h4z@FL!a@pI~f3u++7su5^gEjke4W+UP3Bd-CdemyW5shV^R7oXWI; zLb~c`MPGh~b6a0e{>@j?HUtPo4h~N`&cqUap0L#cew#7o;H!tR&^eE-D%PsqI7_sj zz|OvumH3E&lC;1TjF|!$U$Zhuj|K+!^e5yw2bN6c&K0KKoJ&(M0X_u%K5h0?9l1~e=aWc!GOW{5aW>K+&slKkh2jN+SoBq$2I`(465dtI(&I>7xIIr~Qhx#H+UcM|Eg} z*oG8aSxgyl%W=O=vz5%^Xi(jl$LYEzM|ee--6Nb}JDMxc*k-GmZdqq4DE?M?QIw{v zftt3dULUEULxY3i9!~3zIyeP>pyWr;s2=DUVc*tz0A+Ms#=UKuLBRcsirUw;s54=B zOs`|V=%_dwb*zhsdNbLf9X(h`D6`p??e@>6(;ZC<+d1w=oY9_G7PbGYc6BxjIi_qa zt{Q0^W{e}M>UZu5er1Vs_cAW~-hi(JxFXMqcb zvl$tON4;Mz%G(iXk`3uUx;@J*OY@-gg&&~qeQCvZ>X$K6czT9?+5<$#=WWRx-wjAP zJ5AutG$!q`ImXjfSK2! z)7kZ~bB~l$)E%T}7*|phG!W99tdT&TYpyguDPP${US+;Z21mOW9{(m1Trg(I7lyB@ z$4stgGd7nh?TE!NM%U4>bS*1jyTpH$fBjbNb}L3|I%540i$Xqs!2fT(2=ym)S>hK$ z0rQK(`kxz#|9tKJmsI=@C>9&S_l=%|1fK{NPO)4a{N!GN#)j=khjDlgtf-=WlgOsE zg(Pt~$IaYN4|k&Hy;h(tbEFtq+-T=%u2(|beh0;><_tR8j2=u+keX0(?*va!gao=r zL1a4$*v>?86Or_u6lEK+1+d!y#6SP2?!z5MslP<*L_t)u!C;*T_9_o3KfyVCUvy0g zqHw3<(BTq$6xO)IyG5Zciv5YTg!DAZV);-O`eZ2p_vw)ffm4LEsl|yr@}O0Jb=s@~ zkywCtKv9R69FtEd^bT4z)Zp4&z!uF6Y&<+%TQsYYZlYzOc=?rd=d!zTb$WX+d-+6z zwi7BVPV`N2*fw1PTG?Sp|5t5O}F+YG%xS>PKeiIM6Cclf^7M zYOcf9SzgYa)&L9)J8^Vi4Vrs{@ax8iAM#zi8L_;-zxV?AX>V`qY|Dt%j@o;HueP$+ ziY3io!}@vwtFhnirOi)Q@<9Yc)gs0flYP!hNHTQH^{a=&0-J`gk0leOFlQ9%y>}fr z0X`lGp5bf<{5dOTL6HeMF)0&86l`;$=9Zgh{oNNYa&vKILIlx}IQ=M#-bc`b%qBYi zGYjNu{L@Z?0B zROk0Q+8!mTO_dO`-80f%xtzwLYmbUtD5PNKvVjrI%@WR1vK=bu28$L3Qps?L!gyk0 z1ZJ!LJCUAYrs%y6OwwBId}AkX?h-#lkx3OhML1?B9aaBL^Mr9oTZpFYhEwl17*nj( zje>z>rM`>-9$ev24>W}-lgT+z-+q$O?)**Ki#P|?8fTDPS~0=kwjW@1zA%04r+ZZD z4tb14hNae4O(i%U1}MqDpMrB4T3jV`0ezWV9I`^x{0m3wvQRd14NP2MaD$kpgaQ># zKdT6cnKFxUH?E?{=pU6%FAvwu)Uk_6F8c!z zzZ=Y)b(|m#bSTt=%qJ4*ZdVW2Fd=oEm)aH00E`5cW+3G|5R1^nK^oxhMbT2+vkPJ> z0T|)c&mwQdG{awz|Jc<&g6##96!KaDtf;D=WQ84Po`lnk&_O+Kf+AHF#1t$?AVV}t z9QvU*STiK4U?5>Gt|i8SF_L1qCfH_Z+AnRO>C|a|Wq2`+birYTHX|}KXbfIZcdZga zA2wL{=9uYFMpsb4=ncG_nu(A=>H=hSj`_Mp+ef2vRm;3== zc2%b|2g8uwMotqwzWBfRkk}*OHtXi7c5PWc?Wj}KN=@5c$Gy1xQYkH z2d$jXg;F$!K}#Y8;Fqj;UvdtVns7%k_-rTyvhQvo6zqzgV%ywGK( zkL_ll9JFZ2{4a5*4G7-1j0M3b8Le7`@37Bkajg^|(?5-4HmIR&kP1l#p+F6t^+-@9 zZLpdV)~I#!>OAU5OC~jT#^4l_UAf*+77-h64~7m%f;oe|>edKCxDNN0JDPuVVpU~5 z-SNzcL-4TgOgs;)uBz@-#zCkpe}6Cg8~d#^d}efd0Q%=hEh*n58jY$3(W-C{i-;Wn zlXKSLAl`(sfdC5&;gV|$B}cG~r8J0MzNu6hRZK(&6q0arx00JrqA?A-$Z_Fi>8$$_ z9RK4kSMx2*ot~UG!Uq_#c%1k1Ikjq488iBJidHA5?DWohS7hVmJhNPu4w?O>E_JDn z&gHe?@e-n^ywPNnFtimYOJ+1{vl~uokj{_~QQfcg$i_^t^oD=b?o_b4N-?`>vgc{G zLv^{ip+w2Y8yR-BqOe;w3J%&>p*sB8&AHjv#$e0gX8kC1Q1&rn%VKM{;cC}dn&FX` zjDFa|%NOG+1K<97GR*gNZ>C<@Xmq5$f2MwschzRc^;>1skjnOWL*6g#=@7g*;UB1= zu~YvaoM)LbaCwxU(;%%;&g z#GPLD6HMasRvPcixi$Q)TKj!iW5_KKof}bjJ@@yGjD5sH^wyR&197yt;6bEPi=?kn zGVbs^5jNkJJhqBglkK~%T#wM%8mOTRW%!-7~EHJ8EHT`L59maNtvRl-O|Ga6Z&| zjZG@il2+kR9 zob>=a^E!~s+&-FjREaDayeGad{Pn zCUm@Nj zfoSDHoY;htB@`f{94?ANUEu1Jl0qak#G_;8K$6ssv{|No&k_o;hQN=BbY7KwMmx_x z*c!4aKbDADlBc3!ItmWHdCe9-WgUMq8Kmw;1DBh_>*4wRP!1pJG4?816)ALdktUmL zRNjb9$b83~#%zx{`5V_EOlu!#WFqx#^-vZYUqBY^_=R4n@-uTS#czdaqVOT)AH@^rp^yuYF!j4qV=cf&i3+ja(Ma1 z+UZ3sKIRRvo1N-;3JEHwi(NNlzkw$qMkKV#btzZ$_Yp0_pk`!>N14aEkWR*EQ`O7V zs^sc(5-GiNQJFJC*P#bW$v_cF+n&=zmTMMbElwcU=x@LASj2diJx8ft6PIc!G*&C>SssWoQ7a*^ zP#LZ0mxKiCPbxS}=es-`Ci=6A{hZz3?i?i`Q~?Hln@5M(!QqoXZtN+=S-x5HG814U zq#deHNR5G1jbk0!r|d-P6Elz&)dUf=<$_LBm5DJNU1l|MC+z76DFXRc{)_|VawW3 zJiw$mR>M@zErKZpte}0F!U6$P(UUVkv3egrbq3cDm(z#?pX^Emsh1ST<5~LNL@Zvi zKv_k^+azpeH=bO-dx6m|XiqNZpyrB&-u1S-Bt1p?be0$ zBL=HbxP1!Jm>D#5+(s3U@owz$GnI5~OeudGRHt}`QqdjPr&8cW2l*m@XPTNdhD?G) zjwb5YG97dJHnt?#x*x2ovrp1`^9>?2H%?^J4l$(}fXMc54QGFo+f1(AYN+SbRJ@nu zDoQxstTL4Vnn+0@iDMoX7mtd;Y~^4nXCXD8Q|G0fou}x^L2|Kl%+|JmkYcuU?bdkA z=VI`CzMUV(8zQca&oTfv_iqt~kr;Hy5_H@VK%*z!2Y3)Ix;QZ4`F+2ihQags**~c@ zc#4~;M^Xkx7Dmcm*ev3uZ&80bBmD^38d@V^VDHA@v8F(=aom*`6m>W30^;H?x!z(n zg3kZ&BIV%6f5681BL8q87A*3LoPH9Kc&*O*AjW3j=-JzGQj?*!-dzD=(eUjt##4br z0oSUL` zwuk|V19y>_x?nt6xo*U(KDPJLyKl2l@vo}_{)h3Eny$DX9v4pKT0|L$CT`4C=GRpUnw0DaFA`jWkM3wa66!^@F1h`4NQ>% zBt+bSWs;Q{LNecXwTXK#>tNXdcgF-N+wW zr9OP=tUK`XDJI2&ep(1qGBZV-6m~-{oKG1y5ZMFx0|A2X91No0-7Ic=0;&-JTh3#q z@?Z^-R6~po^yDp@^1P0=fQ1}fBZD$Vz3%(r|}QQ zpA&ZEmEnl_VZutxVhR~cV`96%$Jmjbuw5TqT#La?Y|xd{lMq= zQ`Bqc=1p>%GxQsl4p9K<3uVMko=dhE{*-U??4YPi_Yrxl`=Z-D%LsTxF4CPgP zGdR-K$V%~@V{n(sYSMFt5IHJe%?j+h&Kz=e(>5$gw{$Jtw$CrmY$B~p#@9ze0}EZ)bE@k|U|v<*|W@02-%MHK9gRE2>LfXFaQiJl#YU+)^HFYjNKRwqsrA@ zM4Ol?Nc=CiUc$&V$#OAIwWs_{50}q=9DK~W5d%F?$5I8VQ8px3v5Hzw0%G<#w-}6!#5L3xiFRVar+gd}OwlhKg**cwU)>A^9bw zO?TM)KM62%458;``r`;v{41owa267XL3G|bbK*PaEVBtY803EhLA>m<&bxq#o@q<5 zIiIZ9Y3sGGd9%T{<45Q4^5GX%yb~?xl13Z?rU%=YND$4nrl3;Zi(#HBF^cMvi36T z!PLPqvW<|?D9@vBUtU+gyvI%|Eg5v5V3&;FyobTxN1ye*VdHZqjgg8I)8ds}Npg$W zwz=U$3d#IB0X(kx#c{Fr!jGn$;eHAb031B)#ImCo-@Px!j!32|H6%m4n9INjslaNV z5e3OzxZI}K`x>ab;fe|}N1Z^As~{;1(Nc`B*=(q~Bfv6=ipB;vV#V^Xa`cCb=KBh7 zoi!%ThkIvP6$Ty9o8`FNr|Kdh)1O^zg3lRU?Y1tn)TrO1TR=`C!nB407F>VXteW*M zK@Zeh@{}hPQ7R89%(^^D)X=&+)4*IC3cmm>M4RG9Qk+OS9O9Du-PoP=p%){T@7}o$ zq2lgn^&We=La|F|nMQf5-r25^z8{LKepbFB7kt1UK2*jWHFbDJ%9?bu@$+WHFSJ7o zoZsQp?^Oe~PQh3-=HG{GKFyg$ zT*uNs!_KyKo4fp2fD5rDGU9m)S8H^o&J8WtiuPbc;zg8Wih1d1M6Lk{W*^BYWN6bW zItj6DnFK~1aRVx_Uhz74S-Y+-!GB_Viqct>Bx`Djn_-bu7CfeYMD`_po1<%^$m4u9 ze^TI5HrW=}?|;UQ6X3mpw&u9$py9)kQLoN|Bl+Ik?d^^8T<}yhSD%8+lMEzst_*7I zW_%g@Oo_z%_aUrKjxS0DCG@C9)o0$Ng6U2LiqB65gA*THRh-Iv5(9vbf^XHY%QtQ@ z*rQAwStACFq%6lfRTpKsa#!&yXE8|L@UnvlbI*Rf*Bz$h>2lMtDeD|R3`O?XFOXB{ zYdRXMt;W_#I`9~hLVCU>p`;+)ajF+soaty6JQBvMiK*xaU<;%Y>5RtIYjsfAFEy~D zjP@2T+&{RqM7{M$)tu3j`#pQakZxg6Oz`9@ z%#u5NxOPjcIK3*LG1n#j16&+`K_`RdFRxxU>WEwUj08}-aMmpqy`l%h(y&tNhr6JC zlMx;SXHv+P(=1DsY-*K4C!xt^1d1lvyJW$Zp0Nv-dGo~Uvb>O5$b^f)WJ$U^kx^^ze>`=+-I3Lv@ka}SdOZ%dP zGk1@?laYiz;VhLR=;3?uk+s>&A4&!+F~W_Fx0_Edsoi8+7*QI zmgFmU$gFfk+xB>;Txl7GWeppm1~5k zlJPD0sV>5A@*Eb0x!#c@_9!b9DN+jPO%S|S{@agRWH$$Ug{jEunn*{MN)`Gj4r8}{ zg!QD2g3tB%8gKPjJK84Fjplj2qBX*K)yu`#@ucZH6uVr`F#1-_O%{W^e=!9jT^xHe z;WN*nQB%SY(&~`nq`R zhwiz6vG;m$Ke*i23nI{N!7qp|t{DhgDST!y|f+cfvoXo8J75yY~poGNoKzPjGjPqmvY5md7mI9#_$G=!PvhZnQ7RbUi<}aeT8_i*FS=a*-u#Y9U@aU^RCwrs80sc}4%*m?elW z>sc{MuY@ndf15!yQ$n2jMN~0IhQ!h=He_6MrC-FTMl#ZUWJ+ik8-9KV)fYbL6`QKb zZZ2R|%c4Wkuar?ffTa%=2P zxy*fkG4$E-uZeQRWlv{UE-$l#rl`a&zrm32Whm#P%gbNQqU)t~oJnP~$|+onMkCY7 zr=C4Ge|_QJ?Bnz*K6|?x$e1-d8B2&mzzmn=R^m=i{o3n2G~WAkiI4!*UaW9-q#4r4cA3zW;a*u>SbaBzmCFF-Dp$31 z34eIbgaG|{s-{|Sb6&_xHLo@`bdC*$*=^6IscRATYb_bY#WK;#W5zL!eDu!9nPgMS zHP#Ls;aXUEgegbC=jz<7!ithdTkejD#_d!4fnvGg+0Y-ky`KGU{`(F`hirEp0oMrD z(I%0l3S@Y@N^muPJ8GC!T)T1vcJ40dVlB%NI%zkQIJ;h7^UCgQv<0nK1FlkPou@=7 zaz;1h+0qXLN@az-Lh>A|5@KZYQHVJ0i&qoU0#5tF6awhqqkt3*Hxwco1z?;|FBkos%H5SeQ%3$hQIzXOqe*`UaXhO0tvTPbnfGa}c>e$~w^@|M zQMGHH@3e&KxAZ|3+WvkJ{V9miNhY7FI4gdg&FNO^k@r>J9(EUeY zJMeB(YojB-?^rN2X2qr})Ws5pP%8B9Gt=sHwBQ+h{SN+_d(hK{OM?B`*{c^m@42$6 zBl0kdY9?<@0xR)rFLc55aMq4ImdDe2xc!>HJla^VZP%}Ndam+bs#M?9IaB_rJFb)V z=v4zqUTwFWxptLJFW$1;7G7)i8Qe>aK>0EJdT58!LhwGMJ*!;!?nLb)31XyO1;7d= zyyo+t6n%K~uCvNn{mGY=?kw}88|G6ew_&L;;m0C|GmwBv_RF8T7*}!d|2fYlTP(AO zfd>Fs{$=R>KmI`{r~h0gALCg3mO?#u`hjM8ho?)x(CgSPYplwd@n(_PTwTw^jtmMQ zkvPaEz(@|ie1G(S`x8(ohM}%!<#^&q58dD000TTGW{-{wV@;iKI?!~=(_u(|qlzFq zPLDiNz#VGzF+K39gK%c0AJM>!ljGzh$kP~|tC{n7cswA9MQbG2Lu!a2rzU#OjA;ms z-Vu6b9`({p2XYX_S|otMxf6EkV)%hVz~3j*oB>B1WYaK%5d=od1jQq|@|#M^1%^i- ziXwrrU@<$QA360zg8-Sw>yd`LCo5sd8h8WAEBW9QaEfF2op|JC_xO0ZxtaZmt->sA zY%Rb@@AmQY_`Vrj?nEP)D*eI2)-F`NPYJ>_>QZGOuy~dFaW|U8zNqZihpz)Pbj2JJBlu8!}6OsYwiR_P57&w>R zT7~!v(un%4vPum=BkFKST=ZI;U8u(#(J&P{9Uo_jK@W~?n42;nFoYEuwV2tg zvv>>-E^+R6Oa<$093PLc27PX$@A(mXt~G7a_RpP<{z4-a?X(Oct5=RNtw%8Qa&%;Z zqYOXpe#FB|s1O4MM9!MA7>Qm?LKY&kmSRf%$!1|flPiEl<DcQ#|U#6hqQ5-F0;-&eG!mh&|H;}x3_#L zztVm?*1SAgEsA62^!7YFWOQW|^%q+fX`InFQt&pC7hm}2o&;x?`t+Ilg(YezXHJ&?k#a1Nhs`;XV1q&}I_y zj@`@vH+L$PUTWMC{FHQw7am&CamT>S*3^*a5YK0RE9lo*TXM-nl-j6=nD&{qo7ql0wX%t5~50&jdg==Mbjeuzq z4B+5o*@l1~w}BmXpDj)U+REA}JYNt_;BKNL#4@f52=*X0|BLb-Vj6Pok&@&5nP;x+ zn$7i|zuNZ@PrUtso0Q%l!g*@mvvl1?Wv5D2Z_^Kr7Ye>C201-eUOJ~3($|TZJ;?>_+~RNlc?o+$XE(e(-!{LH+tszD3 z?g)iIP!H}Ge~>`gk~wUl&jakn$)B49^4H#lz-DAj4T~;=hv(nKpF1o!>Re@RFl#Ne z*TOgs4(GRI+|e-PhN-(H)W}Sl6OQy&nqo};A(J=8pa$}+q}tf?jF$?lrIpL4&l2bv z1)LJ~mTM1hlUbKeUKuW$##LLdToW%MWi|bG_zZmUePPfNq^G&0KF;MD;4S+O>b~Jj zn@;N+Tz7xqfP(KM*$2ZS@%7?2fHsX=vyHIpxvLlp3*NYadj|+9zc@&kzvk|+0Z~s1 zD|~RvLIH3lmkx7$FsD=l9!#-q=iOQrm6WB>`HjF1JuG1ErkIpod z^08sKiL=lB$dJTPRfXsfLrK!61qt-rIS_qRBQ?Z&Ygj6y zugtv7*3QPE*8?aiq)T@BtKRXp9C~IRzSOS(jJIe^lJ+Ox$^x2qz|7Z-w9;gVikZ&CV{Tqd!3(;uu$C4C(BW=%Wgm{922Xaiki&7Zw_ji}W zG|)?Bu}-8QD~cCuJ)-L~4nb7s>KXi0gNXDBfZ+gLdIAZ+0hlv|6th2ZWPT8Vrm|Kn zk-*y*Us9;c1d>paMXN$as*=$b13UtZVmeqqCh-TV1GCbH-DaDJG0Zi&v+|k@?_?f0 zyufgZj^2g9R{9If8SP68OlskfEWnGL3eGW_?9b<9pwGQclr-BXw^!VTIaUXX4&+9j zO69LJtdQLxDRnz^%obTB3Y0w0Ggbo6gB(yp5vxhc1oEo$m<$vHgfb&gkGww#!hcN2 zJ`{8C4bv;C2Gh*u%M;E#P7uLt~?J<4xqQ zSr=W!mWf5cSIZ z5<)-S9mC;tlv>E!C<;(uqM%^^2LoUf-Y8%#dBmN?=h_Od0>LjWm~v18(Z(Day?=ZP?iR2l3nRRR*~cA82dI15ZDfSyGS zJb$3l1U^a;XCe>=2PZbi5I|wf+F+FIFB>TDHF&=1q^L0ZD5V_hOL}KUD_osT=An;< z0*^0kQ@$jGX9y)?mSUBOf@NevKL`$D3Cc0631!tZj@XUR(w=nDss`Q4Dp?F7wo(@^ ztXy5nC~a48k#_-@HBuvuXG)ZS@EUJhT}55AxC9ucB7|5lnX`l!Lwj?@0{B?*Ur23j zr5&i+eUttVhis68Z`KCcIhxVSsTv7Es0Y5&BiyQ{Z1&mVw82FT{6*80H7ico*O;Cu zk&VjouxnMvq?-~El0X2FK`?ooP|~sHU=cKpPmc%72d1v_sNrMH34_x7uf1Ae2}u$ZBpHeVBX_Eyu3ULYH9GcL(vb?oax%+!IhHN5bL9u> zezFKvLj3fOYCP%(#(Esw(|9|_>J6;37%{}DRmacmO!ap~R(}EiTWW`Y2i)4!9Qodr z#8_4#$s~T50ZCZ`w=yN*i7eSDcO-Uo4Qulf8zP&wHqh3?1!^FZTPW(U~2-0KwCcp_wI(S{7|Aixm+-JRx%*PX-aM z=hL1&&Gf@2c&d#cW^@*?ch<(txOU)~SzgKwEpXO4)a;Hfos~N(N6OIhJy9^IcJab7 zL8wxhK!yrt)BR{dswf;;RMgYOt?X}g2JN4aaEtrBHE6}5Z(AgIUl-MFlhnSeT55!4 z^|96|a4Nr!{J3wtan&%hcyv$%gOvxokm_bo11ns!s6)l?>g$?*y}Z&gb@;#eeN+?_ zu0`^-4O^XrwapP9YeYw`e}S~?U?|WfXK-*R#%~V>wC-)_Yt9wA3d!!=KBrg($gQa{)U#9{L$=Oh#|wi9X^Yd@dbV ziKl&YJ9(DN_1iTaS3hmr8{Zc{qX7P-U~+#?gBpAOsrO%)-n*3f=i9kSLc40Ml>_&) z(2DEObE2N>ZUAW)cPGrX>d58E8$f^WQq?mWR_lv6=pkLZ+q|5H{V(`xa>E6}0afVV z{;002*4Or*$-PY#@8scHId$)b%AMdpl|3YjC(nrSyG4b9W$*(4p&)~~%G3=Fb zG2KGlH*39U=f9jN>BZZ_byhsN)6pKrf8>p|K>)V(9YGXdKa8RYE=93NLZ2(vk8>aO zD0g`S6|Kw%`piFw_J;C58?{1!%I&+?p7}$(HG8Ssjkq<=bBm4K+pcL^2xa-3^&9_c zM)hv5Ic+-+pP!)lcxSkIquU3RuE zHYSc1M*m?7>sGtA-5f*xS6kX<{~b;#J8H|0f720yMXv7B=xD&5MHkqIl_2ba6FE2q zLrm2V^oenOd7bpMl~F)5eo0`Cj;fzQboQ1(Hw-lO_JUAV*-ES?wZZSmP=!>|wcQ%0Jkzn;oQm{9d*M&J z)h!v8AjXptOM#i3fTloCf32fZZ15xCA<@JtpEMOvAKowML!6z|98nW1;Wz)05pv71>f{rU9Smp4m0PH$U&qm!pE z8z-(_>@3;abD~&M-r48<5esR=pZ=IDtY6!l#IVWlMi82^6^OzGVvLPRt#~3qjn{5i zdqLYTxkQX>Nup69HOm4~7mykiScJpxRPN8!$&uSXn3tCqU2c_v4gqpYSoUhj2zy&c z!b=1=^rc87cgENYS3=kgmHir<;%Ts;2-hky?Y`72b#t)4NFAbu4Vl)PEF1aQ=0(ZC zU@@SUC+13Q;uf14VB0|^^V+p=`1m+@e5eiH+A}?rd3eWuo-ar1MF-=DAdb{Ct!r1_=brNNE2sAw+sm_xYf>3C__UN7E>-v2PPS$3D|7R3kULD^=~zKz1^8#RL4zi zIr@u8>`NOA-3478+a3AeHN;c)sGO@MH&nF5s?FC)TZHFd3}nf)6aI@MfmDrApm>if zv^<73YQjwQT2$DjjL2TaK53{?W#N4GlyI)uX7su7oq0h@XW$dVa_!UfVV2=IB{@78 zwq+cv<=DFpE#CQeZB}4Ok7`0@K;&B7L>as4?tSU-rn*d5+^_)jQsYlteSRQ>>n6du=HG$tzKz z0Ii;Gp(A)k3Z~+VnYEP+x<$7!9AArATXv#Rrk+>>odjaq&w!D^or$eTL4$mmFtDe- zdq4JO*?87iB%aSl^Y>gthDYCw34O|i$5$HEHU%FLaSuO&FgydWhoOo$#KtqrG&~fO zDVDLy13FrYG7^uWswE3L>rmp$6m$;ioX8T4C~W(8mr$5 zCntZe{;7TC`J3EgE$?tAmGb)YR?qJpz2tLM@B5nml5O;Omhv2;Cwc~EQL?J9TEmdI zDITc%*$UosWp}X=+c@jp6fmT)+Fg@3KF|_i_tc*oe|Emd=&c_8t3~U3v{+3aIp2^c z)yX3N)DR9iF?xjuz-H6-Nndyl}d4;##)%{U>L&HHE|fH2$a>q+YwQeS}Mi42Fiy<8TD`P;hevpfrvHd0%C6&_o) zJslcPjK0!g=ww=8n>Tt;JUQ3)i?T$_$Z%V!1R@%9wluG4JD`V@-0mRw#=0q+Tt>lm4IYWui@Oz8ob#=r*jLq8lJ%1a?k^|&SMTm|0bbw^(!AXOPyHJ5ydV7Mw+27Z$1QH-Q-1BQ zIOdVdCU9=3Tj)lmFYcltF2s2C$-jku)-cAb4>JauJj!oem z!6mb@35iBu9EjTG_J(^C<`GY5bUihLYL*))dT%Bc0_Qn3)_d%_sF(xFWJGSLQB2-0 z)FZ&7)$9)RyFo?~dr&RDbtVsR=^c8X&P!@o&LW#{o~B0@%&-5JnF9Rf_$+N8;==qE z%_<@Q0ATT4ehIhKSNqzRHAItc+KBvkcCfFv`5gEp}HjY1*^ z1L9+&Tbg28GB45J(QPjQo8sz!Y-wuYTwl|oMa6U06)ZeHX;^TN zYZeTqv!gQT$PBF50~L*b;d_(v}kB|e}%b?;7Rth2ef>2V#)HpNjcL5Wquo5WUssAM!X*=pNc7k(qnQb1j z&)?%h4~?`y&<)gOB7$g?dr(nup4toZ_0-bRQnHN?0Z~IuRWUGj#+mpnMI0g%3DJcg zRNfHMoLHku$w&ccui^%xk)dCv0%v2^FrQ}*S?3|ajL2P~NH^saFFk;k-!Mv$U%vQ= zwb4bAAt7V7Cyz?$N+~Ggv&2QOa?y70)@9-0kE5fLCu8uhtWR9iYk&xm0d}wNSQDh@ zej}a4==}h^0=M^{Rr`4zRHJ~_(;+stfQKh4VRfAakj|Y&0y_x53z55s06^-bdE1+5 z7kBARGmYe(tYN(K>ed}l3J>lw%a9-{P6}5U$j_&!zTfzi=ZHm0!2T%%7gzw0h&A&m zc3kd-!%Mi)D=I!|Sz3~ko^_f@h?^t7S-H8xKc0j4oO^d4fUux}#i`}n7Tx>7mbSjA z!G+~AyHw&kh&Hf`%>qJh=Ot2*P#c}Hp8~+U-xya<6@8UCBxnK)NtQA+Hddb~>PQwE z$vnkDTI)=9=8CIlX^df=AM%UeBiCh1ivB%Z~Ovjler<$?PKMY^a9A(oAm2(JH znEwIG>k{^JalHl@-N+X_NbvlJu|Dst5+2d_T~$qt(?ave79AA(nDz}t(5VIlPBueP zH4lW=s$Dbt*L0fket{1%$vnoUs4I1sQ1i?Ds(D^1#d|x0#K6m}D^&OT2@hFW5MTqk{(jzLVim)r|fW^HoBZm%&F1g0$!<+{b zP}46#dLU;^IdIOlAhd`4D6cuNxofME)3XVf587V{xj~HvQ5kDpLA!tT1+?|y zyolEP;*89~XA2>rqfz@{2h>)sk51KPe|wr`>Grx;G7}7J$amOQnlosHmVE&sLGB;p z^`H=@?&Vs)JTK~GTeK--pdRvK|3_7p;|&_5xanG2inG3cB3)WU!){Ucf>>ngcT?8Y zkh`1C#vOA~Kxwu!qfxSRz%c9H_Iat7xYCmo#yvh!!^VLtPtRZYu>1GYb`RM1&K7pU zu5fKc&x@w$9(;#t+&70%vu{;{DAwl48r8H?IRPjpWU%g*Z6#A!jULz?6>kswq-^*K ztC?(}+(>~u(^hU(9Bi5s9?VVcA0VS<7Td;QW~57BgoIK}c?^T)<5rzoRSofOZ_(U7 zd45#P9X8}ipAK(Ic)h1C-Ns;ET5L0>O$iVtj%1bR(aK2<#l-!p3(7UZ&~jW+yz$ty zAHe@ndH->VjI}@QUj7y}RsPn2{$_9gyUMe7{AIxzdHjmJYE^^(NZkIb(Q~lxmB=aL zpqvB5W9qZ&n6f$nY3ZzifC@{MW73LTl9-77ymS*EukS9=0-w?&QY`I!-Rb5&#_VEx zUYT$_-uJfm1xpKmw)=C+Flt>K!&EE+zS@w8_lP1U)t_91D>Gl9vuN-{<0*!SX&J`J z!X*yv8^TZ_D(2Tg?(>FQog_Bs!R4P0?73Z}`v7dRmI+2x4{F210!JC*tRNBM50ANm z_cwG(xj9kD3Q z(fE8Z3Zh6bE9$HDbY6I4Jsw%XfY;6vRACv{I|Uvbo#M9A5!U`*VNEAW$@8P$zEz7< zBMMyaGmT=yNrJra&}DVP27Gz>RRXlN?1~_(n;>WI(f|@xEZam<-fu3Q!e-z2VHt^7 zvJ)e*KjY5f-w;HAd6-ip=jjsDlGerXW$djGm7P;D4s_>CCqftX+du^XmDI?;ZMaOA zN@~=*op%MziaMp)sPy;(-T+AhJF% z1tqE~M+pLCX?t{h@T-u#5E*jZd*UuuQX`qkMNmyr0KuCM=Qi6CGUq0gA%$M3uu1~@<-fD_LtFD2!!*#0AY@e@B_P?C#**zbK z41|leTq)pqRuFwvlmuKl?k#R>_$t&qQm-^hNAEi2Jg3xg9JQtR(?eroI1ge+6p7I# z>o4Eo`sV!`RUQ{ibZIYNuzA!F@G^83GX{D2`|-|e9gTav$C*1gjOgt44T?N7+c=(v z792B3WG|sfEe3im(AEq%%6^?<1Fq#IL1D5Y(NMYhovZ~kYC zxRnB>+V5uL0X#852ZZp=rr)xw`+z?-);S^0@chWM=+(QXvp;mHS~%JXQBV+@@pdra zz=cO22Sk5#3P{=6s9eia8Z>#Q)x0@_eJDhcCLSfD8AMrRBA_cgB?%DJ{ zSG$xe?{V#(NGZ|zy@5p@4YYkv$h;RM%(d^m}&qC7q|G}>RPjz%nm;;de zZLyaT{};Rd|F4e!-C;NXUr>S#lwtbo4Pyp5dI@$zYo;OdP9g{tjRK>Vrn2}XT-P<< zU0l-5hU^-d)bl{%_|n5}?w2jf54(bhV6^oi!Pf2TKFZT!%Xo6T;X{dA?lx{f#yNJdE(7{{fwdK%P%4 zIX}o2YGp_zR)Fmda#%4`LHRY@=kO)SS$aCV!a1*sWWSFLhcy)nhK zTu)zV6D9Z*XW)5;nkEiD?)2$H9EYJ$aU-!hWJaOtU}EFL{5m1lm5zI&lnsp8w<#q{ z2d54I^;}L`^t|NwNghj`wBuG|#8tMXzDN$Jo|46>wBu3&gkdvAstd82)Gv{s1Rx~x zxQ-p8->wOq4_ls6f{t$g=-|`IjkW8~XlW|LA7A)h z*xA3pdl*#QA-kv}Nm*APpD#kmgvq0_ z(ZQ=j4o!qi2$+65~lsKKcp@-bSC#y_CQu+>A!J!l+Z8Dyep zpu^2q#rA0d+d9bDcu&Gq2L68M)qNw&(DuCAFV+*b*3am6Cj1(1Y0O=V%}?(~-hYCUtPM#5G!kW120{p( zA!V~7$UvP8XXDXNj;cO@$~egEMrRgPWl9<&08fR)0UEwBu7Y0G1F9ZH#Q&}Opnp!2 z34v#p`Nlf96pT5NSuOPLS4otP1E}3>j4CApdNd*D-%m;f)IGITxBC>o<|8cUMnqnaZ1#yxo zYPew|0TxCTOifl0H3ECyTtnbqTqK7ov{{MD#8aCutD&>`r86aXdhV`hx5hEJNT{qP z3R;A&suhtFi7GM(M#>SzDbpFE`@PztHzC_0#h3EZxuMZ;Ofcbv!M?N1-W*s=4l%!8 zP^)%kdVvbL%1bT%{Q9gS`15@k3L5OAHhPQ-8Ii%D9OgyEVl~GLkDs1R!sN-Q%jK2D zPnOFH{15uvE%M6;j_()Qzm(%+q$?PBwgv6Z54AmWUTzS&za2loIu=w|^2+q3FkJm! z1oY-D{G100)8)mtN!&{0@$?`~mQDaSwjnc9T3#-%aiE!3Bygo_!*qBZ2!XM zpuUV>L82W|Ix!PT;W9BKhJrNJffwYvGo|eMVSUfYN7G?Pu~vVx1-G+=rpzWG+UiT!aSktQ#*LqkHyKEtBLj;L$3we)y2c}wL#W&D`1I+Og3_ zB_x|fYO~qh;zPC7j-o2Wbzu;kM*dPk9;B8gk+0@8AVcCnQcq$coL-4?Jz<#_Ti2fs zrkexmH#zGZ>Xra*`Y33@5D6sA{synI;YZni!&<%Vbg|odbJ2!7d!&y7XL%pRc^1$8 z$PL=8)N>i1<2&pw&T%YhUZ6g+K}l+DIrO0#nVmx^pb(3RBKk}U352jlc1OXD8#orT z7PoxefB7Hg{QtZ|Es&BNup-mPZo{gNd&e3eT5L_~!~K{>p3L0O|xorJDIx@%}9 z*{z&P*=GM_ue;xKj!53U*u)@lR*CYW7zy1ny^rsGT<%jQSy|)gz}$%p$m&j*&jZN^ z2*qVswCYIqNQ9Ox83@Uo^u)X*s6%S01I?T+!Jv6YA)Q=utK=ymWM0@0_8wlQAZ-On zL&~f-MAxx5JC2nx%6GCuCsy{at*tGw#FVH=`=9AHedqV(<3@DVTAeI_!KLF`_mT?9 zqzhj*b4bPPn;r>P0U`6lKO*TjO%S37YB)rfi#w zkUAJw2Gd@^FsYM+@nLHFw~O%C*htmX+7SN3&B;|=pNG}egU!fPdiok2-5xj69(T$d zZ^x(4e^1Bks&|);_V!Qk_qwU7v!R0iWqKISt#t9edQ2jW2^T}u#aJ*?&?aRXl_c4v zkWeNlJcGG;Tya`D_{zXsNR$Ac0lWCa((iUyj9ja7{6100{ zk}Lz~aA0{&1y#ddi9(x1l~b^Ks7}i8C&u+iPd3q0XJWHPDf#wr!w)Rwp?dxNwiGpf ziR-^S9(=FH5k@06l+->Pb{|l*8ohtOC6uMgCKNyfGWooN;dJ4(mgzU-58!s>rWlS;^{dx?Fj%t%d2_6dOBH+MUgGCF7vsWHDs7NL0 zCr{G^DCD__Ut?+`5S)G_xrF%+i2xxIQpy+=dJ=+KNs$&uiKc^q(|(9F%508b7sl#V zXp;ttinFN|J4x*~R%S6npp~%7+VsY-&rLPL%9X4L8%KSRaV3oW>q8`;y(9{C-ztmg zM=lwNwDHdI*Z&NXC#pyjS^i#zpWv{9AjPHAgg(k0H+Iji3|CWm-mDf+Ridot5-cu5 zh-|sDDC=Q8iboBv1EcEwG39qv{09Negph?20r_f5*~D`E7ppcNta4fqYg*hgK9#{h zPl_U?#~uf~!`xk;x`Rc#%RG(LTaR-d4|5PGjY+1XAHlg(I{4`E?Uqf4PbNX$QI6{PcKn(4)AcL}=GN<@G*PnSlzs|dQLgH*!CxHNy4Xn;W zmRxlIoG=72^q%&PV~8>B@{H8tpfVah#(@h9W;C%FJjN%;NHCxn8L+Mkv#Q``=`7IY zxea;oBM-#6W^5;uDX_P5Z{(Tz7c`EHt$fPL;45j>+BBZUZW%(}6OO|qPtKWs;zH-Y z#{bY*F>=?}T~7Q7p#CCU(Do!E2E+!tce1_{hx2>n*j$xB~^ueJlJ3B!YwVK_B zArNdp-4)G6ed(Z*NVSq|f57p}PTRC8&Wa~9!JJxQ_h8Pr7QsP7tqZL%cXkn@1_e-o z2*lN47BKk3oC};a@jZu0@P^Qt2%75EdzOEy$9D*;Ohj6~bmBlnLz=*vFeD6sj3JJ^ zbpoj;)XBXf&M~?jL-&CoUuZZNRr)9(82O5R)V(Rc%&=*l4Uxv&C$?Wrs?m@kI3X}< zR21;5k(h>fv2f%oHk+?}ipY5$-W~_9-XY^VV6eTJSmuN{dfyuYn617|j@UvLubil0 zXvrR9iU=DvXSh-vK)L!HeOi+Wb<*V64D+<9IT2%^k!@Xq>$ycE)Jy^4SIvS5-kLqk zisg8QxfS3LZvKmipKSjQf0*fW%j(bPGQVO{6s^u-n`j+uXSY2uZS~WgaME&=^JHp5IS4y1L31 z5F7vLpBV4~mU-EoR-3sl#h_PP#H&aQK^ZIldcGLy$zd=79ENcU(vc^GUOK{^MRWeH zSUfgdy{*`HW$0j{?GhzvmUeze0}qulJm#$AQ9v>Opr}%_CFf^jvu&1vq(wE|G+FC5 zIT z6LjlZr#-AQgb$>`0qB;zVusr_nbdvxovdOnY<(R>U;02!?PO1n{~YiqIR4c+m47>@ z%mL>*6wBbQX)bwx#Ni3?9dP;)H+(6sXAtQ`>&!iPF`m2nEH>h6JU{!uE&A zpYOke@feJ#?aNL#a@UQSdg3;lyJ;XNls(0}FlVe`{Xh`oN9lmWUGw}#!3JafFpqGHuv)aVw4qC?wTX+Wbou

@13h&;xyFkBEb%kd zvq7ylp=@|)Hk3D6uGM*PG2;^<()TouK>@d+3}~PfRP58c z=l1X~JEnZ&7X!;Ys$c}`stDKxv>7VqnLtwG>gCWDd^eE}fQ)Pz@;`8EKiYTmkvW>5 zEJnLY$bZ%sE{Cj5u#9?s)LB8@YMW5 zs5Zr9Hs^UQv;hb#$Y`zt+_vyYTh_a3@V!F~eZ*8gvva0g#iI5 zYXT3!l~Fc}SH4I!Z!mf+AzR=J@zjE!j5oNZ32+8{DTqC^%zeSiGP+;`8VLtVE-~P$ z1TJx#rlC~D_nz*F?5?U#qj+SV3>1oYANIIUp5A_O-s#_roE*IzS8iVe)_YF_PovjE zxzgF!iu)saBKhoC(4E((u~2wL;H_4o2px(&&}=h0h*5eJJxgD| zXEAnv-agOO7k6D)?Ve)9lnP=xdREGN%3t@#WuDbnI{nS6$>!5Z2&kgzIU`Kn{V7`% z!L)N9VE~EEWPssWcccIU!MU_D6bo#C!gH|dO;uvq6!sybsE<4J_Ehz7R77x*ePT}G z2y0Z83x}pbTFV~?FY3#U+X(;i7UY!qCS8|c( zElayV=jn7_o)hW|4Z0lwV;LNrV2t$z(D2t}8r_D?kpD}s?@|Bk;h8Pg4g5?1+-z0( zG+L)l@U9&9or^r>jClsn@WTi4z{l|T5|m$pa~ElN#=8Zh<(aptqH5M(uf7O-wLFG5 z)oa7Q5RJS2<5(WE3jcTHc9_K`*d+Ttla@;nPr2HlUI;OkW64qEI9RyXMK0V=k&CWM z&__zda+NY_2s)j{x_uy}%lwB@-VA(!?@zHJe51)G=r3s7paNCvpWG(OU;Q<#;6<5E zufxCmsrmlCukqzxfG6s6BLcyZiYGfLW}xd&;#ZVA)Mh#FY^=zC6Jk4fu{p?EaOE{V za7g0HA{&yp6ud6+VzsO;p1{0wjBnnrLc-1z3q;IC{y2F61r2WqGGnuYFi5c;14`{$ z`^;PMK9xCS>N~=bRR*K!6N4ZB_7dZo>NRYah|q-otle zhG&h`SFFwmdoE?~tr~Q@PtM@JDZc)ewWhZ8j~i-5bWuiyc;!v(_8}Y9|1R`bv)&ZF z`$Q2I%OW#mk6x3kKuBvMG~OcPTIv8H8m6T?6=>DHI-zxkFN0&&ON5p7?M82s_H!e> z{pj##Gv?PxeP>UK2&38@%Z`Twcd6Wk#p@ywY3?*;OJ3%jxH9hGNj+h=?u|_OzaM7i zjdaQK&E56nU0%=dN0C3>dXYZyn&OY(lyf;d_M?I0y<=P{YP<*7^^uJ^M{#^1%ic;k zaNF^x-(&f{;Rvtu!9fpWj)#UAN?d>qFOv4Y=w@I!QNQeM>KtwEegN-oIpDte=X!s- zQ)5E;?_;`|6~30r``HJNcYD}Bz}gcSJ(E0rSCOUbTjcLf#uBt8<@Zu<_|czeg4 zhUuktw;qcEK9GCpa~va%{wIUjLq5sDAy1 zyk~8bB~ZY!S_-#l)s^qG*BP&yb;PF((K%4F1X+pC*+nD;-F-J7cx1mF0FFkPP+mG- zHchNF6!?~_x&lS@ML*~Vq|G%^R4$7lLfA9{4Gj;q2~$*Nz;VD%*FPuJtC%K|^*D1F zR~xE3t6JjWp^J4DLIaSnTN~wLivz~GIi!p!oZ33ST9~0h4}BO~etzxm-7~tv*lL(_ zIu{GIcOR1??CE!ZsV zR3Xr-fM|jjFMMQ}7S){^@g4$#4bkcd=Tr_OvB}+O( z%bhBD7$5Z1U1W?9=*BGv7#TBfkrA;)7TY$n{8QZEMd(9O@{=_sT~ac0}U{ZaEdjE+tC(xtp)i`&98M^rH~)f?GLhkX-+qI9YLz?oT>u~cV+J_lu8SQ?OdJUAVXYw83;1>AbkPBKW@2#!kdaVJGZLAH}mGSLoE| zuV`Q76Z5kA1TZ=A$;0)dykoGq^zWcPD0Uqx(ZDrGI!l@i`595JP*_1{}TxwhEgugX;P(ubYNzD|h>oQgsV5R}~I<0%3_T28Bo zrVgkD8j(uDE@WsU07`SFJ9@Z8SQ_8HP(P=A!g*WL_(BwO0 zS-X|*VcMbRvP(t*y0(j;Kzk#R<>JDzEUiy{eq17J?C6zS9|32 z%RJZXnEyOZ^11c;77-UEzip`Q@eF`!hd703cYd9y2gyq;qY>Ns*!MAIu8Fh$G2z+J z8aV0ydcMr2&Hh7-+lq=Uv*@*wzrD&{AhQ%c5+4CyYC)Ytb1#LvG5imnigcee>TNJq z=nLPue~d7-%{xER(%qHDv-s~r#x#2qgO;nC#i}MGMBB{Wlcp{Pwfin=9ga3J=wpoL zM>?>#!A6yH$=*|Qj3-zV=hwdVtv40p5SMNgm6ygp8W3iZ%nOZicl6TYzEmL1A=-h=pdc}EmZ zP}ZJMk0RevJ^>s+8qi_vmmM&5q6m4RuvlgxkmPN`tm{kWBdQ2AwuI;_UHwOK44;m| zB(aaTI_8-In^%~*Mcxc)t)oMm^2FiR(C(H+*dE?Zd94BPZy14g1-~0H!RGD`xXg~7 zL=r=@)*S&Y?bt{3@eu$?ay1Y>#{5Od8roQSo+uNewW#)hLqp442+|JawJ=8^Q#Nkc zqaI8rI84(@Mst(L%8c-nqtu4P9A#73g3A=|#@|KItsc%+>#A_oQ8!S~6XMIcGW;>) zj*`(4BC8~z3uAyXxffWgk%cE6(SjUN(i^-Tw%S{m<-qTk1!__E%zf3;P<44Bw4H;N@SaB@dpCibpO9!Wmmhcc7&m42v!@n;Q@+q`G2nJ+~ zrkk%DV7@;~|3=TNgE%!Zp%+IeG0Kr6*e5kQcuSh&3KgZi2RaC99DBLy%RgF?vQ=Xb zfT$DrMj#ouTngD>TZ-LV6|NS$O8T8XvzudO>|n7Ou0R}~=g(4;h4yM?W%qV+ z^8tjYC+Y&{l0k-_`=02)2zBcqZ~^(?d|dDh<^bg`YM*a*v~#2#Rf5HZ?RCtyO_j|aW%6#{|Fm^4{HpMB8t%aX_Q z`6`0UPtfRP7DVZE#@j+Tk`LBx$)X=%@DFLC`(Ba$8F}>z37qLrkE|DQaL6AhdGN`) zvUX9U(dVcVmggQh_Ci?io$-&@77t612h6=0*|AZAgkyC~l>3?iir@|vB(4VH(U$@U zP8$!p(e+2}=xYMN@Tm-{tq}kW9;bWngtJ2JB#~g3rC&OA2emH?V zM$>uhzOv(Sj~!JAL8hdX(MU8{bwul186a|DFc<+Mw3l)XdgDkk`8AgTm^h4m%d+E1 zU@R^3H)_#Ae#M)-flL_i&UX?cY_$XFGgw6CMh+|{$6>PMS>RLd{thKdlOaS$!8lp| zn!t|afZb9Y+CtpaNNjSR{*{!5s~B$wY6I?2M}Hv4Ta9SKV7OwggV_{>8Oi7!5M9GE z>4-RkGLZS27JJ6S2W|Dy!3aV}%0cVwJ{j^e^jelQ$$uqqFx|kie;z}5JmyxJDSq{K z=mLL!zt~(1?>hX~ZDlMdjBI!Na&qnovKW zkzzF=Cju9TR&pa^c_3rIM3s`tEy92vQE-5)8#N0{+&vqhZ6G)3tO|yi8_wd9>A4y0 ze7X&q(0zqt6~tA_2xnMSF#Wor$i7A3eF+u@cJ#Odj(5d-s*}7uWsi1M#$Hd<;Yy=_ z*#2RCxsPCAZOGN8allA%LJT}#!_bM?=>smR_=OMWkf(|A?gp|Dsv;>;xSsmji{fq?YRAh#lfmB3Q8>d zuHKt{C8Zgn4;2?!!FHdIcV43ZAtHx(Fh-^>Ebh3?;&J6EdTOqXw|w|-_q%Sz2jbCX zSJFL~=gYNCJakgPn8n$2jD6?F0Wmk|hCdTOf?oPnVgO%dEXmo^E)mL|4a42a1bZk1 z+AT3pbCA*0l#gt(j^`&wR(lmU2g_4%->X$Jx{TYV()vc|!^G~$B@X7~Ls+Fn-jo~F zurZgKuMF^0((@ajR`&;E}l2QQ6l4GyL=C;Usz9znl`qU)8QyIhOQRN|wxa@lObeKy58> zi$jNh446ShOmWuM&RsG0eW_B!g|R4^jwr-HYE)iAN4tg>o!kx%sQ7Fl1C^trP$&mK z{F-hHgQh0FhH{b=7bB2@dmWs(6G632cu4DP4Vx`vkPoRxlPJYlbixBCv9RV@PfV4; z=x%-EgU^ig=Wo}YQf2825!sxQ^PtjVjcZUySsW-EbOY$Go$wnnXNc` zmJ%sVb8#qwqMX@CSv9ATaulsDF@HMlg*z*5qJaP>r!4<{kjDgmbF%0$>8$~(Ewu3C z>uA+XZ`b45dAlPERC{o(AxJJMe5YF1Q-h)~U1$27&7L{ilmI++KXQD0`d&>8RH!R! z)N9Ud%a&Q%20vv0R5N+9tr%gr2qWejd#_Tg#V1IPawXB59KrD0zc5F zt4BSzy=|){r>X}VvgT>p&GU18aHVB}CiM;u$7|AOK+2SUUys98b*Lt}q)Mjr&fxhH z!*!$_isKx{v9p?!)5zdoJgKSl;X-fx$6zh5+6b)KY|Bb7&-{u`45F2+y%2F9W&$+M z!wTK-+CY6f7;t{(BxTSYAn7vXgp=HDpLSsHWo&yO$qkC%w~$9(h`e@WfE zN;h~*;*KkqJC@5rAGtIsELeZq_4V?IFv@c$?(E5YOT8N9DoS6s<+4%Mh7~+!SuSxA zZ6S^yW1VBtX^NVNAAdgGk^yte4&_n&T07F)inO+c@Y1)^dXUHL{T`#wMQTw zMQex7L?h<9UdTjS!)O6~dwF$b)vUYkKmUJ4_J%b*9>rA?>a52%7nDNEksBy!n%ipA zAZVKhktuCZSChFrK5qOpg<9=lh*Rl}0K4+QCRJ9bRFjiT;ve?t^gLZ@FEF}%ExdqE zC8T1pQ3kGOm%VQBY)q7DI{5U!8==JqWqnJ3GQk}@=>799Im$~8W`!F4xU^Yk^6^u3 zPP)SZ@BJ1Hte8`=*sFI&{HH^#J|=I@!CXrCY;S zf)Q2VVEiITsTd`6_42Y_G45}qU6Mw)4OR_Zyn|PHeL4GZ|B-F-ZE=NscxXLi#3%YY zVAd`e+5gFQm@rDLL$~@yd>rC|K06nmN6xGFod_&I{_M_d&|yE1a9vk@BxZFPlGIt& z!dC=#LMw)jY;Of;Pt?IL?#;{tmkBT#s!SbCGDhI7B4EaOk&Vg@T02wsVcJ!DA3>(g@aJN`t){7wFd$;yOhgP@!D8BBlVju}*= zxBaIXo&mTZx2gEUH)bLxpvUugpf;yCw{ zG75YP?9pQPnQh=FLn|)~M<5LT`#0+EP-J(8&OG=>;-KKAj#(-+wS*Qy4#4XhHooDN ze(2po){+>)$n~9=+*qNqZ@wA_okeR1(>RY@=N>t$Lw*$@G=NVd>$eA6%IX3WtIES` z9>Mr7=I*m^ZwM`ST7rNewJAfuU?%J*5AR>tefDnDUePXEoB)o|Yk&edhQl58yrf<5 zPgwp7?%>*6u8@sd@(OLmKJeN!Koq}0GJ0lCq?AO1=_n7+WuymaTWPYQZRUfN1*c-w9hm^oTS4FHo8D7@Hf~ z$1i%NFS(<}J^%+oi-9#JkbN&!s(jI^Z4CL2&qu$3ISjaqXJpd{-Ij5li_@UkTYs&pYy;&C^xvxqz|NvEG)L;P^Y)6Xw zq!Gy;DvD-g>Uk}YOKf<@=muReRI@LpxUZ)nmtv` zT&A_NW{RoZvZ_8r?2i00r=^qd;rj%wjQ3JF`=e-5uWIEw-2|*2zGtW@+$v)vOcLd^ z+mra}%V6WqXX(`%&nuXLX2>6Ifwr3vKB;MBt8mi8QW6zcCCos2qOxuxWk0=!W6TPWvLy)d(T3`Rw8)EAMCvK;vXS(l+O^WJhd?1*TYdFB?#J|>C z&buSCL;~Ys=+TiMH}TrD^>@}{J}crW5ci*P-l*{KaCe3h)BKcj7!^wxj92UF9(ho(!|=JV*?)(=FE@j^%#()?1U7AP3yl7(P(NgZr_5 za0CWCVNl05=-G1s(6rp4Keu>I%iA;+yuY*B^LH8aSO=Eotl@Qc>E;v@5SbMou|@Hb zz7ndXlD%T?TlJ1~yMpi<|g#$v0^l`Gg3|}to=KzC2Gvx`Vk~r zFJX2yTHJ>Ll`&`#X^_%kU$H~f+s(ONM-rmbzE=4=5SBr5Z%7d}&}FHaxtPsnL|HW& zIAvh@mJPc7zCO@^n&-VgAVEbeYoNvIMYZ{maGU7K%E`jRll2dMYu*oP6iQt%I|02R z{J46|PbG28z`=V03wgB3Jedl1X5Nu~(cJ6r>TB32qkThk0G?#&e)=C;?f>u! zv_QsDU_k)@AmRR(R@>3o(ALIC*Ve(v*x^?vH~l~CEdByOIHZ&zHV<=vItv9moTmTD z8kZCP%^F{qERs}k=+f=wl6Sajxk<^#IH7M6-`VbTy_NLbUq(i*i zEt=2^Ki4={Ku4nON=kN08)l%DYCG9PtC=q_H)nQ3;V!_Yq>993p%MbNiDsx4k(Baf z3gMemD@WyynI$-aF8nq_`4Y!#>F}4d5-qM~xR}iSE{I#(a|2kS$ zw9|GV`~BFvp}QRdY@j=I7ZrPa;e###9&E3*csO{rt=Q_%`bIWTN{f5V{_4)qLHziL87( zw&Li_SX^{Wo7wHFZG1b%h$|~#a8tz|`Aa~$xRwlu76uOBaQ%ywRZO2u`lld53>75U zYKBq_%$gyoL03ZRhUreJPB7ril}j?LpX0Nv%YYS~!%p=Ch{;4ZYk!TRe+34_8lXMzie%w`?DiB2oyAy#yLi|;Grqcr zNtr-zvzCdVn50$bAci2WB({uW?Z4q;A!0tF4u+z(ZVAkzQkog5n=_sie-{J+f13ka zHUk&iv^TXjzdS7HVGmqV?pVFn#u z*qTz|7nCGl@kJp3hOB6|P$ z-)|=*fcP7%Lbx@^%d-=!I;~b`jJr>2${Qno&YPAEiPX{Gw}An_XY*yRQ_pA6)c(jN zcf(Ycs|f#^$wXn~)~s#Ptq9A|V8J`c9*=%2JI_&I9jHevu5PHhxsSfBnod9h8uGOhiMW z|0e35i!||eB&BII-o->Cn%mA-u!2&hG=N0n%fP$h?F-Sdi?^7^Q8NP(v?u#KhE134 z!q>F-XI=D;DElcFo#;luC;ZE_FBvb;KX2@e3XAnj7T)Ef|X`C zr_rhWT0%xQ+RILqP6QOR2YQQ}lx*Na`5>NhYi z=X5>`10#NBk|KkI(=B`2V*Z@P8K3{~q{OtNfoR&i@z#dDZYx zNY7i~ClT~?GHi4dOY5?oMG(rGyjs*1i$oP%wsd=mpA0y!>uk~xhZeT)b58N_VtM^7 zm2X&et7*Ib^u7Y>5T)3*Y4yQ8G%E>K|G{cI25oH<*QWrH)@t&H#9Gp15+$F96gi@R zrOQizUea5I7)&^RTy{fx#cke~9Eg0xF+iWQ)ave!P13KWqFBQ#YLLS$LfPs|fCllJ zsBh{cJ*uDpHDwZ8v+sOnkU=T{yv;zzbh)D9reE21ZJp!Im;)JnxH;d4g<50+>Z?#Rc2-^y%2U#MRbRRbs;7l$5yoY;q@4&!1=NEqT z>(v|)DQ5UBI{Xllx}wO(OKM&K5;D9jX)HD*v@1(GfmzsSV;2;wfY!AmmXiXgS|cA? z%{Ym7jg&sSPrsG`ndo(!%iymo-la`cf zuRmSeaTe3|2B3d?b9y~Iuk`v}?&@yCn7$y}zQ&{+1*8=Qfk_~C#c2^EMrUH(`J~25KJjAA(IRKi2xRa46*_XO%9lr zPn%tFFh9)WathT$G zCiGonh{05rin}@QON*FXVM#ZuGliZ}vBy;#2l+0CV+Y^j5%9IydsNk!f8}WIXt?xo zgrCFMb~yR!h5gBsXI+)P{fa6}{9{FQ%wl8a1Ow(z_%!F}t+_gX2q@ey56B zPEsuajG>FK4=O}S3aLhV6>!St@}sf6ZMK<3Ue|7?D#2c$u;45&Z#7M`y|**81P%YC z&FBB!^|MwE9PN{?-iQR*;^V!R#m8$!ho@+JV_>Y(dFgmAcVPFH2Ym%)j?8gy7QV}M zYITi^u+2p$;zj`PCh@|8;S{ec_#8k|=j`BwqkvRG07%y>Jj=L=*Lpv%oxhGIvPBD# z*?b<F(ndxR@SfI2-tRG)j+5w8NjAE1u2aT$2@z(xtz z)$uVsf1h|)cEKRypd`#3R+dR^v0r)Wcdwdy_R?xqJC{U9aIq6*KK+b|99%$TF70Um zUY!W`$N#ohu!xlXv9hE@Fo1X5LCUVOmXkU^!ozyrgyQjrbT5Uaxv-Y4u*`w zVoof9>Og9Y$yMDL^J^i%Ck2(|wV~>E$;`K{XhxL%^?$Ua|5=@NC~VB%{~D{G=m7v| z{-0{I(B3v|mnQ`PiTn?EjihoC92Y&&3 zsUu1bj>lXUA!A)HF0v<8_(abcdepEeoCttR2t~OM=)N&2U76#kfxLUQsmRecuzPpG`potNTjrVKyo8k*5~ z+-kL};dUcr#X8($UxE*NZ}*3Viw}j~7+1nNsO>BvOawGR^!h#%?%Z5Lq-$r+R(K5q>3XJtd-XEYSGSj?VSMoNTAozv*)i~u|2rC__?EV z_b0Cho#?f6XVb4&R$dFjqOK@vKYFQ$~gwat?rtZvl*Dl(T zHiB4RJVvEYs7^D#0QFbPM_AvDc^8<2LM;TG-vnw$Pt@HJ9EgB-VtFf$sOcg6aewtc znP0YJc$KxFI68tdyhyH%EpAIHbD1zAMP#D0D9UWD2|}YbpQk^1oZ~I$INHC~Gi^q# zZ*6_DB71O<5uq4NFH$<<;B8Au$I?_merC&nF9j}|l^{+iQ`IhvQ=bEoP~FuEZDyr% zOg8Ay6ew8|e*u^jT3sHNh7=7TD+v2mL|Lv7hG9ypbO;QIb71-N7h!``YubQQ;;f3L z+)nlHa>0s-JLY1VEL5{J!i7ytE-MGGMPq+SR$i#Ge@Ef~s)gVWkGl^V@H!kG8F12+ zp5Fw0GpOX+v0lw2>blG@CEUyCFxd#6`wT5^;qT-d=RySz!+u$-7kPWKs7w_ATmMX3 zp7mwZv+dZ0nmV!=xp2{B6w|)(*qdB<-Cm)cY+?SCtT zQF!YxpRVC2diELF{0D;)=(w*0)IKRa_45*f^cHM^L;}3+N1Y~=t=ypH`G{oBp3twMnqwHrC*)*j#(hi{xMT%{#!D2(g1cd6{* z*C%QiXqSo~392G0TKW@WQ2o()zC78HGF)0CKezXjrW&dqXh4F5b03Vr+f^i)J=rb=E|%_z2_8`vATz=10qN;xHgS= zF<)KEsBZ5KQVA26^kspCohr1SIy41>vt3#7EB>RdaS2`i0~RS_63EcfwKL5*!(1Cb zdtIn6wcJ?&_|k5P z3%x=8;+;(Qywod|10n{dUG=#Rnp}@)>9*_I5G$&R=c8mSGLD%xBk1}{Vi2;Ld73N- zTc8&xF-{G=aNOEf#9Hv(l`N>IlF|j*=*IQ<5Y^($a^jce>k4wEiVqaj2 z{%+EB821WA>@koZgsM-$G9|VNd`8ihmf8gK6?1P7vnkcm1)hDB3D5)1UJ88+u-rhl zr9fC_xuRfjyp2`;@SZ&Him{m@Z5w$l3*&0#Gji-|^_^n+zzfUl<5bx?oi(L1Ae#lk z6C)9)?_$4?@b?26p}Z#}`SkA)VnhAppIK<_i5Q2gFxTY){pSyc>w-hTBILV%(GJ^& z51E0!=|Fw_jq|vJfpfoC{P?F}p)|j|Y=<)9)e$i|tufZtU|%WiyRsn2*umIYPNoD^ zfQ4fOgq)nPb#$A)hLz*!`)jhB7CD_B-jQm-K0?SY{=A1hmi&G_NOQ0m6LZgW?IaO( zq2h!i0icdol~o?VGJ_M8oN5!ol}~!SE%w=Eh6=^hR^m`j#P0(!wJFc5-sE=Muva%G z^`$oDg#1OyxDO?>z8jvJHuV@jc@C)Xux<}!5afg}pkfi`4+kvYuM5h52Qjr*+#Imr zH(U@_?cf(QGDiVY6b^W%;mGY8`vDWhT~~J0Oj|L1ua+X#XU6?)o9(TjE30ehj+(<; zxRD#rY*sY?>=oUGlK@5vp4Iv=hQMK4ZfvcdHCukpj>q;yNO>)N^pHfZm3}0rfj6#1 zx*9`PGQBKtH)?@ zy#?lfZVX>me{V~|a2QVKdaID_U~?}h0$oZf$Y-f>9%H{5f3fu2>b6pF)>f#w%y*&T zCEq&0Qu;}r#P$fC+!vpPt{BOJP^7p?hXJ(x>}9;8;xPBtge}Xou&{HQjmf3QTsf6W z&|Ot3;AtVBu!aWPmRLAMg2P;`T54ASySD(guXZ_b8gcic9%n?^1lb|arvt2J$a42SPe{!Z4wN{HBYs*<`Lhub$F4N2zj<)jUG zM$OazqI9&dqc&s2S|J3{Uz_ScL0=GqdhaoRfRSFL7}nWPWG#Mb!ZeudhoqXe@t=Pvvm-(a@wYe^X3Q(BlIas9|3%IVqZCAr4OTu3E zPw(Re5MYlUCi^PmCFmy{t=JK1a1B{|nIhwE?L%_RJx*{Y`L-a@V&|ckyQ(rz#o89& zNBF6aZa2&aC?+LvSKdY#_c9N0f8YT9WGq3)0cKyL{>qYu?-=HxG}fa|!ZPZ=F`G9? zAKtnICWbU~rMSPctW0=@**cTgBd@_Ex|K<}UG4ogtthIp2d?NRe#MGXF$Nx|)4gAz zc4BkPZ>3|*Ju|w#%S4DZkMLy9VTa;D`YMDwn=WyR*ot@n-X+XYYtzX#SE8adF5Mp+ zGCJ59v~+I#Q(K$Ptgha7Z4*7{-w9x}A#ImMoNiX#+>(G}^VLOEUI=NltD{$Uq9H)d zq_{0pcIxUgsZ}WD#`TD*mT4!P16n%8A@B(yQ%WzR+S?rWn{=&9!M{qEqO+SRSIMsM zbn-}>Wf$vWOXsbh%|EkvF`w@)*{84Il36)$&DJHKc?RQ1%B3D?@rN%M58d5Z4?Z8M z;JO*K*=KjpDP5V(bRgSPocCv3JNWmayn}a^k{8&BwU5P7BBs_{FiWw3uL4fK1}^YN8mFl8Y(E z=Jd7L>qN_0!wg@jT&(SVe|zz~l+rIFIn?8k1X%oRt&N@S?mNN5un77@3KS#9UACl3 z+Bvu4JApleI6PcJ#CJ<0f7mJu?#n{)=NpJN#m2cjtBUc5#390od7& zJ?NV`>}cq-Ore87ks7!*MZ5whWdB9?h-D-ft6uEzovrksxpP01;vH68ZrR6 zjz*4=tG1>>O=9^cBOnD?EGXDRm9Vd$F0HSxw;Ms#BTRh>oB!EVD01JGCBm^sHOKRK zOSB+JV%7ZI@AFf!q#8<`&4c2O+S-K#5fHOpuPswI5FyTDm?HYgiqfFofFgIy+4e1b z)1(bAiS<;9i|W&Pm`Er8hw`lI}@t#?(4JYPX`-G4=G<_n1Cv$`$Nx1$mBIn z7;O=j-}*i`XerDn+y5R`7Lq?5p@TI{WBSl^+h?=`1Axr-H%Ft@x)(qeRDi)3TNIF{ z{U;90;Fq|M6VJ0oyfuAovojS`5?PKucw>g%Ag|L(#8&X`GL2HP@X~MX4R&!1vj^U2 zH9F~AV`vg33!kx$Sx0{-JebJ1ueih$R2P4yCX$MjeJ0o9G#nl8lM?v|)UAKQhOC5_ z`b4Ku3RD{`^h#ZvCkPCSM`L4#%15U4w+SMq-KypUsB=w|5wGy84x?O8P(jRz{0yjG z45&Pxrhyh>jRZJvji7E}mJUV$P>&7~K^DYF>~UAspi&t>#{K5V`d&Ikk-iERH}{?L z*ZjEv%?BO`JK$z;s|pAp3y`~?2Y%XZ+Z_4@jerCU;|j0i(w=+dN)ime2$CBR0VAL9 zzyNM!D8qr%3v<@UpBxjlMuzkc&Icxm`I@8Vq*(gf)5yT29~4?hUs@cIaWqx&m)UzR*o0 z>u{j4rsN$JG?q%FUXQJ5Q_4v}kk@3$tIxhr?S)vPpiVlA?ZCGi7+4!{K6=&)qHr(F znNiRSHX+t5*H3j-k#O73pS|HL$)M1(kE#I&6yH`EI!_hub`q@S7| zmiA-q577y6{d+JiR&bn>Td4d@LoQ&Ooo>1sm{_~Z$Tc6)4%4!3DJ>2cRN-$b;A$k# zR^eB4Ee4w}JyjgJ=vRByOfNfllJ@MG{~mAb884fJOr%A+ERe)#V6j1!BHN@Um2)yv z%VuR!i8k0{OKG4<^bZT#Iv^}Bi;0Eaxa)01J}ZdRsBVbPNX8f$HeA*JjMk?jL|=b; zFCOrre!3}kip89zsqpu}LJSmOhYSTUCVf8MI481f#3TybKN?N^n|#2V3BT<#f-18z zolb2wcjR=SkKBa`Vpop>b>_sDfA~Fns}3ZgUP{0^YwBw5V`D9ywJ&`j=g5ob5S~a- zCo>JYq0uHCop(9-jx-@Fh=b1*q*;1VRhlxiIm4N>rR!>=u~yyKn~tDjA1Rrc33$yQ zzkYQDqn4zOKXQYD@|aywAj5z!tk=VZILFpvK>mc5+Ee=s8||EdU4ii$>pDSTjAog% z^sZiEsBIddJk4Naqm6A{v|#E9ew98)EVwCx1nrS6|07C6}}rn)i?A^Y@bmDtm-Ubg|XkJ zyHuujxov0bs%eaeXWYlQR8zeS`5$9V461l)PzZ|g|Hqd1Z^ z*@^G>{5-CBdNSVqE<64~yN8=X#7jsfDL_2#e#`o=#9`*1TYPSm;$%^rhPCTY3k`CR zVv7cf4XnCS+9E^|Cx9^&trN@}K&KIh{CmV3b=EcBxQQ>?ygGk}OV`gk8t07F!z2qk z@kQNK-@h>u&*L}JJLxOgDNrW&&;U&4O1+|YyjJ^Jld$m1>-cR(O>Z2lB~2Qyd$ORl zS}Qdb)rm|Aq)i_6r15gA^!%(~cNg(H`^iuqq12}E3XSZ^RPE*!ZCuc^<+zi3>H$La z#{3Ui{6%aL-IQC*H&{^d5?mJQ?69TH$qo*3cLkg+!OPxU!H z64~hvGXgOO5HJuMj(WsoyAI#jJ<)r|Y@Z#PxhNm@#=KEM~{a zKpoA&(|_*Y4XYIL%TR|I^7cX%x>Z(%{N*65{23dPDw>)U_?h17>vPSPTSdI(F|-tI zD${d&&xzOG4GQN*hS!NAz6Rd$&g>l_A4=Jw`Ac~=m(VKLV1~w<_TFaIDcM&#&K20; z!DE(GsiPRJlVEDvmoqnDHE=Z!_G|C^StzP>2j8))@ABzARm(Isl8%M6tGaiU1i#df zXj?Eat&Ub%g^9E)_^j{q86?|y7vy-2JzyR`geFc&3yc{<_#KU}nU~R;k=w>nFizaM z)zS=yY=eUG52&M(fP$NgR{^YR?rN>@;x5YCAH-aSS3acSpT%qG6KnFK9pcsYjb(s8 zxc1jysxHP3{G73AfgQ#M1nd28Now%JB5`n!MAjJx*C@vmU#5I987vYiB~=FK#r2kn z7VbtF^ri&-jPgpm3CIw9BMjYYm7cu07P}8V6vqj7cKCr8eyaENGZt*V0RPWC_@C>3 z*3BC49uNRP{qJ!K{eRAbPQQCoU1ukAtN*M`tCZwrH|P+2W@^}yLP!+EFk+;n&i%}n za+A{Og32kJ6C}^co47X!9(QgoixCDd?&sgHJ7#vgc|Krl3ORlWqf}ZtT0Z`oLQlQI zkQk_(1X${hOm;Ibn+9pUjpll1vtY?)ay@c%Y|zsX^0{{LQtd;0W^qma$~5{^6_;uNt}ay`abrHJH;v{nA^0K8q(F*<1C0#5wog1SX=U@IxC2B z!vh+jAsDA|1B~UPOn6-S7Am6;SWQB%_VcBp>m#1*TT7!3HPF%NERW0*+eiWx5jJZ zs3(Q@Nt*j%3*G(Pz1tU3Ly*3Le?hrf!wAzD|AjYyNNW*=g=6-3>{%z-{~2bi~X zMQ`6t@r?~@0@w*?1@dE6_X84*IgaF<+|>q~;mEH+vtD<<7}FRSUk4&4&YSy?ZMC*B!$SXr2M8L1Kj z(!pO}z>_M%N<1C%b)~5c@NV!BZ?-nJelsO>DS`$S+%6B0-O#dU^`|Yte9|N9+VWW@ zbS({w?NVZO0spZ-jFRYM?P*Jm4fVnkx1)DCUK@fupPj3cyM{XiwHcz&P7RUmufd&C z?RVtHE#6o0*f+`FvwwgW>NqMRU^9;urct(l{ePuWXV9o143otG@wLW=t4UK^TZX7i zY7!!8^BeuH|F%!*mvU`&=W-d6x(xgnz!6sUsK79 zYUneWdzR-}{e13SeLPIOQ1hkoyjku-lB*N>$ZkUX_8+{1oD96yNl&_ zlafsa3*O4Se$u^KYcS;%+O!u&Xkgp-2Wj+P>H)p6U`E$RAF|2>al!tVT^Us{YkFt! z@AV?(OK;A|w(Vr^lX;E(CD<`(^7bZ%F=(iZM2h1xt{)h;jn(&O_`z3$e;>HPe#{5j zgn3DuEn;toPo(+GVX#Q;Y5I8+BfsUG+X?^gN^-?EY28JYmpd9@MF7CE9*qg~4D3GA zE+h7zr7Qqw`p z%tiD-_E-8xY$Nzs2!t?5X}IzINXcI3<6c#ZYV_$`YMD^BW!g}HTy@7mwXV_?ITalJ zhapHTPlA|?Md`H;Z71qja>$yNMABIMN`3!_c z!c#035o^$bdsO^2{It>}IVOKWfHdE7^*g>cQ7)Ok{J>|TNbd2Z@0{nb8b;^(pcvr~ z_ZLdB$AHG6m|t2wdOoMX5@EL1>69#m;DSMPh!BI~^rPO-)Q%${P-3AA9D`sMnP32? z%B<)S5Kk;&qp@t-=V9}}=<~0%AhCaEiB22znN4B(CO=6oR~4`W@MxkyOu`H2>)jE+ z^vJkgas;}{X_;;CAr3uQ?GF0bN-1HmpkhK}$0IUv>ritP7y0Q|vjLw=zwh6uEQ5cz zu4Axw_mGqnmU;CV4tpJKg}H4Q`uChLDo7UG3SY?$2>zv8>Fpp8F5DxZDrmul8H{veACvF-y6#ER| zF#${8!t^?@N>IQTL1=-e&f~lA6Mfl5vmhsG1nKxcSO0swss*(W`Y85-t~Nb_2am4@G~6ZsAdi3m7-@LZAoWqp4R*?nQJN)fNeFM~a^eXl6UeshNJ^<|v%q(X`2H~<>Ig%Up=CD4O&s;Y^>M8iN+4$@YFgRh z7z|MNghDt+7sjYP#|`oL1ybUR1e)*mqE*=qmx6l1%mH&1m5PZ+p{AZ(LZOet1-3=Q zT1Y)`wVdJ zurI$FImG8Tg5J&V;2qPdntKc>w>S^LaE&CY{lpsSxJ@{B?n}TzHXPjx(5sQ^7ds|3 zWK^!2{7-oPU6L@$cl*Q3@UdY#S_wpuNu3ngT;q!j=ue1wf+F0sdYAqtQ@%I?-m_h@ z-&;ihsaq$7lLLs1A|_iN?kc&>I=j_+7@Q|mF^i~#xm**+wMV?31+#T!zthn^`O|2S zLHKOL*mSurf#PeExR{Y=wfjOchVqo;fDTQjKjoRZ_bMBQ%WH_XtE+eJcbo zG%+^iK9Dh4p7@+L*A<0th!}8oeAaZ}q2XhzKf^2A-<#U%T(XruHvjYlW@CoL@$V3k zx>G-VE6^~$40Rw_{%pBK z$i-%$(2ctG@FfDwv1}LNi=}3gT>-g~Gp)dzfPwISI#|+2gtn8jIb#nhtde>B!}SdT z$JwWj!R&S^Ki?Z>`>%XY#gWWzF7p=Ohpy21H^yl9nP;v?L|rFOYw(^k7x(uNX2;=h z=IWVCMjm0qU<7KQwK2Z?)9oxIsoRH*EtwM2F1Iaq8|0-=%<}$Akwfx%TZOhZmD6rg zL~em+90i^N&m3{QGQr0Z5U)`Fr~n|4{fAlSQDP8xNC7c|3(d8$2n&-^x`)gfaMlD4lp@H%BFbc+L)PIOB znpPOtp9_m<3{SFBZI_i;{Ln@(E&{C$r~B>2)!&Sbo78>B(22~xhstZYge#W{U&|-E zVT_M!wa`#sgd>&IC4b=(X(?k~M-cRPrz9pL0 zCB6)L^qs7r@HI!pSS~i261QW{uQ)e%DqYLgVKoG1*&qp7v)|#Tk9hb%9Z%Jz`${$Q zyhXeE2UaeU7`SkAfu?*H%q@;LB*~@KEd?*pL@2iYsL0nVF#mvDd1ps_>Q=%m{64PB za~B6P;?RUuxP<|Zq8yCo?UhRwmkJah0c`Tl%POgIAhnYzIh5 z@uS98#qz0AO7gJ5e&SPc^Cf;x!SxXg6w(QCs zBsrNbI+I0}4N;a>mq%b7Ddyv_X^zkCk^7||WD#uVMaOhG-!hbhEfanNiBjTaG*vB@ zFp}=qddi-G3T0o@$(WOmpFwYxQ)bs=LNiIqB+T&|W>FrB1&gEd-HE;6SS=wscPx0KOEKJwnG4fxZOTq&a82)0f_CNc=>{ z)vXOce_8+o2Y*~w*rQN?k~%S4Xuj{3?tPm+b*X=Fse2TZIe{zgG4C&Hr!)C&lzp8qL=^qw6UJ*(Sj7LEUSNM5Eb$d^5BVSa9 z3EE4q1FdvwB~Aem{HmIZuu!Goeh8Wk3g&g@mmWGfzWe|7zdR&;phe8#GkMjE&CW@n ze`CX&AgrhE5Hr&r8@S@|a#VU*ogn*;Xkal#JHk6p+ySIAj|2Ox=I%_V1q!zc^Gs z_h9$O&K_nhh4Pre;=kdt*Gc)ClNk318wm?mCLhk=>PdFcp)%(Wb5&K3d~1!pL1gNn zXBe6BRo+v4&;^UAI!@OE#D|d<4j~zgiy{;wVQG57Gw{L zZOqCqUXOd_&I>S$9}*6rrwe<=o!8t(}^^T#Hi+oEYg#pDAMJgaA==t(dOrjBSP*bbH<3H2* z`gvSkicfEzzNoaP;@evar(@x-=8U5{N>O$Y0#079cZ(EK>ULk{kaV zvYw?a9TtdJ-3~zRu$jD=j#)5O8p+?Zl|SIgB%w%Zh?hFAT`$Hx_=o=*2~sPEb~c}@ zm8);7E(asvMt072E4Rysgh7C8vmnNoGXMNHmk z5>*=rdI)X2%JD3;j#D#X^Fhuvs!ZU)r( z5*a5fDcRVWz)~zLd58(reMkFhnCM$Of`^|>4rtHJss#bbGjj@KHJys=Ren-MA^#|e z&1l5-sWAXX4ER`z3bAQcNGcSU1Rzli?dg%2R1K3y42(MkV6-xU9ydkJ5rDdg@{9H8 z)JCO@3X>UMj z@o+vQ?+(9bN-dun_?<_H`qEt}pIL0ICdp3Y1<~DoP_D&-;|dtn)LC^A;f-#qTAHh0 zB$DU!H)wu^VgUCcS=`MZdG;;QWWWRgj5_kO=hirbx!ul@;gFgy%uH}ByiiOB$vOvN zZ9-%!3keHD6&6|bN@>%lpM=`iT+lSxWi^(G+_n3!XFZK;lC1WxkzDGeU?=H75X7q!a+aOT(AK+iy{MDeXcRyyH=p`+h}2oP(t zq<{FbT3;BYp(RibF`*t;kN3%Bam2NJ#|HPYffo~Q7dGC@C0dPc*EL&Bv<}Y3!GS8= zO;7PSeW{)Mi?vmFAaFh396r1G{z4H8yS+m)Ba8^~X}gfNH1bs{7Td zDD5D$snhB~NY%I4et@`@wBe3r?MYFB;vGWOZ(EU?7KTKZ`V~#ytW^Z&wvfK<8v;!7 z;E^6+ep&#s_O_G2it-EsKiAzYgp|}5ky92Y4sd#UEHmk)WNJApqjpYvwB5}ucmxW4 z;~mnsu?lE|6U-=D^Zc+|9v(jJZ{U$kz&lsVo%MXL>RL@Wz&snsk8I`JvL+ml>5wi6~1y2!v#(nPPKSd&%~;BK3ia| zA@85^F)&uLIg{bijDJSi5WRvb6M)~#t-T^<47TjCcx8NtfBf!TOVR?&Q9(P+4-M@M zR(ZY zs-ainv~*A4oIuC=>c=v)V|0V(j3|4@o%1v;|G7VoKYO?pB6fgs5k*-pw84y#_YWQ#lJum=9Mn`N@*VD5 z-J833PjuxV3k8BJ>lPyYRkMdno=BMsNd&9;HRu>KrSen?;Kwx{PG1TQ=h?>21Vm7e zH$WtHx7&v|0$bFi?m%&XX{G$3q>$%+ppsuf%;kjV5IiR@S;eO|0J3>NqIHP^8>hg` z?#}+@pI_i~c47CV&F?kxofyYhxGERaTa;@s23xP%<-=F(Y(*d1!fM2BkyDi-K9z3s6JFiRC5uE9(P;)Wy0hz!f6cFxc>Nd!(;q|_$_5In)X(#)t1 zYO=Wm=u-#_NbP_daA-wC349eVD+Hw3Yl`ZLWziFR!X=0c5>Vp^`R7iLp%n%C>oFzb z0nZ09LaSB!cF8L&70re@&2ER5Cr*Hk)Xj8(=8vN}SrY1j7DUWWfD3cTknL*qH=-*t zE{>~C5*0Y?9_+Y~&s1_hVzwBhEHbJvu}~3A^^%Jsw`ZTZf*rRlM6fw&Vv9g_K z{Eb9buTGEkd8&iKobLp#_mG06aEv(Jvm5pR$a$D~F>vQ~_9mB8T5l%vmo%vP6t{ZY zEMX~M@~BePM?`)@A`$c7c(_8xh8VMg7|5g-H$uhWG%`)V&3UM7x)~>{>Hz&H`QBr- zOF4CUDRP`m7yFq`& z%c%p>Ce9U@G|mpvj#&}*ewGZgO7}?mR+}4vMC4D-n-GkLVk$l^>IZ@x8jX$s8nZGk zupX?V5Jq_?6)QV=WJ%t|gG4l6 zbK5W?a%?PC`IXe}*mO&zCWr)*U<#URR~4Mn4$?EI3q}}Enz;R=PH3rn&&^C53R_1c z)zq@$d|GYgu?F*D@-Lx(Q*;*v(<9?yOC%en@0`rOmq$#<5^aShfnSsz${yLn#tpuJ zD;^VbDrS9>7xjW@)9!lKfeyB99RnH}-mdgA%c@KHgdcZ9(X!TP__QdeVS;@GrcL1NNab9;=UAb;;AJF96Y_;}s&n@(`HeGQ<9 z5BcnRsH@a{cu1?%Svq%YrR4ljHOT5vKB=WCia181y|?*zt8@ zpaHztE;XTxXyJHPRw<+t+`^*k3nf7Rvh?tc>anr8W$mg&@=LZ5`nL^(cH>M0UL76W z>=BwcF2t=if}BOe3jDx&o);1jGpiT3dqDXPrXyPk05ID4w+U^fd~ z`51UT)wiisbW}sK^QrCv%^Ip`0PUz2%7N{ugPoJJySqtM)`d7XMxpe$mH`27?k2H) zLyvvgKY$wmK$j7{AqmpcqF{L}0Zwcc#KK!P+@vs9o zdE~yKXz)_;Q%ZUuWc!@67l_2dzRR8&(OO+z3JgbEXQQ36?fEua|0w-i_-B8U)y1a?({YmVx?{_T zpLq(z8kI*pGVpKmS&_W$@E8$A@loaVwb~Bkt^pTNKk|H1?;MeMlAh&}t$L4M3L#H< z3hctJZutp3svFIB(^{ZNd6D($_fmn1lrTt1zt30kmeY2Rg<(Sx@xvvpqu^j~EQj&B zkv7BOX9ZN9_PR;wVFkud=}blFTM70C;7dA+0zeUt9)}&DF^Nsxp4XYxhl&TfUwWK| zRX5H~W?5I|+euX`4M-m_7XIogqq#v%Taq0gT%?y_QS!~mdzTbLDK25Ugzd5+>MP!6 zImo!e1V+jb=9iKDcAS;D0}+_YQSTG3PaIq-oFUJ9 znysHZ`JN2TC@zNWoG-1>}taq^1P5OR&#DY1q(g|(U^3dSx2)b{pz zspGM5je(Eh%o;dQjY9G}E@-#{ELer!6+-I;YT*?vkxkB7xC2yr6;WlD1$7%Cs8%!_ z{?o2f&FCX#QxO0=83A3QZaG$rKyHNM5J(Z+Exxyc`xsY7rQ9JwR$-;$Ro;@h*!pcC zck3im(JZ}GTRiV489}lT^7uZr56+?Wfqh(FfA%{M{rtBO$T%vq${)uSw!`0MaugW= z0QUbC0{O3N_kZPor!XyT4_NLyd_m)e2oxwRuZ$YHrusr8TQs%oHmu?bRgwDyBqbXM zL~7!duv<=jdobeT$qxv^J5FF-20CzE+EHkx~YjNV}uk z>Fw(NWdj9cnRxS%C@v22(Y6mHq-_I9?VpoM<3EvTBSrwB4Grd4&ga7#K$%?w{5K+L zU9QhhqQ3~yOT8?*-r}0i1Px$O*wMkxN1CP@epK~j1#Jaz_H}RnI0vu0hk!RzS3660 zTdT8!1h0B@13pDpcT>x&HK@gQ^zA6-P!}yl7d*(nTo2<}4eAgt^{`aLn!{mAEyQ10 zA3dHp;w^jxU$wjXG(FDwu2%?dKQ=xrQj)0{dNYo{$-xxLLi7`zI4i3&H z6@*vlA1^t-Bc!alnhkF-9cx;pP~4yPMDtPrxHYgc=)6aCxG2FyrP)Ywl2=-)G^tCn z*Ac_7Pb6o&fzI*Xxl$tDBV+FE$i7(hax#=&5rsim%H2r^E5uCiF-I1btxzs7xg75^TD-F$Lr!Ix;DaTm@oCPxOCGj&QYz zNlJ0OTO5&l0BmL&{BsP{`yXWBulk!jix7Z%KklCz!!g~?yqm>wV4#$F34D4>w(>Ck zqL3GfQ5#R(_6D7j1-p$AmRYKTHTRfn9#N)z%+$auwLEZEpUW)~0!~H%W@i~$B(T&6 z1&H%yKwFHcc|Fw#@$Q&2^R-w6W}$6TZh;xY{a{8rPgI8dw)7qJ^vmXv5_gMS#*L4a zVUUO6R%vn$LPHwa&ddIp;tOKibMfwM7b{fEluzkNP8j=O1!!_z ze(16Lnp1qZ`3R;7L25zxR2oon>IqA6GU7pxZ0^HNBdam;~F35 zb4~7Sg;DBk@JxN6-wz!Yd+zE>=KQ4B)ijxG8|cV{7qn8y^AFL$wd(oJ#G1RO6x5|0 z>^9G=^qH)-uXzzWfiXEpJm@Mo$!|D55t&~YE-ym1ZQhuDHBokfPnFHi!==wCc9XEE z^>T{vY-Gu?r;L;|L`EOd`e%IM0lUMF+-vmYuh5-(Usm!21C;#=$#q&7Sc0$++2C*m zxCI2Bfj`(^;yB+jh&?Okw7XjJ#hN9beRCv#dCt+V)Ny$5Dgg7M7a-Cbq_H>(d_u3l zm!>ItQ{kKoOOY`_1B9!t?C35fa9j@>V<#^vIrR#;0|+Ndv0(7bG|Ug$nHAZfM^pAL z3-;7FvH2$ZS&wnDpHjtH`D&4D?EtyEW`Fv`tRRmnVC8Wb0ZQuv_x4d$gh2NY-bpml z*Q`Dh?PDV?aW>fq*J{*DWz>c9VjDa7= zML0Yh4gr^#<3FKgfK0_bgmt zL$^#~@gyMCm+irz>k8gfJqn)@6blh&mN9sc31f^LNA9!nIpBZjlQWobi$1f}AwRu^ z%nY|24HlCL+GuXhxR>iC)eBKxy9w?_QgfEIqQ&phU;vDD$7>lrK%98Ib%bLgUCXPE z7ve;17g1}gcR8GTMIcycP98|vDz|E92%5JNu>RGcxTvV(vQleqXfKseqsLze2_Be_ zRfguTug+P*>!X=1+rljcPT4F=H~g#iiP)vayuruE=Ic$o-U`0G?CKGyXC!?Q1_Sxd z_*^OJoB$wSkg)J?5vnK|i7RM5X>-g2oMb^9c#XdI>dg~ap?Hb#9D&t9cMDIGenN2W zW8TGA;95Pmg%E7v?>zgU>?ujZ_3Vo49(Ez#DF9v|Tr(fi8%R9>go3C{E7>SmhTk{< zC0-S=@2Y0UCsgf=&YC(*&_)k>rmh0iW$ax>M|B)=1WJLXoP}7a26G{{GYypAa0-Y< z63#|iAa0<%&t-%UzoB|V_MjvAb;P)86na(>@cGf9>>zcH`|pDBhXkV{B65=Y_Te|+ zzbw3|5A-XZw8;2*r!U|{xXjG3nnv5ACmVl zhgE?f6Dd)m#8R~MCU8i$4t@*@PtGGtAkq(AWSMIq7YWZi#>;$|6jEBExnNc@FW5HS z=>2P=6E`f3EPt&w`5~V|9nJ5jxj zq-|MHcMC6jMc!FQOBoK!liv|AM|?j&Gwcy-4?Ym))3HpY-8-uRC{S(7nvT0)tS(4N zJ76F*UY<71>>u&}h(?FQ2HJc=GpXTnnt=xnyJi%>P%&xb!4p@!OXW9ogWe7-4NHYn zZcuV%WT$!rOei%0XN?lf-DZ7q>afnfFDK}SjhInQ5l1`%4H`L(9w(y#XKxayDeu+RPSfIYt<55qE^owY&pBW`o9KqyeEGJk%Fd-!f569e))$S@j>LTiV6q4=v9M;#ee%8#? z2=;7(0e>qj$A*9CvI=I0P9YxI|KmZH`NyNS8PJ;5TBFi9EM)f#JD=+N zNV%CUyQ4)}(p#I^x2lFyFJ#h`2u9E?eI~utcs&uJqf`AA%7SIP-2z}e17-|0V!)GAd=+YJ`vuV1Rho=@r)MSWH!^|=9XXY)m5vrODbxB zMtgJHf6#9(zvwqXiTWcVoPT-zxXEtrx+j^E$A7x!ho(oD*5}Ub9s_h3(eMf-&cYL# z=F@bRUc z+ox#klCdjeT1Rxy45k0pTHotiOyQjc?DRDLjm*EO8nGUf)PWpnu}GyzjlSyfUz7m5 zQtepLlb%LlO26bWS|e-HX+Cyr7rEYaTg=C{ujGv7$%v_Pi*dEN3;r(}_nLg1YGyOCQ2=C?)Lnu-#htj|prEZ5NU`p!cD&!#EWy^J(7;8&=sd4RsW zYB&}h_qtt8xice@t?1+;2@0-Fjn?o|*Da{}%T6A+?R)^Tw@@f4*CI&V*6pNbsh6-Y zD#O4vgwBU5fHwS>0g~2-axFej@W}OEcLIW>8e&i5i=$6#Yv(0&<;C=oJh@lP@(1R? zm33L%nN6Vz*#E`WJ1~hBZcUnH+qP|=vTfV8ZQHhO+qO>GcAc`OZcltO(|0=dZ+Q30 z$hGoG>hmvQ5Xc0O@)D6R3(J>o{)ieb-9bHOZ(&oHkh39@zyvCVTey79xS=h98}<&0 zle@)3chn?pFgLCJs%76Wn#t+5ZijPn%tWB3RztG|jN(}p^;_hesSm)_2n|qO1lDe( zBVaruH1McD6L4ftRZvlPNJ%Bq090XkAUSm+12fUn>NG*jB}zvbAF~B76P-ke)h8-M zkMTPqBXLwimtG(dm4r6;?hrUDh(ez^Kl)m-K_GWshPv}cznCJIt>#Yj14J^VUU5*= ze#5I3ENJRz#4cR{x$>*S1@JT6ntzAOeOF>FaR=&!Uz*E`q@hj>^2ROARr!%^>~PF7 zS=?CgRa{a(CM#TYw$-0mV+NnYO-yFQNwe%=1OMU&b7O4CY)K^7{>pk21`N=s=F9RQ zALgX#T27=9%;v5|2Fg(~{cg&tzBcUV0)T8T<++j2K&q8y3ke#TvI~PbH5@`#`njr613*5oO;^3*T>AvgW=& zO|@wTb3YRQ!XN&#Rebb`A1i;6uybzoTkwnaJOOX!G5OB3-{X9k!EBTCyxJAy?LoA) zAP)=SRn^}ZF9j7`zOgI9-$E#sHF?k(F|yj1?oKXRo$sm`j3Lx|q2=_cX11(?%-iIF zUz$wQb12B{H)ihH)+4dP{nU>e)dnU4?!oOL4v+v3xE`|u%18gku!cCi)f3p%mGjWg zIkxsNlu~X?*w5S=qFmrHLKk&v_kJSw3vlW@^ddszxcbt`_-X+mRO+a1Y1^&mbLkcu zQY1811l|4cR&TY75e^EpJcRhc%Z7t*L_JZp4Ky4Xgp&;$GP}Qfdu(d*ZFvexyj-1K zShBmu$%&CYESn20+toT2tM@qk=pV5$-QQ^IZ*lRnRK4x-AL^shG*NUw_vol$(>n-T zR~zUP)y{tF+8DTZKjSh`z0&l(nF;<$eAAhniDSx8E5OHr!t}M5kjzrpgBN?f4Z&S- z5b@dCu)x5DkIA}Ec)g`<)i*~o<8TDH<@O8!j)Ma}K~p6H*SMbvBt6YIePI^R+Ds%*U6boC3xy)_Imc-IPj3;)rsF~W zti1kGRK?+a*X>PLkv9-~r#3cY7WY|>n_sZj@Pk_cQ@TEvn&w3YKXjfjA&x2JzmnjD zmsSSBsFQ@~0SS9Sy#EkFX;ebNwr~pc9+G`Cb1=U7jY3rCrDh5L!R44%A%6)&SoH)4 z*8LJfrJ_IxX^sSm_LL$^fC~s#L?}VkCC29vq?nS}jB%VNfoA6qk-Gv`4^({TtIdJW;*Wm2w=Cr_CA0g~SEOk`o#VQM+1 zh@vuf24?P)GB2tq*I0_&xQEIeQ&Lx&8SVzk^A1jXE4X4#v8vu?-3~~*8$`Z+1MpQl zk?gpB!IFNHqT1&$(-}n>!=;60n<@6KJoRm{L{STdI?Sikf!bVJsrJ|fozHpBh+*cP z$+HNm8bL0d5i3ji)t0Y%GIdJ8!*G17bJ)r*_e9ZY?JIy}6K^BwCjU{v(e~c8F*v0j zF=gH;6t5|;`>wdRP2Crc^kHk)6P#jkmH4sjVgM-R zu#xahkz&>75r7rk%PyC2^^09;S%wZipqp-1^+B{G&17AC#FKj&V>W5($8qrNkkFZI z@h3)kpf(hiGqA&2%js!SpJ%Fs6&jikB9u1=L)Y4C+b(W4I%*mAph5M|DU{|x{cl8IP5QGO@Ku2 zXtk=_=WJCZ(uv)x#=%P#(aVbiB!DS(bn@z=hPi0{rtWrVK|ED2k$N7QEtW>{NC@sN ze#7ZqSyiiM5nXC6P0g>Ao3biTC1GB0lwKqwr6Mu3w_IBzBba4^$y47(2!2ugD>5tS zae$M>Py-`FUWY_6$tQuOLO7aFm#NTC!bVwR-o;%4kUHCA;qM)NeLbz!AHCItihcN(PUl%gTQq=jh`Fi`Cm33l+T>YPqhBmCd z?Z1Gv0(<=-#WHk_B6GT4X+EFa89qM%{g?4IqNLII)SucE(ef$U2q<^M29QbV>WXC& zMOwlp(nB396P1M6@nq^QG|fTD!sr!Cjm;k8mNNt=J-~A$&>Wc)mYlqWU9!h>ZxHN0 zY4o4=pU!+eGI;o;zlQG~vM=!e&V60kb);|2UZ06&O#2q)mi}&slPH3+J1SX>!pBE0 zoyEqUDBb|7r@#^@O{mu>N4G~dqg50)d#7~rb0|yXP!EM!0MG}dL_-U!yPZ-{xH^4w zeZ$Po&h~1iPs-cYZi{tZWSNhzQ@<&cR8 z#DeJto77*@^+W#pttE0=PGb01Q~v%UQ(b9e9uI(wRx+zWTE|?a8=^Wm8EzLj4NOHH zWkytVh^AZ$Md?h^s7%-5E(N_)qS}B70~R4HQ^ZO{eAp;Xq$Qfcz$U4T7v2C(m~Bqm zD)fH0wzGEcFRy5RrUgfMG)1AbYeEnj58;O$5~yTiM20|&gl%HWr`-^m+PXTDfV64x zR2%jW90|{eCbE&=&-j9mNzxvP3B&%dW6Hq|BaS=R4O%=G=Zd`q2XyvD$t=68vmrwE zqn1jMIukIt3h2O2c6?^!YOt_jXkj!+nj025m;Z{_YoY5_$A7$2WDNXsVeS)XS8R6T z=F|_$YgZV(s$rNeU0v~sD(Z(_mkn)&qVR+YT_rxkI|HK8sFCpV+7O>}EgL8!!$*nc zu7e>FTE5t{4@69P+=cixW)kGRqdAsSJx9qny@%A@Ov){bJDxk;VQ3mkRarzuD>Vx} zK;I18JQ#RA!@S**08lvUfP0$&hyn3+z^b7h+-mhN9_I@JE`PuC8KJ!`o;FxFx!D>w z=5=TiV#hEF6+!xiX9 zopZ><_#|n7%=M?ph)^Hvg_tgJu4?RZgm@60@xb*UbthW050y8sqaWx`c$&1&+Qb^fBl>>_d0 zNl^uT<5nyM)zJ@?m=y}9EBCs}EXzHy!SW@3yw1zq?i3F$SgsrcIm$E&7hm3+ZRhZ! z(i=M+ZnIaN3q_LHgzRAu+^;KXR#xTCn|J--J?)%#L-~-__ZOM`=GExXbuj`51|euo;4Da5Ch$ht}JGFnPX-9OzjAY|7KF z#pA`X&No=k!W+OnKE@rP@G4_eYwd=un@>Q3p!RD(OIeIfwHOI_&&?HrQHoLdq(-hd z_wvE*!!Ulr(5eD0M-*oLr94c{V+xO<{UzORdA)kn=DV+Q+nKJUA3&gL7uU_H7@>>m zyy?(3U-`?jXijrevkL06oQbkOS#e$yn{IbUV=>%-dGCAm-0)y^d|ijGCM>u1lkQ^z zw~X-k?7m*v_M4Z-1AE}Ld+ltk(?x4`DM7p9Maoga1{@%j zVpA?iH;T)_dvT}oE-kVxJ3$AX^bXwyxRa@uQ^IMZTSL3J#;_l-W1pX!e)pruKYf6< zj6+_oxIz-0Gt~i)S8eR4tbbbs({621fxRBNnKI#+*M^*@4pheNMw~IZnmr&t97EJB zc<{RJlI8ert5pA(DqVX>V?Cf{sU4#L)Oiv6LTIfp*+?o)L6xusR&8l(Va4hh~!7Wtj zL{>Go{RrW_1_%89Ysw~^42dNCifT)U|6er!f2H#OM^m>b@BZJK`WF_BM}$q4gt0CK z2K;gwU&DN4wiDJHAQyx5&J)kV9{%Q^l-E5vaj-exrj*?(t!x&aNyWqb` zAk8{=M*mBS_8)~(8A{g?x&j6pV8!^BhVWY9A9<0d{w3}QBA`JLBHE`FxDr^7E0>h) zxrElE+AiTC^M`|ZLm^mRbBH)at~FD~9{P8s*8i~0o^MuFUAs=XOQljO=}yqi>-OYq z|CvsMuM4-&dgx944{v`JLa3znrJwbF;x$uUvKQ3EKMic+H2JiPbugk+y1#PJrb9q* zBE&_s8ryBLIv01v8;Ae;fEA`xDxDVmZpxWci!Auwq4oXx{tJp%7Cq|Ro7GEw@DS(* z@ARc<<;p>}i=83IFL2lF@dG4YQ)FyU7JDvF9BOye)wy1nhfuyeHhv{NLeN-QO^2;p zm|_JvGjkF!rW9!vehH~3TfP__!mZ{QU*ye*?B^UE7Ay!=Q?a1zqr^c_v-Cv|<8M(p zLnRT~AScv)3)Nx1;@`Z1ku_DO+s77=wlY`KLb3Gf%lYKpvuMATXjpB zRkP;iLW)U}vQ~O#x_h-Dr+}8nqeEoj6B)OygFCiJ0Np&U76p%_4)x_u2iJ$OD(uvL z)-T(2GsUzpvP0I7il$a*d?wSpFuWPHITp9Sp%qVRBqT($B)g>c7CDO+nBh8PB$~C? zB(HxX){KtDUn1H|uPeaStQF8CHcyrl54cMMwA7T+w3Z8?m|EKqln+1g$aKcm-TKSK z3CHrSqdoD(TJaUs(? zH<;m6z`B!xNKPsW-6)@K72!8Ui<;L@aTDB!Ch%VPaR)I==WwOqa4L58p}F4FM6xps zD_>%F>KLYfEL`U}eI>uUm&x6#A}`e+EOVUJw+Cx$Sk9GtvwCeRU%?TO-jFxi=P+GA zg-*%-^(qFv7WW8VN_0(;-KUc-?|3`Tyazi;K48Z=(d4f>lxFs05+su%QP}zAj}Z(8 zedpfFdmUgp;LOJV#Rn#4i(rTZO5?GNcdP>)Qy=)Viw9SNzuoC z3z+Pn!58)Y*E4=q2Y-%72LMo41OPz$-^H?$^XI<&+u3|ZH_1Wk_1k=@ufWt zULD>XrPH~thd5VEj+05vsE-u`2pEd%fdFa15Jq|?{rY^TkwGXYv)~n0Y9;hb%esC) zy&=2d7ndfxs2Lsdoj#uzy#sK{th`7nbatk$s*A=Q6vc=>w9UdK4;s2jMeo4{N$*6- zR`@&nhnXX6jD!iv3GII>bXp7B5)KQ+n4d+v!UwmFDk3@&FZWuEBY^b(s?eMGfP#It zqij0SRkkB^B>)sC8t5b`kBrE&WKmMaFKm?1EK(#Is5`keMS%fXhWS{+Jy%K?xI*q> zy|q3Lg*OlfU)|PRWPg-@fA1*7$nCf-NRbE0ZN7RKGXS zY2!OLqM8A{8=XPbQ-3>uYsA@Wg57jtUT-!EsVikjlzVPmEvNw{b<1mMRejJIru4yC zK7jpoEJ$^=wWFpk(GR+!E~hRb^gq2l9{KS3cEj`Nd3`$DJ3Bt_HN8JcdVk!XhK^^B z$EaVQ1+=xlgg$SNSKt%k`>E7LdD&j=C>qYQj!v#daegm)J!u$q<2GHQ8`!^t4=O z^FdP7!&HB-5E;@}V8@-7RnE$Ys8XIVy4CNrj*`Z-pPm-x9Yi95N}3Sa41A`IpKYI1 z`ee(HcI-gIN;wBvsV%4aJ8RMKN^g^i-gq&%ojpB#$(M)40M1{(fR9m)B(pL2>{919~Vk7qav%&v>mZ1JiF{qcopUi_0)~7QIAh0?XAMbpFM^)!IbdSZ z2|t2E`sR zaQjmUZF81Vk3+)9;v}+WvhkmJq--$fY$P+qV+=DnOihWHT`9f=-k zgJD=7Q0LUUoWB%bClF`_+KW&O$q5H;OThOy`bYqSE}|@#@l%hh=3>y>ar8_o7!B%dBW$;MDOV>l3Fs#Z zeJPAHtYm%6Py5?slUULD zb|oBP4Jag$m>_oX``65Y-C$O&5-v;-tlrmGbTr9QnVlSh;9+PP4*zcT1CuNz0C14( z8GVDryH*f^(kwR0Vw@* z$lL~ap%Dh2H2+P4!^Uz zIUnW!4cZ2>pKV(@K(2+D0-OZJz+6VB6|1L_M5hAch`f&Z_4>2K#5p`sIk7P#LuGb^4)UY~~(nXdpB2 zuk@NUPEVjY)BOODn6IN!U|CnQJ*Zyv=yUfq%HpA&PTZ!S~)YoIGSG6D= zZ{JDr@~_IqHN_3f1;p;p#_F?RN%mA_!2<>Rk$sDtJO1%n`gievtq-ZpXtV% zxVo))vb?*#pwO=PhP3*PS^qt`DAC`7JY0iwG>=*o4kwaluq1#owUL9buG9VapB7R8( z0giuv@^kG%BIp}LMhAD?*YBKN^jJ(K^tCr?q3*agfXXY21;yL1QAqPy)iSMRI$AVK zyu||qj>=gH@K7MaO9ItPr;#;>!F;5Soa=!}s0icM|0O!M+w}C;X{RluFVhLF~tD;~eR-Dq-g1=Y^elqNT z@^y6Rlkb8b1jB%U_d6hnK} zV7}4J8p_q-7HcnDh7a=^bQfa+B?DHkp-C_ma2yTHFr{&|xW;c5M`?oM{+4d55C90y zYt++Jw=FKWUy&Qy34@fDnx3#a0ATIW^Joo>QZu6{*SqHS(HiM!A)Rs zsOHwDISGU64l~+*#VP`CVW1HlY#tNd3e<%qrq%e^Rn-OM%wHVus$t2p(baL8$e?iJ zSkPuqcXO`cbs)2+P;(7Yu{MTi^xhhXn3!AU)?HA5Dc*c7Cg#M#E`R}q4y+>53_En! z4wgsDzag1o)k>DkoKLW*nI#P9T2la$JjWdx*is7UY!8|2i@#_t!;r$)@`@oFHd$X9 zG#Cgwr&SX%{?liO3PCOwsvcmQV75)u?dEvrV*7HBUw0NaXl}nU1TeRboC2T}+Zfyd zFLyG%P9vyon93m4K8zd<2Pig+19DD;BNr-V9U!Ste8sNJcMC`zM>wT6bMNb~xYO`9 zEvdW7Bv>29Mkq?=WSDMt(FOpY=CTnjnv)=cADUgx+;vQ}`B3RjQpk^MG}!Jp){~LG3mzbG>dxh_{Qhg3&uDDYJ-DUHKq}4!~r$(JT31FuF zJVl!9#EEjFdX;Aqh%Ez&uFUtVgR9@bguuu5&!f z5c^D(+O64KqnxpIuK{_&Gb62vjdYX(kPuFTc-HMaKkRk6z`Gnc5&H+$6sh+O-M_Sk ztAGmSS60YvK;iG!7>l^jX0FVfvtD>+OrTsvMST7q1NMa%Xh!js865^;ySqhCO@L^$ z`=%3|A)!0HWMUTp&{fJ?sRwiioD$S}ZZOQ?b3|mMzmi}<9Yrdc+r$`KuuTA?ahMzW{guDOEBk6| z+!==rDTIP8IOdPdfUMlk9+VIS451NWRve@gEqh9RsTxu&LwL zZ&uOYB?b6N@e^jW-RqiH98H?520@LchxKg<_DpZOQ@P6$KWthsD(uC41rZ84t3K7M5odgVXxXS4HYJ6O2v+W(9k&dyDKP> z4)cz(Hv20(9lV3^c|$vSEj|>^%{767o;6eEG|cGyW~=yOqBgjY7$uPb34Srj3;Eav zx8jF9h)Y^1Sbl5|8wVTB4Y`9%JTea6`NY|`UO$*heICpnjH5rPemouX;e9o@@?22gci-J2HtbF? zvc#I1=LY$%Ot0L}VZB+dt!RGk_;abURkfB$GO(YJUY2l9s=jpm)eQ)TTbA@WA~W}H zpcD@2iEeN1k}4V9l?`ZUXsJH2$&lk2)b;<+s?4O6zCBeI3e@@>apMz11ahN!9 zH}q)qU0TN{AH$5wB|CMFEB3@~UvtSmYLdkUqGyj50`3C}23{fxjE$=`{$bx^c8}#5 z0GK1Ue`D)AxV4o)2&Rz?^xing-BSs5{%v*~sut6Nj6VSIspc~mGa~$SEI$DyLMRZ3 z7^b$raNt4B9Bwm=yd{Vav}3*(l|Cos@hfv;2E_(sd)%8j0Y5Mf_tsaPw> z9$Ozjuj&=7%J7d2K=d@}-MPIRu)e|M*{q!|*Z>c&th z_=%CNunN*t$lA?CP zl`T4pwyQrP67HX>1s~}N*{;hs8jERD=^F`XaFf<(j|*Bg5<1TZ9)Fp5+U13uRtU;af+CVr4QRAx* zWNn`46e3=iJ89g;PCmT9fsi^I=j3H@-YBMZ=cM+{r3v=YYS7tHPXmr|4?V!qIts^= ziG~E#qo6%MW?_g+e;NdTr-tWVa6S6RA5~6HlsZpgF481v%)RbSi{3n*Wcke+$*cV2 zgyk)Ge{TRwJLvygYev}Y*jt6_h$#~ZHe2O4nSU(lJMl-1%Dal z&0kP3jobSQ%nQ%wT9!EQ=F>wGwPvjFp~LOFJexkFf6+twb*_-HvRAWv7BPvj1wV2# z-7q)4%t16(GR5eJ*<^gv9~g45F6{OnFRm*#+$+JiG3-C9dpMnn@kwGoCp2?Yo&NKJ zgs2wr*zWx zPL_v7;C9~(M?|X?H@dc=H3nJ<>jZ4?6U{>+)K^V7rZWo890J zb*$fP`<|n38cmLbKo!j&hp)y?s+3Bv4E#&8FG^ZxUSs=Hb#i$IkX@u;0p^SjSuRJ3 zalMzy*|F3ffHL+pw&O5rh;G7FwRb~49=q~>|7!_lbjJ~f%*b;kAIKmq7J>Kk=k5lh z^nH*&doa{1Df(-My)@ANsfSR_%Wj>{ymLD=R@b<#S8znLs;cRZqrrE*G_fv6<$gvd;TGs;GdR4^MwT60&Q)lJB@?aPOxUIlZflN&J{Zh?4BxLxw-huw(iTep<3qAs}_`yK;k<_LZ$?TUBX zMd96HJD0<4k2^=%gjv@(v2g}7kZg`J-x#B6z-FAiW72vwA6}^lG#byzz35 z&1D;*xclDTk{g+VL};R?7E@1^Bndg`rB~KBAFso33V7GFk`kH{uKD|r%Q?xRIDVJ1ZlaOpVU{<-QY(u%fpbc)4%lca&0e9@M*n6 z^eHVhmaH6tJ{0FlWS-LFg&$w+cQu%Bq;%8;Cw6lU+8?&|GW#*&4?GyXrbz_N@Sna@-*8?ER zTbWUxMhSJnFuv}<6Q%xX(!(ZX8gkv#E}M!1a^f2$$c*`NbG@1+=Mj`{B~g$eHjuq} z>H}(}gJ1*+)wj_;-dJXC?85#+KHzRMr-uw6K%C@(W6ZsT+Dp1kSwP<&Jv0oCt`FiM zB#H>tNm#P@1G`0*5&mU+(7x&|99?i6r0lxbLdlEbxhN)nM8(XQy`#uaEWF;{DN4*N zF8Byb+Bau9-=eNo1gV#k+^6P9h0!eDA7D^?W1UzXg1BMKw|g^uCAPlLE9fO+w6A=I z`{ADR!}{KX(0}s2&Z$94|Lo!a&o=H_$K%(?Z}`#6GbCD}4uk@UNd9Uqw&CSid(R<1E>en@H62X(Rg4#5 zSNC(jG`DuOJ+j$`p0-w}r_+$P(}>Qe)9dZ?^C z^8sFZltL7X%XP#fCcuFV6Nwpb^pGRiw-afQH`u^W}NegUw_>oJ1GU zr5>0Hdx^oqB|;+}TuJwi=Q~6U^{*E1+HRY`XMu5mejC2?;)lMikEaDEC~Yv(?NWv; zyG6Wkj9C*<^@VE>#hDaad%;!reVSdZj@aI*WHW@$jxlIbK#qG&IDA0B#dVTOHhiGwf+vi2ex zvyt*nTAnTh4dV$;ihVb_sV8#ACpF7+x^Y5h%HSa2)4sT<$II^EcJT~;;*mU=Aj~gO z!mpNCNJ{UjXGR%w*BrRJ6$jeJ%t9n_n9~ZtNQ)#s7D$S#HwEHitS& z;F*gI)t_b$CdAoj5npKx3HJ(Wp#!XTxSYXmU^xpnkl~4}tNh2Ai&@`;!C4|_ju87N zX5iUR0M$3(8_xI=pp=Rg!)QugR!Z%eM2;Me1ncEAda3qkbbX2=1uCuiJ!oF1=q!4y{|GDq*DmoU5*JL&VKV}Uyb32&qeuiPfiZZ!Jp)ET1E4e zS&AtnF;T3&L1@-9Q*GU?hdbcQAjLj#S;FMDFcERo83yHj<|FR0*MW0vEp<-QQExmh zOL(j{@)YBk^%sN@8w69SLy}g$M!W-o0xCfqvZ^-&)zT2m2q|9rUepiuRL39Y^UD#C zvgtk(AD?=XzTn=C#k7PS6Z;6{bb{d#1PuJ3NY1g6sBoDe4<6dGOO*|nu{8N}Nq9zt za8QP(A@%C|_|sPDsA>qLU2#<~;HpAQsD>*@6x#beVPurg(wfm2I=At9x|p%(R7HEm zh+b$B^r6q#)Jn+?l0YkxRTvfdgFIN$>41R}PO{fpB(;SD<>7>Xp@&Scu&95wgE>ix zIyh6qmCPkLyEh3i*N7T+>Zf82D#bcRd|Hz|ATo~X>rM0psjyhm!Vna)jtx@rFWWH;|vzBuPU|wH?Zd0Op`^Y~mjzgLf=(p9?k& z4fzh)=z~fi_IW3_Qu9r`Ss#zE;?JyL2sLaw{hN2vPg6-N9kde$V?Q{ySB6G6dvxmVscBIw#|3T~AmsJfrIjH9swq91cN)D~FB+K|>5VnRj6 zR)bhG!lHQWJrPRRd38E&qWesN%xsqUO9k88f$Beb;M~ z&Ut`sK|Z}T#W+ECa zdldXV%xTxNX^)~OkT7T25*0{;T0^>nB4j#8-gAn$Ig;E~l%w!z0++B)>aQ%>V$3JH zkJNO*Y^eb?p{bZiKoy0pR04EVDis=D(I6^16-sk=JNI9u3Xkt%MCG=z!ci3%y2v+D zxoEb18E*mI(66s7IbOa{1()N0Vhg7~PLd<(E&73b13fQO`a|@cg_=Y4{%`SeL3C19 zu&GUMO=^bjO=8$}F)&vmUamc3g_qW!cl`Yl<_{8|fQszh@;SPq$%L7WGQrV3sPPF; zA!CKhF*3j)+`(2fWS&|!_!TYuy*qME=04R2{_H*epRrbxh^n)!;GXFjlTxrry7N^C zI}Bu`+cL})8TH4fOV>I%vE8-b4nqU~JWXFwU3jG3Q=lta5ch+c+?$#RpDIYbAA>0R z+&@cAokb)!aZ8UXv5)sn{5op1CVgpQ4U`pB@|$UEBE|7@M%&-xf1Q*O{ZQtxrY5HyBKV!vrNr*6IZkA}_ zTv%_$_`UZKVHtC_>u-h_RlC3?*PhFpU9TUQ%6Xpx+PwZv#Px`6RXAXVaB=I@SX3}a ze0t%(g#Z)^fD|u)Vd@)ItdCI`_m1f5tV;8zYKES#`TZj zNcIZ1?jO3Fg@Up&D7btqE*)WEhZVI|?#On#Q;mDjJ6%?81Mw@rfm}&8eX1| zxHV>f8m;7@E91J$Ii}t?a*-!%x{J%RL9w(onOprQF5W93Xm8k~tPmhL6I`v60gC5m z1o3kQRR+jjSrD$)W--2Ryzs4rWnqJ1b0?x!cX8r5_fpP*sfJ+Mx`M4IX%md5q{XkK zAvs%x+as5XI;z#gT2ND)A?Y_Yk~oeV%Y=J&5$$jf{@7;s@&>cSmtd`gi9O-^SRGID z{)N=eb?=$OF8Tr^s@;B)SP+5$(lo0!_N#^;Cd~Q^paj_liix#l9fA9+^XC+ov`(I;2SF3_`_cHemC{4Z=5*xZNX zOg>{4L6bqD*k@XUo|1LaR}FYKCeoa&Feu5us74`keUktUfcIz6j7s;|BQ=X(<+608 z?TOJ}bFv}F*N9}gGB@b-(cU0#bfjU4u|+i+?#Go@C*PJ@rQc`>Tc;vT^gj6ubPjQ@ zb2H0%#zaCpcwTz2*q*>&!-Ku#IPTyx>5b(h_Qul!=eaHPa>eabD^)0|ICR01Q5lbJ z*v_7jph?K8!r4o)2u?SPefel^yYeLzTWqRu$EuMUxv#echyfV_dbJJ}XZ|q$;Of;K zXjLz^n;{ENEPJlh0Snjc6>i0WMtEYx;ipS?$sW0r;q(Q9@iql@>PxkGE|>9;s%PzM z9i^3QNl&13SFp4%bJ_wM6B1R*SQc?HLRR&+dvVg~51PPYWRBU#3+xtqNGmY`|?7 zcmqyuQ&Fy4W9ywo)18qSD0Tds(kvnyw|XrA;P>oQJ;mcFZSI&lSLa#4{m7qa?3ub= zE(amL{>jvNW!mqM)~qovUKHbuYfuS|vh=zLH=5+?I*|Zu_%wQ$ z*>h5`BEVc0A`i_|<1+%vXoi?sLRz`ZLf_a9;%g&TM|^1!0o+(I_!YCd8B(3WA&bK%$>kk0Ko3?)H77AC?OfpPBsxF$g(ia z3hKm0?{OLThl_pMl!7ov(PLrV{h}22^thwIg(mhGFs677thuA9G)ZrbuN1Wc5D}TZ zc*fjF#hnqbjw;FghW-l8Z9DmVGq*F!zwU{q9i84hnk-{NF_yppnkZiEB=Jo&Q=eg} zNPNQwRYRuVLehTME`_e~sUL47q<2nelNZcF+*s zTTe7jo|PJ7VujW&ksChwsK(4 zg?pnDYLK1*z0l|VO!uVwT*H1?A*!U9hU0IS1*8LhvX^8aN@fIIr?l`exc2q_{3$m> zRGF%)2qP058m)5k(1ft1Kj`UkpZ4~|n|VptjfEdD$$s|*2QB@P0>LX4!1?$u*O8te zf~EBrya0*-06_V_yN)LAzjwnHHYT>t2G;aOb~g6bChmH!2G%Yn#{Z!)sO~0h{HAPQ z^csY8^Pe(E;B(GhYMF0HZ7h)JfQv0kmjDf>MAfuhN#FdxpSDwsgXsA^nfF2CpU2v| zc5f~)w^W~WBfO3Fjvv$L_fSutxCxci?~IA73MWC=?o^Wp?^pucDfC4bu}rm{p8UfG zn)WayNwh6<1&Nhv`>i>KS*7l#c7yM@)ijlnbVcV3#DReQSTVl*(VkU+NjHP4VOTPY zAPWeQLBV^a+->;9#VZ@>K$0NCY8bFDN3jJk?k2LzS3?U}gx;`zOvW7hqd}$fVi=#E z+`G4JiA@JYQzpkkke3hJk~Y_t-Y%&@tW|i&1Fc{qf0iIZ3Vw3|mcvQ)tc5h|XjGsy zYB;yjBJORH7=(`q&M6tRN;oyeg4m?1G3TUJB4Ld9(E6R}!2&#C6W#!PBaah2!r1ua zQF^z$GIz)X5nRddo#xGR3Nbxg6!UE)Pr#Zpd|~b za3|R>0LCf6#a6lD(j*vZ)7Ys_<5l3bB%#2pY9;f^(lK|6&58&D=?&z0pdz-my%*#d zFko%NpuJt5`XWLZm9U1dYlh*0dxxP?#fSwu7BH(csAeNb3v?nAj)WN-b>)&;BbJj5 zYo}Dfv?6g-Y@x#)kd~$D6R?iyf}y3d83VHjNH)vTdyEg!-ug3l-ZP8|=3!IG4EZIA z-TQ9!j-6Gh>B0HD!u$~9LK7-Cl1mxvkrFYkQpsd$0`MVX7oSZq1w@DSX5@=ci>cQ` zLC7*&T!>qlcK;V+?-V3T)NSdOZQHiJ%eHOXwr%XPZQHhY*|zPjzazSDpN<k$HOCtB8>>MNm3Pri{`ufhE^+PsWy-{~@N9ccUxcKS0FS(uZtj=R8nT}wtIHysZy-16R9RIqqXg*pITkZ z2Z1Daf=QELsUyDIcT);$xkN z#eOYS5yO~d?Z47Ec*iPaUyZDiz``Y>@A!#D!+PU!smUxopt%kII6!HvL!tfdq80Bo{f~OubCKB&h|^uy|D$ z9v(q-1rUHJhW3O@l$rV@bHLL`XYW!+cs>BNr=^Y93M;O|^dcKvV1ilQ6I)8D{hFk7 z(4Xv~hMYIB5?{yhnxtZW2&+UvZtrA%%Rg!lUHxiv;(4d zEIAnh41d}@)Km4!5(%-Hxjx5i)z+z=&g+E3y{jb$kG{)(yZ?NuLi&N;@~7&D>pXMOb<8P$a}U! zO>*7O|J5t%*<_n?PICs}0! z_Em%i0ZAf$JYPmr%JT)>_EM1T^;C~d6y;iN{Cscx3eB^59doot)I?K=mt~u--D9)jvWOFv+qI*+Zc73lt-i~s*Rg{oGSvfB_v=($jXISd+4F3~$37V$4@iGq?b?sG%} z3GJ^ML+plYH+D@uA}(-q#euPkT|&?rY|xzaQjSDS~ZX3&}L+3WcQ%zze!ZU z+|qa}D#{!qz<4JVHKzIUk#u|oY21hrV@Bhi#i@)^Yu7W%Z&AljtsTL@5#R5Z{P%TNpmuL_+O z1LQGotAol+RpVO-zB0CtH#`Zdpwsx`mY2~-N27*`(CAj%7Tz58P^a%k?~a~~J$-}$ zxszrq_>K4vdSm`2pI3n@DDjl1DI|xe^A`&oVx3K?;@k*IQLEqulpnhivLH#4#Vk24 zuVx!Q%`41oKq=GZ>vLobxN`;bi7no-5JlG>boScu?)0zBUtO6yviIZ+Ow(#_OV^CJ z9b`gwe}NW`v%i}0v_}GzD<6WTHXpxIijSK(R|iq<%K|)1fE!y%A8D5)r%qre5Faiv zTmEtuVmgvTGc72(qUE8;!s}%3sVH3^zWu&nLI?M7u?z{rKDOq8QpX!9uqG;Odq;~H zFs58o0@KooT8gF;l=f3vWJd~GSuOKLBG8((4cIW86${21{Ctr%;_59W_cIu7i_Ud> z{w9odR#-}PM9f>L(cFBoQ%@X0lUW zz~7*nr-;-?)rIjtwZ^U{P#?A^&Iv@=?x%{@AY_xl46fCGHILalH^!M`3Pjufm`I1c z-S|$u#n$KwS|oQ(L<=1AICJS~gJ~P?mnAnu(}4;)Owjh;4IL_=pfHi~>TlSA??j8t zX*8x~?*Hn8Be|-v?sq zM-p-}^8vN6N79JYrpn*w(`PJY?U&DhDmDjwMd@}>CCvdB(3=h)=T~YK#C2hHD;S!? zDX@a@Q~wjF9W^CymsS>wGs7*l;%=}rF>{+|7sC8^h`GR_bH{>_g3*Ye!bcHvV;Cw}QN zn8&Qsxx;iFwLJL29{AY8X!qPg-2wRy=0ekiW@GQ?hmL#cXrO~nx5>5`X#dic ztGwVlgo)KrV}rXx}p`Hc2Oo}+hpcgpJU zO#I@2klwk$R+198PS(_~6!@}1yAJt?9_>ixd7VDv@fA33%0i*s<3qA#52=SLA^4-R z%aC78qLPT?IQrfT%{XbU{tf(Ji{7CgW5?q!Si<`^ikkGlDS9UEzbU+W299R`iKbps z)3Mu_K=e7OP2+<93lt);$&OB!N7hq4r_EiLjIE6k+Ek#ruq{tAsH2G%`1z8Pt|M7Q z%PE)%RE=ltdpVPn&go?O8n}HlT*4AnT-eBNW_Q-TJX}kbIWJB=Nh9|VGXIx1EyckT zO^HwDEL2UaV_r4&#T}zNTvuPDfVt94*@RgqQOoiiovOe0wnVJ=Y$1cEu<;ev#l2s` z-DTkmn#TG*Qq@vOb znT}K=%U%SvEAbD7LaDi=i=>%?b-O+5u14xWy5%%s`oy+c5|*0TNXMci#R{}i6g)Q* z8i6ilUcyg2Xb9tyli% z-0l5j_uMCrMKHwZQAmNw*PjJPR00>a*%mPE1TUV}+3=VwKyR%;5n3kdQ#TcJGhZ;3 z+K4k9?bfs|kx3vqiCivg--yoloSuu56H`?rAypRO0$h#iJrIJPM&cN~I0G}Ab0BKU za9VOMTQehVstxj>L_%dMc`=D2GW1Q7`ASks%H>8VF+`ZmgN?aMD@T?U4Fd;C_i83$ z{X~6c(>kdfbv0pOMbZS0OkT1@@^4X9nxvDJ(feh3^c1mB2McwYk_F4kgbHbRtTu=m zbVwfYx8?C&pW`hi#mmPNX7Yq~`YhGIslywt`XXp~|a_uqVBzqu783I1F|LA%{=KR+UNh%yo~cC!kC50wc;tG%{2>wCzDiHzB02 z+pmiPKk&gBnW-aK|&K#=^7?tATZ$QNs_IXMt~Q?Nf( zraleNT)wf1wh~&l#y8`Q-d)Giol?!R+?+2|_If6%^96^VI@N>96Bm)ydFq-#YCWh6 zMe{;V{v1#&nfj_`dEYGs+%LB#H0&E}?2Q6$gVi9OnU4924uur9(gxeeyIrw>h^*y? z6B`#dRz`Ed`fykai$X8AuQ#Mz7ioA1U%*LtOto|D8{3xqDKF=X5ng?;+aDXqfJ|Pg zul!p8<4qak0m8I~a7dyw-;0PcupeR?^yExACfX{Zc<^FgobJw7;E#6=-&Eubjdg$s zb-zsr0AL0O=GhH(kQP2abw`wPJA*16MuQhsg4Nnw?kcv4Q#&IXEo;zTTRj2u;{-YR zaHgoB^zG>d5`dP!1{i>@O>PyWBi#HB={8dILibQzqSoJf=s5oJ#dhN#6WKh~>WS5} z#SAsH%=}xx38c9D%*|#g1kS~;PYv_;d_5a_As>wlcVyYCDW5wbQlQw7s6LJSGknYl zLN0m@=(q^e>zIkJdK$UVeGL#Hrt!!ZiJ1t4tI2n`nB~Y8F4SIVL>myocf=UrBZy_M z{-NwmsK7D8qU#w{grGU7{Sdn2+m+uWr6TNugr0E&ji=1P%`QER6V2sdwDq}(NObjA zz-9F$0O`l0PHG>mLedXqqGwgizxlHb3BiDt+1hn#+aTou9jboS?uw`#rL0?Zd>+OzF zU(e7_oGGJik%Fgd29d7ziAVFt?fx$7i@p-=dt3*2lpQCldvL>a#gg)5Q6uVe*q~ug zHj2i=53YdUEm17J+lg?}8+hqIIiwS8STeKz3tk7~+oykf!LcW&o4h)0mx%|i+OsGD z>1fGFoG6qpcNoyDt#_RVZ$`4^GbFkgM&HhaZmONPMHg}ODFH;^BT#M{e?8XL&L=*V zxn7eb6T!%KUR6wcRM(Bp$Mc%_1ALM`Zc4X2m*IVF+TR<7Iddw2(Fn=BUZ$B-w>FMC zul79pI7LNW&PGOJFM?w?!ns$qPU6Knbu-av-zxM#QOLi2YMh7OikhKh5D&8}TNx>4M&5)AMnr}vaemK{` z^e>c{vjiBlFs;I-YQF;_zbWFx0^F}RRcmrJ54yrqN4*WN5%U<%@KXg`VHy+DW421- z1#5_3H^U^=9i_j!Dj4zcNP8X5;a#qrAB}ksrgvnHw>R^_0z8Tsy+UfIx)(niuL=wf z=Z_1qAP=AP{_C~@rP9VT<(HK4@w;`P`hVIs{O>+Vx!?AH4Wavj8f>QDF@Hj~jvSq{ zEn)*}i_8CLOTV51L=$U#Q}GvDS|3C7vrR}kEV+;N=O20!b6lLADV__cTLrI$?ltY^ zlx@u(Uj28>Bu-1(RWVYrk?zMqW z?|Kd1a?=MS+6DdXg{2QZpkn1I7SHnf8F|8H^L!i}(vTVu)DjdEt7+FR<`U^-XbMEJ zryy-PPMLz*W}5+4Vk&Z!J($6=*f#1>P5YOgU%xradI2FiSNpFAx!6zMPYPLy4B{77 zq|Z%kC`Cb@PNlm3o(2uNm1F9Axw5lrk2zUF^uPeTW4y&!v93tO%60MPI3M4%OQqrkJmCU$jj0Yf}5E1oGmzr^u)s}_7 zm1C(KxCtyL&9b_&gKL`!ysO@R$VO|d&OA+p-0wEeAJ{vz2lDvlv{1^jE-Izx*HvZd zC29VNA?x{GGGlQ=XRbD+k3$sr)_KeR9?i9IROS)q>Wnfw8%dgTZh4 z*nIJ#=^gze>3E9W4zImXlHBAs{^>V*^<9v2;Tqxxc`d)uBT^VYvHV?_aEt7sXsQ80ipW= zmH2r6zJ85&*N_PW4PlI zZY^6lJkyuPmpX#;p2iw|sUrbw6D%_DxXp%AssLrmRgk72Sk4Bqwdm^tZh_X}G73g`H zyrvIFRC+zg{Rau;)nlA>kBw&3sxM}JvX9Aa=qXU(3}|)rB&>FU+x-n@The{$>+t3h z>Qt`V_dPGQ105B}v#kIT(5yoia2j}nCoHq~5%!l2S1V78Xs&0bNi{8NOkq!8#}`XP z8>Sa2Rd@tCe+FBvo6`U!lovu}vti#wNWw5i=LKXRCHZZ1LE2*_YuZutI?n0|&8HHo z(75LkY}P^yxDwPLqoB}1zSNg6rLP44Oy0@ zlp}o%AXNx0P&N$*v98|q+V10@^&Ch17o$t}0n_@bdeCpNfQCiWOlge8T z++7{f{{m)0J=#&SwkV?>3Ha8WqW^!t4@_`&)uyKXP#bsv09_gY0HXi(hS4(6G19Tp z8JoB=&>34eIn!F$n%dDzO9+X`Dv4}qYB_C;p!&Ylt~!<|S%8&fvshrCMARRKRts4O zx7(5>g^|_~w1U-*(Gf&`J^e}T4k$8qdsmeyVU@UCYdpEnzR4PfzH9Flp`Xk~x^XGz zD%ItAIOI!qS$E#Qc)Jk%lyNO=0>68@SlQueH)n_W%2ca;{`k^fU~@8Rt;O1&==KbN zzNx+O)r=ZGm+b1+(4oCaWv^^@cXJ}=t*V7~aP#qUb9&mN_0^eEQk#7G+PV*ee}#p7 zwWO<2uIB#Z?`ZoyV8#4Eo&njKBtxu8%FWm(0dAs2Zmbru;EY;})Ld`qT!Enaw5)6t zp@*)E!mDM$uHKD#=S7%ZW-p}@F)dJ~GA`Hy^tK$mk^L&0hQ({{K9g)z*|DciZzP98 zE#p_`*@+#z<~9K(4~cfZM{vk-8NKvV^3|2)L>-!q17h9LjD+T@F}tsp*fot3f}~2dpbljbjH3hFfI;MOY1mkEPhrx47GCq}JcC&{Tca zMe>ukdSNi?=swu)Zbox9AJ@W;z&Ktpzt=CuOm8A8=CCiQ_lrDLYXNxCxJ9DM=l*sh z1S6J&lpQ_kLivuM(HEiSZdW;Gt~C0T1XdB!S)EYdcg3*J0% z?GTe_uKAsQZXk{9d3}Dmp|{`$J~kKx7dY4jtK)hH*fdUoMiMv_bBn{;t=*^jnVPM0tQ+<-M+V=4rSKQk!`mJ39Ik0>yXBd+lu>FYS-*TN6 zBqhYz5}uf*ol138$kC&f4BAhyR#VA>J&S|}Do#~r1LinNd~gACkhycEgWGK(^HbgE ztz#1*M={Vv*e0=nE)1W1@MrClYttilTd@MB_onqXXuf?#$0gh+Zkv%c1)rsr=4<`t zI9wAPB#^>9r?+CLe}Ns=Pd%DqorkzKEOvM>|C+GwKGyht0@R@mN z+4sKFD|xsR+e#|hVf)huBR`hsMM#RvJ*kB>dkA9MV*l#_Mnj;wv(I`>4ZwIF@vSi# zq^KUXx(8wsC8OumaVWKkbMU(TAS81HcJ5e!&4Vu3&$Dil76Ey0VqC7;LggxZ z_-Nm7yb-F+$jy-Ux}7t})Br#ZD}OA2bc>}RGBrb#APest2v9UOHg6!r?P(7-4+Hh1 z>CKO6@ zbS}}#DoeO{@GWwtO|SycgQBGB4j9=L0O6Luq^J=@l?n4@m8aJ4DGJSboa~ zqnrpzY#a1gbL84(#R~8!Y(@&3yI80e=VG1e5Z&uA+y!n`9hT6+3uhoWlZ^5&)e#m0mmo%?HwDD~7kz~53_ zfJEnD=}AVG2ym=!%IboW6$#&vu6(Mzrql^mN22sh80o?SrZ!F3 z_M48*%10KXe{UBS$KrzB%UQ}t^@Gnegm?xNv}-5lqu~-;Q=X9pb7d|m`iI9|&6gO0PgtW2AJH=1i zyK88H=CUAkf-4cl=PEnx{OpD1Zv>*o8~qu%gIplxxEY150;LwrdyWa_kK6H|(YD0( zIfxM1Zare*Gh|FSO2VB@b{ar)ox^7a-;xB``T@St?qgblPt%`sgb5%D2#m!pkeYwc zV7>wTO^M)s%L_jO4JnIbbBW=BkMy?$@O|?GREWl1Wkdw$QFyM3*C)$@k?G@w>YV7Y ztJ`A1<;bpzUOiHw#xC#Q{GS1pM|-5LT6&j71Lyb_lzuq%E5zv8%}1F^sE$Xtt&M%lgjuEC?45WyHm0GN+{~ zU`zTfO>lz!5~W;6)Y>xAo@&g%BWfXfDEEoULoY)7V*3vmuUSEh4oJT~&wC$4tn3x1 zPeDe1INPo{DT|5o1R)mna~??vlu>%PBmsj$$z0T@Bv_RAD2?7DK*hnO4@KlAL_L8H zp#`c3rXRf;dZA9P3InPqJ&HR;F7&v`g&H1mEmi@-ULM?*xxdg%9j_h4g4mi3XiNcz zurT0V6Uc1k^rN*V{uluVfn8X_g0xzTWw?FZYW*>)ZDFS`z(wy})Z`8=BnsSi5$*50h zl&AkR8i^<*vK~wgdCCx(ZUs&uhErErk)5^&%&sK+!I-c}aa)w$lDFw4)}V$U*nE)x zB@2uiUl`|0D7Txryz+uSNW>p`{K4gi)FIF6Pi&V#y;bh1A1+TZ+4WpYo_ql_t5U`{ zfkK2h1z6J{m{X;{iw`@^ggbGXV<4`VN6MJ{z;!bS*DE!h`II>f&^qHd48kK~tsRpu zIe*JI54L56DS0Tw$nEAWLH#a1cqu(SrY$rhFd zK)e14^Wsuv_%3@VGrwVUbsLO%uo8P5&0)@Ie z8B-Za$P!)VGH5r$@W#5zZf7!NT+>M?pEXx1YAP?+;Nh8%JX~x~lO%y}ohk8<2uRwf zkU)*C!ZsThE@q4+GfH1A$8iBd4)d0pj_rlq)n%>cXGh29Dl3T1D#0@-R5bdKLTZFg zpt5V14`93)EXP~Ce@SgD=zkJ`V2 zT}qAFD}h;bmAbNpqBM+BVfuaXufMrZ~ht1fz3rMNX*dnZTdPZ>z&nr|$tQ1!< zMUdmqjx|(K-Q->Cyy^AY>CXp<0tU2?fX0lcgDE%@K5*^v^(cX9uT#S&N*Y+GRf zu$$U&F(7F@>CJkqWS3nJb|x91_4-pEM_(ToKvg}S?b_X)S~cIEUyD!YJqUa{JKbN1P1>Jz-o0Hv z>4K+Q_r5OO$1C7yx+7Qna5&Fg0(MVbKfQe!JY^@q;Y!>dqcPvGw>->D+cM-}=yJM|;WhW_u?GP=SxeS0$Z~Og0}HI* zb@!7IIQzpRR}|R`k1R6o#EJfUb(eepvey{>q z`YFz&q{+7mNQC8*3TR+Hr7^t7beMUK2Aq15Y|hOyC70%EZlUx?meS|jn`j9Y2v+7J zDYM0_7H!gVae&E$)MBTp;1!}4+wYEd$3?m30BPf`vBp1nqei0HM0Gg)Yt_`$)n>cF z#Cqv%?e-Fm_;&K3Ei?hbJ#TC24{NXY?20CMw@LeVzR>9?KCMs|Ed=wHz#2vbv3Y~b zMH9H*G%HISfuKN0xP!jI>?l!%k1td8%n)DAnf5BT2tknZD{kH!WIgHVYNTJnUVHfe z77D!H(1E&K{43$gLq~w*zbnM+;U1fCb!Crg--q)*iMu08&hS(Pum?h8gQG8dwcsO% zrI4L<+~iLtG&*XDGB?GKbZFDuDtnxO+KLC0!J*3vaAau{g~>6fPTTf=Pt}u`f#AWt zIp{tb=(fU{Eg}qSI)byS;>1`m*{278{ZRQ7^bE=oYpv>CIN%kps4~k4l}qSFprX4C zc%3;qdZ8eCB!&Rd?)iM>nRDNIV@u9s6jd(bAaeOK6GOr8TtcchFqgZJ_npm*k`(h#cE^4|D(Ewy6yx z7Qvn$GLOc2)b9DsLSp~o`ceA<{r}9W{ZF%=e@1KjhzJ0%$^ZaB{NFe0G9t2S>043jFR{ZcRtR*9wa%_rN~cyYOpG0T*z`tOUJD=Dwqu{$!`*`hb25S zd^3;FpGn#+Cl%Skm>mdPVy2oavdXZND6L$YQ5z;K8tIx-;JifgS74W+bs}=XO^C=>sKvXtyU~dCq!zLb^j1= zeA`$}wa5z-y@{ArG!XbHOj?%`{FAtMb6L$KYc(rS=vYr9QzFI4MVDklt=GZlZ$%hatlDe5R7^5F0*wpqS_9F zO)Y6SHXomyD5p|uZ!?!DOfLvg!~E%RFQr)>ca^(=z!|VCY!)DH+MiyMIt6|17+!9a z6;m@kcy3T|I5YpiMhK?OFRl!7sd#km6R5G=&vz_5k_W>nq_G&JZ&+`sW3};^r2397 ze%xR$eCo~6G|RTc$uW&WUqVAuf}JM_!1Ap!>8GNC>L$JN4wOdmGGAv_#CWSxb?uv? zPHU5A^C%&2FJ=E#Nvuw*(afKKGsm9pxUkdrgFxrg}{%+LJXd%rKP!y(5f zlM8(Zy7OqHoZ6<#YM7y0X4vhOK1C~oA{2| zu9XUw7@R`TgnA#r@}_sd60v#sU@S+drfNq$RLN{e%&cP`UK>B|t%+~}tkAQ>iKyG! zaodYR)>$91AwkZh5z&k!si}ighj^}MRS;g3)x4apczM*nkZ(iWB8s+)ZLlYu_^AL9 zsT8j^aNa=oPx=y3%}KFf!c5SHVq2WocSBENI`h znLD1NLJ*0RkibEotBz)uaBL=g86>#xr$wPkJ8Gw%OB|UB0fvrDzAV5X6Z~qfk?pBV z$2gseVuzB@Y4A?kQQ49{56~i_TXB0%_bQyzj!6TTKlF7P;WI|JRaPMT29?IuD(@dG zPMRFZAFYXkTQRK)v{6xH;zAp?NaeyvA*MPldg@4t$#>ad?8L>u-XL;HO$czyxkh!N zxHbOmqUaJzXD|n+39gz1jl3t{QmWJTfRrLuV+4%v)d816nL=EaZ28YIR?d@vZtN!g zg_b+b84dPd@I0HNV9F#Jx;S;tj5G#UL@H}Vb5TbWlwu$Y66fU+OhYV@R2!}!k?5pu zV$&+A9Uz2yDxR6QO-P5(*20u^g0*zQ0SJvA;ZO0P8>PmZZAv75zVY(k)ucwZ<5Uf! zsw+l_H1)SX%zu?_aT3t5?-56v6T5d7JKR4U(N>syji-2reSChQG0#R98CrK`hL zFqKu}t@5kXoS@YN%i;HD^QxfiY0ch|m`KJs6!!6RMFW{_q@`fE9lqojE|Ru{?9jsz za!>kH6vT5+z89OMwLeN}8szHq;TR?YJLBCY#Al~V0(N@-boo%)HuFd&E?JtmNmO6N zMu+*i=a7Wg43@%=gQsxk3?BOZyDoM8B5**-=li1(uh?}5!+-48gynE%d|;?s!jMYv zD!v4&6R3GpQhIK(vPB%r~Qk+Y^qac z$P~#UmVsVkG3{z|?;*iM@V-%jag#r+#(f)d?0z^)+O9G`sz}NWvR9EJkm$Up1OgAb zAu_q}vd(?H%!$fbk_dj6ns`$Rq_<+K;)A&O>-|o3cxdk83BgdrJbC6fWS8NtYd_=X z*ZqSPURj|OT$e2kQ2p1dxRu+?Z1L=e1lMO)X;p4th9{=ntzR!@m3Y>3xXDG98_LL1 z*)W-GV$1LlsSqmBDmgy%mZ|jZiSR8QU>`5%rjZbJBt<{HStbx)n$jRs7rowdJwdVD zH}79PDxnbJqzhfhi4WPe={`b7lleKp3Mz*NgHPED8O{o7^5R$3gpgFDyi*06T@Ow{BPTeGkb#eO!Z?{0n z;FbH7;n?opTBmt?(6`W;oMR_JOM8(y_u2OXDZeaQ$;!xfVJGMs(0bWC&>tgH#Gav3 z-%zUL@A*eE zSo`9Mb=;BxDz*m#1G&Q(6hcG8Y|Y`#IkPzm@=>~Fi7sqj5tF-dH>na^u8?tX11gd%^G+q2Ih1!1%qo_Wy)apkq&)VYsVrT z8eif5uJ$DLPpK+KX{s0o3SqL9J+V&?XG*=};XjD?1`XVrdM6hd+*}Le*iZfuGzR86 z`H>8Q<|rKrRFKp7pDs;s1X*aZO>!Zy7zN1(>>|5+&57{R>@o-dn9clkh&J&cS>E`k z*s5z)%?GAOX*Ra|#^wi!H72N0UfC5Ef~5+p_JO!~zu9BGWIFTs;o}p`%{R7??D?CF z4A6{@)YvlV3|u4!Cdt%~^lNaOcBpd32;noxusDNI9Pie7MieahiLd3#`C%C^&xDH5l^cj(qoB!~s zJn+Hm$sQYdb6gqi`2a)aK!p6M&WHFZP{)wp3!ZU_^dVGE*Nt$nfIxwYmMe9B^>M_( zSO0E01E{e`KKBY}ULV(r8DW8pkWoWjx+vu8zgDe`E@@=3;kswc4Ph_akljlo2!nvX z?28gwl}Dp0gSBpf%GNvHeY9%M&j&lm!D!6M3>%gJCH)(#l;VaMTPKbo-J+r&AlXjd zt1A2(RRJD;8iQ3J_SScBz5L!9UdLwS$aB03SS8;iz{8bn?$8V2Ho0DakB2hi;PSj! zK*y$Fm_<)D0*s+!!ct(sRly?8Qk}Bl#rQBD4|y+xcN!Yl`k+p!XqdyrKDZ|$HJC_a zC3TEv{n)Z!50}F&lTG^B0rnF-c&&9hrYEh!z;mr#GTR;q(N^R&XbjxdF#!N2s-hHYN6r3> zhu1f!oB$&1Wfm!9B8fFU%xuK1k=boIBd4}_JtVxHRL>8Y0DC!kO?Wn75Tl1t+@DD;o)VG=sr%K?FSu|NOV?93N0G} zRoP;BpL&KVz5;tcZQM6pq{nTt>MR4EQ{-2S@0XA7SAsu#HVTf<*R7-&H}TA`7N0Lx zELuSQY-E>FA7bSBDVQByxoUQH^3&L{Fd2IipghR*;mRWVK>TIZ=kpdK!kU*9#73u|xyxwncg(^(}Vk{0b?ys?a{l1ris~Z+KH3&Q_Rto=lN%Xgf0Y4!1ZKr2ynhWdD4BcC4Zn(V{BD(vpTY}hJMz3k) z7Z5{G4X|#w=tHt;R?W*HiL*-PoeHqp=HBaE5XBKhGpy!3V&M3jY-Ip31MKYH1gj*i zlt!)Mo=yO`t#8BdB9e<3XoRp9F6_Yj^*TswK})v8h<2IpotU^Zpk>$t7oye+AcMVW zlm4oos@M6E@* zk0)1CQJPf@&6y-Y)>RhiE`1%Z`z0za`Scp6Ho(F#C zMg6x)LvHy~uLiiR7)F`4>&I5wo^K&Vu0miB`%<=z1TeKbSrDC8R%|hm6ncWxNM43A z_{ZNkvNDqzA6RvVW{fvGj4>gCrQm+>>WO2gH!<%z6B5j3sg z5x;e%caN78O52De`{I;COeBcPMM{pjoWlTbG-vy=IOz$Mf(X?MH*N`8N`f|0#9lHj)+ zJ1zN=_#cDi)XA)Xr{}+N3aA0fI$rcU0`aL3k1}cmnm}2g$F~^A^W1%;);AUShkk-me@opWQ~o|$QMM}~>}A-i1ARSZlP5}G9=RBE8egW#lu z)(j*tJz&Z6*>d=iMYu^s_MY)P^za&Pjj|qJ&JT0>Mo$ozUALe8*^v`>%ay|%+H_+s z2hm!nTm=ft#rp&Xj~<5=(2<#Ma}YF(&sXCJMgFntM1>GSwr#zw@y!O9=e)&pyJ}*# z%IfEZg|8?do0g4lkg~JmcYLntHSA=&6ko}HGO}=%oo*?xV7IyXUe*$Ir(EwC?kaaCSe;Q{v7x1*~IoA==bzaM1 zQVj9E?qXk=TzY-#6+%CJq#ij8v;2mta4vrd-FlZ1Gx9L|gT={zm4bcaHQzgj>V@1c zwH#NV^Cn#J-c_9m{&D|s`J9WkPjCk;8`aYhVs$KjvLWxoIg{Nt$u`21RknaT>Y^dhl|EIfgmrmO1A6L#T;a`uQiC zWmpPY$?0+F21N==Dw!iF3F%fviZYhv>B(8CC7G$p!+jut|LQ%d=%Y0GkBmk5dr|!N zcjjzouV-!IYGO_2?C$(a<)J&4tQMc59h0Cvo}j0YsM@8Hq@|grk(`#9rlvgpzZWu@ zx3dlWoh94-#Tfo)F*|!_1NZ-blZ&E=kes4$Oxk#2;aH2v$rCL_R1fARAb~JB7i7GC z;CvL`{NM^-!S-}lB2z^bnE_d2-LUJ6)AQQ30`jw8s7&^!xmupzD?{kdC~z}1=HP%h zX(HtUnpU(`SWfA>NCMGGXXD3EQ%+x?7F0Y5*f*k&z`n_IhTTH`FcNI+)Wpa#)PP0h@1vr05h%ek@SO>l?gbM}SKxc~j0KrS! zpy*#l`pLh4!%LE2eEkLiA@u9>Rk#2p^7|oaq!OFMPw{Ui?VdKNGT#VZK3LtmW|byg zj=aOBqpuKBnO{76t+THB}zP5yD-y+l||q1xZed%(QT zBy*-3(6+L27l=&2zTT5L0RRDx*c0yl+MB95mSF%3vIRlr48t7Nh9xUFhK_7Q-<{)KzR~) zM1ktUB${GkJ{XgxAOc3dO^US@G7U`gZ_P%n*Ouy#oJiSj9YZBGC+pj z!I(5phWy94#68eq8gCezllLK%BOE|*kKz^kQc3%|FMwAOH7K6ZN3)AWYrZbsY+ZO4 zu2+VwH6x!JODw2Q|!j_z$)J@e$}hY83>BGg^idVgM7=F z7(1Jn-PIEBJ&k=%;OVFQ77G78lsDKbC4Fq+C1(cr3L5PB0G$F2df%RfyNHzHYB|(N z2q{%Qun8D_hS*->3*XbvD&M{IP0b=KTZxPe(WqFtLxagbIk#m&Ym-FBld2{0FN*H| zdduK`lOHIMagKrxrVu5gM81awyADP?`Hn9fvu4vAn#G-clQu3f^wlxK5f z!odmS>sB!b}tDivVHp(=8gr&>xZ!4u+ssa7$QlJ(cTPB!DaUAi>G#@<)Zou5T3i!9B4I@s_? zcFljB?`WOlzEi%)x{Cebu+v#P#9};{*enal^qjJ6Gs{P`C&@Pk1Czn~qVi`y_+-`B zbKezUB3Zq&9ap&aCU1?P3Q5zYQEhIcrYUJJJS|OktBTU0eE)N!L{h<<+aR4z_nJm~ zO|x3mhu=8)oLB6EDmr2eKVrU&`sWECH2!AN?tR5pQ_8$~UkXS6&e;rQCKHiW?X|L@ zB++7|V?Co`qg+MisvcgL;V|i`9Cp$c8WvC5s|hsr)M5(f9(2u&g9VdE><>^eVTV^P z*J$nYkijXcdYsCp$zCPD_TvRhF$IP4P&`gEv=f(jmV+=n9!^NC-CB-=H%ZgP6F_9) zm|NSFyo^R}Kj>r_`p z!2Aqjws2TyZfBub=0T~$I%kPA%Iw62RTdh}MshQd_t0m`c-MvRJmm}y@}Yr%JYac<k~en&oSNTnP#`dS{^1T=9z(N8>9>v!(x|vU354YF*>_8bZpY zD4tCcI@3j(gIv1kg33;8#8B*}F$Pm}VP;(S-Z8p5-E>jya^DU@QN}4rD%nU%#XfG8 z+`1r9iEhrj)L<=ZJo|h-Pm`YK`}u$G_rCA?mgRjHE8cE<&_>616%{F|byb3=zAaB9 z+AkZM89Q!#d^~T}?sR1iv$(;Rq7i7RYtTiS!r^R7Q z_I|;+MF;;GK)Snord)7$ew5B1Rey@(=MG+Y+a49Dk0jTf{dd=+;@8K2<6e7nZ>o>$ zr|SFBrnHi;D;*RRZgzeeyz#T`y0hf;xrYrB)3;4dn)B{@dvRo9=j@E^j`QLoH4-UD zY%Yqfs%bl_c_pvwchWp9kE@l%eR80yQfh_n+Rk=}k}WQ+-u$lg}|r6~mbBF^vP{W(@`FmkV=yG+g4H zJQsSl%-*Jl9Phc=`_}e?b9l%O&%|5pEkRn{{xfY2vfaD}ZM_3rsutf$4r@M9R5q{W zOH{9B=ccAi=a-6FJ#&M#SEPDvYq5TRpJBk!w3;vX`n5rRTA8xPYpQCIUrfQdAlgrTk#5{O;^~u0>hpoJ^UJ=8Zf(NHwyen-*! z!<1C2$Ih-w^=^%(ZpS+pSK3>cAM!D>9HhSbHCATea$@wUBApe*_B0FH)TNw_#~J2@ zEswhLLv4)Qk`Kz}*t_|OcZdD9vf0!@bxCB2N>9wtxCz(RoNWBBqCw$$xS&;IUx!rGKPyHLt2%C z9QSI7KBRbVHzwcv$KAjZKdlwhE8Q$uNK{NBx!=n3*~X7{Jwt!+$ zZs1K#-IwmdDUEaKxhmI1c9*{#S^xQc@Eg;>2-RZ6Ss`DUT7ss9PCqA=8EPeNNiSTJ zHds@#*IoTtZ}l~$tZJvq%p_7qMAMOER*^|UOJ7;8Qt@ofgn4~2?srz7E74ndZ$>BU zhp3t#$;=l9`f7T$`FW=G9q^jMh+8yEYT+zW^N}H_jV8&=kcIO)G8rK$nNgpl+?jSk z89|Q0FR3fR|4N2fg#OEi#}Gg(Ag(V|L5!V=z|8VijfWLbATx0DUGh5GZ6Ar^H&E&2%4 zmrW0)egBGp;#+pZW)4t7K^=C1FAc{~{OPP6K@(tVtxUF{y`Nry7_{TOWQawpjmKcv zJl+nvKa=i5<8Yx!0jCEwc!Ksd&Hb?T)$-qNeMdnQ3+W(fcE7M%BsUYBs1jOI!H$m% zvFQD>F*G7E=(fo*VnaTo{sX&@z6)0}#G<9v_~}Fvu1vC^UVrZNCLD}Q0y|(Ou-1RS zt)+}92pYS~Xegwb_p#?gIZ*@HL|yXz;ZrP9gBSD;-xX7a|0eqWW%3qOzQ`tJ0YqTk z89ZQzs{}w3@EH_3n?=Wm0IArV+Q_YY-39OTmes-gS!wWi0FexIUu9h5mCl11PR>CP zHOxAmM_?%45Q;dzK<#jn^N73z^;4J(ZzmoZV$tn?jhjy-0)<#?ZhPedA{=2U*!?t$ z!9W;nRzLuY>qj7DSG3C42jCOeRq&pk2NwReA;b2RFrLY1&d2y znA6Q?dk=zP{2|t6Vb-IbPYjVJCvYWswhhF}HrsC{bP^b(?E*YW$mg-4lN_G~i4yO< z=$3zI9@I~OW$Czm`~9m?mPh4;ENSB0t4yO$KytDML!tG3k|7pZ)M0?kz#vH9Y$lh% zl*Y2IxNGDF{@4Zf!lrPArx=zmogw7$S>eO&^6)8QZNGMHJ*0?fpb8v&WQaw_UP_@z z6GU#hmh1_RPzKJ8b*Qv1EQ-ye3(c2E3kZU?Z@=y z(+RxFMJ;2IEO?a&A_e_9pk#Ta^ZnZPSfyq@%1NU5Q#xI$!k)@3}uK+#c`3T5{3ma z;ZW^EGB1H8u!*AWvY#f%tr*j3q zJU$1vA$K&)sM7`iwQPgUI5Z!hOnR>pLx>iI?#7|5nd-f;AU3!N_HIYD3_}VEfIA){ zZ!P4vcV2;m<}%QOy%{j6A|OZ;(pV#3Yz7JC5+oGtnN{xr7DHfIbR|bSO4c3)9or%4 zW4C8t-FSdB83P^){V|XgJzVhd>)3>$gz)JB(4Q8+#i$beo3td32*>YtgbZn7@Xc{Y zLHj4mjJ@z+It7etl{@y@kLts>T^z;kQTe-AX_E2X%|_Yb>eAVmUT1i@_|7Mz$cx&4 z2PsV|zUjg!_pq*ni!rQ}J1(N~`cc*!z3(jI=aAov?Z>W}r~+V=7jHZvPf`ys4v2d? zJPLehGa+#JSuvi5f1*1|W7|$h!#v-`Q}K^)N2sLQiK*Wn<~rHS!37*aR>L2A$Ra15 IAXy{-1M!0rp#T5? From 296094af24909c9c03c4a3701578b5689ef4bcea Mon Sep 17 00:00:00 2001 From: wglao Date: Wed, 25 Jan 2023 11:16:50 -0600 Subject: [PATCH 33/35] correct version name --- optax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/__init__.py b/optax/__init__.py index 54062529..a76aede1 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -180,7 +180,7 @@ from optax._src.wrappers import skip_large_updates from optax._src.wrappers import skip_not_finite -__version__ = "0.1.5.dev" +__version__ = "0.1.5.dev0" __all__ = ( "adabelief", From fe000572327220974bcaf063889c97f25474d273 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 25 Apr 2023 12:58:09 -0500 Subject: [PATCH 34/35] limit the injectable hyperparams --- optax/_src/alias.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 682c9af1..369f4c54 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -458,10 +458,7 @@ def eve( updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) """ - return schedule.inject_hyperparams(_eve)( - a1=a1, b1=b1, b2=b2, b3=b3, c=c, eps=eps, - f=f, f_star=f_star, mu_dtype=mu_dtype - ) + return schedule.inject_hyperparams(_eve)(f=f) def fromage( From aba8dd4dea624beb10127d0ae19afea103faa161 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 25 Apr 2023 13:14:48 -0500 Subject: [PATCH 35/35] bug fix with None --- optax/_src/schedule.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optax/_src/schedule.py b/optax/_src/schedule.py index dd780e6a..3885a391 100644 --- a/optax/_src/schedule.py +++ b/optax/_src/schedule.py @@ -576,8 +576,9 @@ def wrapped_transform(*args, **kwargs) -> base.GradientTransformation: other_hps[name] = value elif callable(value): sched_hps[name] = value - elif isinstance(value, (int, float, chex.Array)): - numeric_hps[name] = value + elif value is not None: + if isinstance(value, (int, float, chex.Array)): + numeric_hps[name] = value else: other_hps[name] = value