Skip to content

Commit

Permalink
Move second order utilities to dedicated sub-package.
Browse files Browse the repository at this point in the history
This is part of a broader effort to restructure the current fully flat API surface, since optax has outgrown the flat structure due to broader scope and complexity.

PiperOrigin-RevId: 570042410
  • Loading branch information
mtthss authored and OptaxDev committed Oct 3, 2023
1 parent e43f9f5 commit ab8e7a4
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 48 deletions.
7 changes: 1 addition & 6 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from optax import contrib
from optax import experimental
from optax import second_order
from optax._src.alias import adabelief
from optax._src.alias import adafactor
from optax._src.alias import adagrad
Expand Down Expand Up @@ -115,9 +116,6 @@
from optax._src.schedule import sgdr_schedule
from optax._src.schedule import warmup_cosine_decay_schedule
from optax._src.schedule import warmup_exponential_decay_schedule
from optax._src.second_order import fisher_diag
from optax._src.second_order import hessian_diag
from optax._src.second_order import hvp
from optax._src.state_utils import tree_map_params
from optax._src.stochastic_gradient_estimators import measure_valued_jacobians
from optax._src.stochastic_gradient_estimators import pathwise_jacobians
Expand Down Expand Up @@ -234,16 +232,13 @@
"EmptyState",
"exponential_decay",
"FactoredState",
"fisher_diag",
"flatten",
"fromage",
"global_norm",
"GradientTransformation",
"GradientTransformationExtraArgs",
"hinge_loss",
"hessian_diag",
"huber_loss",
"hvp",
"identity",
"incremental_update",
"inject_hyperparams",
Expand Down
19 changes: 19 additions & 0 deletions optax/second_order/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optax: composable gradient processing and optimization, in JAX."""

from optax.second_order.fisher import fisher_diag
from optax.second_order.hessian import hessian_diag
from optax.second_order.hessian import hvp
55 changes: 55 additions & 0 deletions optax/second_order/fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for computing diagonals of the fisher information matrix.
Computing the Fisher matrix for neural networks is typically intractible due to
the quadratic memory requirements. Solving for the diagonal can be done cheaply,
with sub-quadratic memory requirements.
"""

from typing import Any, Callable

import jax
import jax.numpy as jnp

# TODO(b/160876114): use the pytypes defined in Chex.
Array = jnp.ndarray


def _ravel(p: Any) -> Array:
return jax.flatten_util.ravel_pytree(p)[0]


def fisher_diag(
negative_log_likelihood: Callable[[Any, Array, Array], Array],
params: Any,
inputs: jnp.ndarray,
targets: jnp.ndarray,
) -> jax.Array:
"""Computes the diagonal of the (observed) Fisher information matrix.
Args:
negative_log_likelihood: the negative log likelihood function with
expected signature `loss = fn(params, inputs, targets)`.
params: model parameters.
inputs: inputs at which `negative_log_likelihood` is evaluated.
targets: targets at which `negative_log_likelihood` is evaluated.
Returns:
An Array corresponding to the product to the Hessian of
`negative_log_likelihood` evaluated at `(params, inputs, targets)`.
"""
return jnp.square(
_ravel(jax.grad(negative_log_likelihood)(params, inputs, targets)))
53 changes: 11 additions & 42 deletions optax/_src/second_order.py → optax/second_order/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for computing diagonals of Hessians & Fisher info of parameters.
Computing the Hessian or Fisher information matrices for neural networks is
typically intractible due to the quadratic memory requirements. Solving for the
diagonals of these matrices is often a better solution.
This module provides two functions for computing these diagonals, `hessian_diag`
and `fisher_diag`., each with sub-quadratic memory requirements.
"""Functions for computing diagonals of the Hessian wrt to a set of parameters.
Computing the Hessian for neural networks is typically intractible due to the
quadratic memory requirements. Solving for the diagonal can be done cheaply,
with sub-quadratic memory requirements.
"""

from typing import Any, Callable

import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp


# This covers both Jax and Numpy arrays.
# TODO(b/160876114): use the pytypes defined in Chex.
Array = jnp.ndarray
# LossFun of type f(params, inputs, targets).
LossFun = Callable[[Any, Array, Array], Array]


def ravel(p: Any) -> Array:
return ravel_pytree(p)[0]
def _ravel(p: Any) -> Array:
return jax.flatten_util.ravel_pytree(p)[0]


def hvp(
loss: LossFun,
loss: Callable[[Any, Array, Array], Array], # loss(params, inputs, targets)
v: jax.Array,
params: Any,
inputs: jax.Array,
Expand All @@ -61,13 +52,13 @@ def hvp(
An Array corresponding to the product of `v` and the Hessian of `loss`
evaluated at `(params, inputs, targets)`.
"""
_, unravel_fn = ravel_pytree(params)
_, unravel_fn = jax.flatten_util.ravel_pytree(params)
loss_fn = lambda p: loss(p, inputs, targets)
return jax.jvp(jax.grad(loss_fn), [params], [unravel_fn(v)])[1]


def hessian_diag(
loss: LossFun,
loss: Callable[[Any, Array, Array], Array],
params: Any,
inputs: jax.Array,
targets: jax.Array,
Expand All @@ -84,28 +75,6 @@ def hessian_diag(
A DeviceArray corresponding to the product to the Hessian of `loss`
evaluated at `(params, inputs, targets)`.
"""
vs = jnp.eye(ravel(params).size)
comp = lambda v: jnp.vdot(v, ravel(hvp(loss, v, params, inputs, targets)))
vs = jnp.eye(_ravel(params).size)
comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets)))
return jax.vmap(comp)(vs)


def fisher_diag(
negative_log_likelihood: LossFun,
params: Any,
inputs: jnp.ndarray,
targets: jnp.ndarray,
) -> jax.Array:
"""Computes the diagonal of the (observed) Fisher information matrix.
Args:
negative_log_likelihood: the negative log likelihood function.
params: model parameters.
inputs: inputs at which `negative_log_likelihood` is evaluated.
targets: targets at which `negative_log_likelihood` is evaluated.
Returns:
An Array corresponding to the product to the Hessian of
`negative_log_likelihood` evaluated at `(params, inputs, targets)`.
"""
return jnp.square(
ravel(jax.grad(negative_log_likelihood)(params, inputs, targets)))
86 changes: 86 additions & 0 deletions optax/second_order/hessian_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `hessian.py`."""

import collections
import functools
import itertools

from absl.testing import absltest

import chex
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

from optax import second_order


NUM_CLASSES = 2
NUM_SAMPLES = 3
NUM_FEATURES = 4


class HessianTest(chex.TestCase):

def setUp(self):
super().setUp()

self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES)
self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES)

def net_fn(z):
mlp = hk.Sequential(
[hk.Linear(10), jax.nn.relu, hk.Linear(NUM_CLASSES)], name='mlp')
return jax.nn.log_softmax(mlp(z))

net = hk.without_apply_rng(hk.transform(net_fn))
self.parameters = net.init(jax.random.PRNGKey(0), self.data)

def loss(params, inputs, targets):
log_probs = net.apply(params, inputs)
return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs)

self.loss_fn = loss

def jax_hessian_diag(loss_fun, params, inputs, targets):
"""This is the 'ground-truth' obtained via the JAX library."""
hess = jax.hessian(loss_fun)(params, inputs, targets)

# Extracts the diagonal components.
hess_diag = collections.defaultdict(dict)
for k0, k1 in itertools.product(params.keys(), ['w', 'b']):
params_shape = params[k0][k1].shape
n_params = np.prod(params_shape)
hess_diag[k0][k1] = jnp.diag(hess[k0][k1][k0][k1].reshape(
n_params, n_params)).reshape(params_shape)
for k, v in hess_diag.items():
hess_diag[k] = v
return jax.flatten_util.ravel_pytree(hess_diag)[0]

self.hessian = jax_hessian_diag(
self.loss_fn, self.parameters, self.data, self.labels)

@chex.all_variants
def test_hessian_diag(self):
hessian_diag_fn = self.variant(
functools.partial(second_order.hessian_diag, self.loss_fn))
actual = hessian_diag_fn(self.parameters, self.data, self.labels)
np.testing.assert_array_almost_equal(self.hessian, actual, 5)


if __name__ == '__main__':
absltest.main()

0 comments on commit ab8e7a4

Please sign in to comment.