Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move second order utilities to dedicated sub-package. #588

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from optax import contrib
from optax import experimental
from optax import second_order
from optax._src.alias import adabelief
from optax._src.alias import adafactor
from optax._src.alias import adagrad
Expand Down Expand Up @@ -115,9 +116,6 @@
from optax._src.schedule import sgdr_schedule
from optax._src.schedule import warmup_cosine_decay_schedule
from optax._src.schedule import warmup_exponential_decay_schedule
from optax._src.second_order import fisher_diag
from optax._src.second_order import hessian_diag
from optax._src.second_order import hvp
from optax._src.state_utils import tree_map_params
from optax._src.stochastic_gradient_estimators import measure_valued_jacobians
from optax._src.stochastic_gradient_estimators import pathwise_jacobians
Expand Down Expand Up @@ -234,16 +232,13 @@
"EmptyState",
"exponential_decay",
"FactoredState",
"fisher_diag",
"flatten",
"fromage",
"global_norm",
"GradientTransformation",
"GradientTransformationExtraArgs",
"hinge_loss",
"hessian_diag",
"huber_loss",
"hvp",
"identity",
"incremental_update",
"inject_hyperparams",
Expand Down
19 changes: 19 additions & 0 deletions optax/second_order/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The second order optimisation sub-package."""

from optax.second_order.fisher import fisher_diag
from optax.second_order.hessian import hessian_diag
from optax.second_order.hessian import hvp
30 changes: 30 additions & 0 deletions optax/second_order/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base types for the second order sub-package."""

import abc
from typing import Any, Protocol

import jax


class LossFn(Protocol):
"""A loss function to be optimized."""

@abc.abstractmethod
def __call__(
self, params: Any, inputs: jax.Array, targets: jax.Array
) -> jax.Array:
...
55 changes: 55 additions & 0 deletions optax/second_order/fisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for computing diagonals of the fisher information matrix.

Computing the Fisher matrix for neural networks is typically intractible due to
the quadratic memory requirements. Solving for the diagonal can be done cheaply,
with sub-quadratic memory requirements.
"""

from typing import Any

import jax
from jax import flatten_util
import jax.numpy as jnp

from optax.second_order import base


def _ravel(p: Any) -> jax.Array:
return flatten_util.ravel_pytree(p)[0]


def fisher_diag(
negative_log_likelihood: base.LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
) -> jax.Array:
"""Computes the diagonal of the (observed) Fisher information matrix.

Args:
negative_log_likelihood: the negative log likelihood function with
expected signature `loss = fn(params, inputs, targets)`.
params: model parameters.
inputs: inputs at which `negative_log_likelihood` is evaluated.
targets: targets at which `negative_log_likelihood` is evaluated.

Returns:
An Array corresponding to the product to the Hessian of
`negative_log_likelihood` evaluated at `(params, inputs, targets)`.
"""
return jnp.square(
_ravel(jax.grad(negative_log_likelihood)(params, inputs, targets)))
59 changes: 14 additions & 45 deletions optax/_src/second_order.py → optax/second_order/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for computing diagonals of Hessians & Fisher info of parameters.

Computing the Hessian or Fisher information matrices for neural networks is
typically intractible due to the quadratic memory requirements. Solving for the
diagonals of these matrices is often a better solution.

This module provides two functions for computing these diagonals, `hessian_diag`
and `fisher_diag`., each with sub-quadratic memory requirements.
"""Functions for computing diagonals of the Hessian wrt to a set of parameters.

Computing the Hessian for neural networks is typically intractible due to the
quadratic memory requirements. Solving for the diagonal can be done cheaply,
with sub-quadratic memory requirements.
"""

from typing import Any, Callable
from typing import Any

import jax
from jax.flatten_util import ravel_pytree
from jax import flatten_util
import jax.numpy as jnp

from optax.second_order import base

# This covers both Jax and Numpy arrays.
# TODO(b/160876114): use the pytypes defined in Chex.
Array = jnp.ndarray
# LossFun of type f(params, inputs, targets).
LossFun = Callable[[Any, Array, Array], Array]


def ravel(p: Any) -> Array:
return ravel_pytree(p)[0]
def _ravel(p: Any) -> jax.Array:
return flatten_util.ravel_pytree(p)[0]


def hvp(
loss: LossFun,
loss: base.LossFn,
v: jax.Array,
params: Any,
inputs: jax.Array,
Expand All @@ -61,13 +52,13 @@ def hvp(
An Array corresponding to the product of `v` and the Hessian of `loss`
evaluated at `(params, inputs, targets)`.
"""
_, unravel_fn = ravel_pytree(params)
_, unravel_fn = flatten_util.ravel_pytree(params)
loss_fn = lambda p: loss(p, inputs, targets)
return jax.jvp(jax.grad(loss_fn), [params], [unravel_fn(v)])[1]


def hessian_diag(
loss: LossFun,
loss: base.LossFn,
params: Any,
inputs: jax.Array,
targets: jax.Array,
Expand All @@ -84,28 +75,6 @@ def hessian_diag(
A DeviceArray corresponding to the product to the Hessian of `loss`
evaluated at `(params, inputs, targets)`.
"""
vs = jnp.eye(ravel(params).size)
comp = lambda v: jnp.vdot(v, ravel(hvp(loss, v, params, inputs, targets)))
vs = jnp.eye(_ravel(params).size)
comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets)))
return jax.vmap(comp)(vs)


def fisher_diag(
negative_log_likelihood: LossFun,
params: Any,
inputs: jnp.ndarray,
targets: jnp.ndarray,
) -> jax.Array:
"""Computes the diagonal of the (observed) Fisher information matrix.

Args:
negative_log_likelihood: the negative log likelihood function.
params: model parameters.
inputs: inputs at which `negative_log_likelihood` is evaluated.
targets: targets at which `negative_log_likelihood` is evaluated.

Returns:
An Array corresponding to the product to the Hessian of
`negative_log_likelihood` evaluated at `(params, inputs, targets)`.
"""
return jnp.square(
ravel(jax.grad(negative_log_likelihood)(params, inputs, targets)))
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `second_order.py`."""
"""Tests for `hessian.py`."""

import collections
import functools
Expand All @@ -26,15 +26,15 @@
import jax.numpy as jnp
import numpy as np

from optax._src import second_order
from optax import second_order


NUM_CLASSES = 2
NUM_SAMPLES = 3
NUM_FEATURES = 4


class SecondOrderTest(chex.TestCase):
class HessianTest(chex.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -69,7 +69,7 @@ def jax_hessian_diag(loss_fun, params, inputs, targets):
n_params, n_params)).reshape(params_shape)
for k, v in hess_diag.items():
hess_diag[k] = v
return second_order.ravel(hess_diag)
return jax.flatten_util.ravel_pytree(hess_diag)[0]

self.hessian = jax_hessian_diag(
self.loss_fn, self.parameters, self.data, self.labels)
Expand All @@ -81,13 +81,6 @@ def test_hessian_diag(self):
actual = hessian_diag_fn(self.parameters, self.data, self.labels)
np.testing.assert_array_almost_equal(self.hessian, actual, 5)

@chex.all_variants
def test_fisher_diag_shape(self):
fisher_diag_fn = self.variant(
functools.partial(second_order.fisher_diag, self.loss_fn))
fisher_diagonal = fisher_diag_fn(self.parameters, self.data, self.labels)
chex.assert_equal_shape([fisher_diagonal, self.hessian])


if __name__ == '__main__':
absltest.main()
Loading