diff --git a/optax/__init__.py b/optax/__init__.py index 6e7ff09d6..c4f13b89c 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -16,6 +16,7 @@ from optax import contrib from optax import experimental +from optax import second_order from optax._src.alias import adabelief from optax._src.alias import adafactor from optax._src.alias import adagrad @@ -115,9 +116,6 @@ from optax._src.schedule import sgdr_schedule from optax._src.schedule import warmup_cosine_decay_schedule from optax._src.schedule import warmup_exponential_decay_schedule -from optax._src.second_order import fisher_diag -from optax._src.second_order import hessian_diag -from optax._src.second_order import hvp from optax._src.state_utils import tree_map_params from optax._src.stochastic_gradient_estimators import measure_valued_jacobians from optax._src.stochastic_gradient_estimators import pathwise_jacobians @@ -234,16 +232,13 @@ "EmptyState", "exponential_decay", "FactoredState", - "fisher_diag", "flatten", "fromage", "global_norm", "GradientTransformation", "GradientTransformationExtraArgs", "hinge_loss", - "hessian_diag", "huber_loss", - "hvp", "identity", "incremental_update", "inject_hyperparams", diff --git a/optax/second_order/__init__.py b/optax/second_order/__init__.py new file mode 100644 index 000000000..eeca2c216 --- /dev/null +++ b/optax/second_order/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Optax: composable gradient processing and optimization, in JAX.""" + +from optax.second_order.fisher import fisher_diag +from optax.second_order.hessian import hessian_diag +from optax.second_order.hessian import hvp diff --git a/optax/second_order/fisher.py b/optax/second_order/fisher.py new file mode 100644 index 000000000..6e78a1d97 --- /dev/null +++ b/optax/second_order/fisher.py @@ -0,0 +1,55 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for computing diagonals of the fisher information matrix. + +Computing the Fisher matrix for neural networks is typically intractible due to +the quadratic memory requirements. Solving for the diagonal can be done cheaply, +with sub-quadratic memory requirements. +""" + +from typing import Any, Callable + +import jax +import jax.numpy as jnp + +# TODO(b/160876114): use the pytypes defined in Chex. +Array = jnp.ndarray + + +def _ravel(p: Any) -> Array: + return jax.flatten_util.ravel_pytree(p)[0] + + +def fisher_diag( + negative_log_likelihood: Callable[[Any, Array, Array], Array], + params: Any, + inputs: jnp.ndarray, + targets: jnp.ndarray, +) -> jax.Array: + """Computes the diagonal of the (observed) Fisher information matrix. + + Args: + negative_log_likelihood: the negative log likelihood function with + expected signature `loss = fn(params, inputs, targets)`. + params: model parameters. + inputs: inputs at which `negative_log_likelihood` is evaluated. + targets: targets at which `negative_log_likelihood` is evaluated. + + Returns: + An Array corresponding to the product to the Hessian of + `negative_log_likelihood` evaluated at `(params, inputs, targets)`. + """ + return jnp.square( + _ravel(jax.grad(negative_log_likelihood)(params, inputs, targets))) diff --git a/optax/_src/second_order.py b/optax/second_order/hessian.py similarity index 55% rename from optax/_src/second_order.py rename to optax/second_order/hessian.py index ea619ba73..1ca152f7e 100644 --- a/optax/_src/second_order.py +++ b/optax/second_order/hessian.py @@ -12,37 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for computing diagonals of Hessians & Fisher info of parameters. - -Computing the Hessian or Fisher information matrices for neural networks is -typically intractible due to the quadratic memory requirements. Solving for the -diagonals of these matrices is often a better solution. - -This module provides two functions for computing these diagonals, `hessian_diag` -and `fisher_diag`., each with sub-quadratic memory requirements. +"""Functions for computing diagonals of the Hessian wrt to a set of parameters. +Computing the Hessian for neural networks is typically intractible due to the +quadratic memory requirements. Solving for the diagonal can be done cheaply, +with sub-quadratic memory requirements. """ from typing import Any, Callable import jax -from jax.flatten_util import ravel_pytree import jax.numpy as jnp - -# This covers both Jax and Numpy arrays. # TODO(b/160876114): use the pytypes defined in Chex. Array = jnp.ndarray -# LossFun of type f(params, inputs, targets). -LossFun = Callable[[Any, Array, Array], Array] -def ravel(p: Any) -> Array: - return ravel_pytree(p)[0] +def _ravel(p: Any) -> Array: + return jax.flatten_util.ravel_pytree(p)[0] def hvp( - loss: LossFun, + loss: Callable[[Any, Array, Array], Array], # loss(params, inputs, targets) v: jax.Array, params: Any, inputs: jax.Array, @@ -61,13 +52,13 @@ def hvp( An Array corresponding to the product of `v` and the Hessian of `loss` evaluated at `(params, inputs, targets)`. """ - _, unravel_fn = ravel_pytree(params) + _, unravel_fn = jax.flatten_util.ravel_pytree(params) loss_fn = lambda p: loss(p, inputs, targets) return jax.jvp(jax.grad(loss_fn), [params], [unravel_fn(v)])[1] def hessian_diag( - loss: LossFun, + loss: Callable[[Any, Array, Array], Array], params: Any, inputs: jax.Array, targets: jax.Array, @@ -84,28 +75,6 @@ def hessian_diag( A DeviceArray corresponding to the product to the Hessian of `loss` evaluated at `(params, inputs, targets)`. """ - vs = jnp.eye(ravel(params).size) - comp = lambda v: jnp.vdot(v, ravel(hvp(loss, v, params, inputs, targets))) + vs = jnp.eye(_ravel(params).size) + comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets))) return jax.vmap(comp)(vs) - - -def fisher_diag( - negative_log_likelihood: LossFun, - params: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray, -) -> jax.Array: - """Computes the diagonal of the (observed) Fisher information matrix. - - Args: - negative_log_likelihood: the negative log likelihood function. - params: model parameters. - inputs: inputs at which `negative_log_likelihood` is evaluated. - targets: targets at which `negative_log_likelihood` is evaluated. - - Returns: - An Array corresponding to the product to the Hessian of - `negative_log_likelihood` evaluated at `(params, inputs, targets)`. - """ - return jnp.square( - ravel(jax.grad(negative_log_likelihood)(params, inputs, targets))) diff --git a/optax/second_order/hessian_test.py b/optax/second_order/hessian_test.py new file mode 100644 index 000000000..29a7b3b5a --- /dev/null +++ b/optax/second_order/hessian_test.py @@ -0,0 +1,86 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `hessian.py`.""" + +import collections +import functools +import itertools + +from absl.testing import absltest + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +from optax import second_order + + +NUM_CLASSES = 2 +NUM_SAMPLES = 3 +NUM_FEATURES = 4 + + +class HessianTest(chex.TestCase): + + def setUp(self): + super().setUp() + + self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES) + self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES) + + def net_fn(z): + mlp = hk.Sequential( + [hk.Linear(10), jax.nn.relu, hk.Linear(NUM_CLASSES)], name='mlp') + return jax.nn.log_softmax(mlp(z)) + + net = hk.without_apply_rng(hk.transform(net_fn)) + self.parameters = net.init(jax.random.PRNGKey(0), self.data) + + def loss(params, inputs, targets): + log_probs = net.apply(params, inputs) + return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs) + + self.loss_fn = loss + + def jax_hessian_diag(loss_fun, params, inputs, targets): + """This is the 'ground-truth' obtained via the JAX library.""" + hess = jax.hessian(loss_fun)(params, inputs, targets) + + # Extracts the diagonal components. + hess_diag = collections.defaultdict(dict) + for k0, k1 in itertools.product(params.keys(), ['w', 'b']): + params_shape = params[k0][k1].shape + n_params = np.prod(params_shape) + hess_diag[k0][k1] = jnp.diag(hess[k0][k1][k0][k1].reshape( + n_params, n_params)).reshape(params_shape) + for k, v in hess_diag.items(): + hess_diag[k] = v + return jax.flatten_util.ravel_pytree(hess_diag)[0] + + self.hessian = jax_hessian_diag( + self.loss_fn, self.parameters, self.data, self.labels) + + @chex.all_variants + def test_hessian_diag(self): + hessian_diag_fn = self.variant( + functools.partial(second_order.hessian_diag, self.loss_fn)) + actual = hessian_diag_fn(self.parameters, self.data, self.labels) + np.testing.assert_array_almost_equal(self.hessian, actual, 5) + + +if __name__ == '__main__': + absltest.main()