-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move second order utilities to dedicated sub-package.
This is part of a broader effort to restructure the current fully flat API surface, since optax has outgrown the flat structure due to increased scope and complexity. PiperOrigin-RevId: 570042410
- Loading branch information
Showing
6 changed files
with
203 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Optax: composable gradient processing and optimization, in JAX.""" | ||
|
||
from optax.second_order.fisher import fisher_diag | ||
from optax.second_order.hessian import hessian_diag | ||
from optax.second_order.hessian import hvp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Base types for the second order sub-package.""" | ||
|
||
import abc | ||
from typing import Any, Protocol | ||
|
||
import jax | ||
|
||
|
||
class LossFn(Protocol): | ||
"""A loss function to be optimized.""" | ||
|
||
@abc.abstractmethod | ||
def __call__( | ||
self, params: Any, inputs: jax.Array, targets: jax.Array | ||
) -> jax.Array: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Functions for computing diagonals of the fisher information matrix. | ||
Computing the Fisher matrix for neural networks is typically intractible due to | ||
the quadratic memory requirements. Solving for the diagonal can be done cheaply, | ||
with sub-quadratic memory requirements. | ||
""" | ||
|
||
from typing import Any | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
from optax.second_order import base | ||
|
||
|
||
def _ravel(p: Any) -> jax.Array: | ||
return jax.flatten_util.ravel_pytree(p)[0] | ||
|
||
|
||
def fisher_diag( | ||
negative_log_likelihood: base.LossFn, | ||
params: Any, | ||
inputs: jax.Array, | ||
targets: jax.Array, | ||
) -> jax.Array: | ||
"""Computes the diagonal of the (observed) Fisher information matrix. | ||
Args: | ||
negative_log_likelihood: the negative log likelihood function with | ||
expected signature `loss = fn(params, inputs, targets)`. | ||
params: model parameters. | ||
inputs: inputs at which `negative_log_likelihood` is evaluated. | ||
targets: targets at which `negative_log_likelihood` is evaluated. | ||
Returns: | ||
An Array corresponding to the product to the Hessian of | ||
`negative_log_likelihood` evaluated at `(params, inputs, targets)`. | ||
""" | ||
return jnp.square( | ||
_ravel(jax.grad(negative_log_likelihood)(params, inputs, targets))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for `hessian.py`.""" | ||
|
||
import collections | ||
import functools | ||
import itertools | ||
|
||
from absl.testing import absltest | ||
|
||
import chex | ||
import haiku as hk | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from optax import second_order | ||
|
||
|
||
NUM_CLASSES = 2 | ||
NUM_SAMPLES = 3 | ||
NUM_FEATURES = 4 | ||
|
||
|
||
class HessianTest(chex.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
|
||
self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES) | ||
self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES) | ||
|
||
def net_fn(z): | ||
mlp = hk.Sequential( | ||
[hk.Linear(10), jax.nn.relu, hk.Linear(NUM_CLASSES)], name='mlp') | ||
return jax.nn.log_softmax(mlp(z)) | ||
|
||
net = hk.without_apply_rng(hk.transform(net_fn)) | ||
self.parameters = net.init(jax.random.PRNGKey(0), self.data) | ||
|
||
def loss(params, inputs, targets): | ||
log_probs = net.apply(params, inputs) | ||
return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs) | ||
|
||
self.loss_fn = loss | ||
|
||
def jax_hessian_diag(loss_fun, params, inputs, targets): | ||
"""This is the 'ground-truth' obtained via the JAX library.""" | ||
hess = jax.hessian(loss_fun)(params, inputs, targets) | ||
|
||
# Extracts the diagonal components. | ||
hess_diag = collections.defaultdict(dict) | ||
for k0, k1 in itertools.product(params.keys(), ['w', 'b']): | ||
params_shape = params[k0][k1].shape | ||
n_params = np.prod(params_shape) | ||
hess_diag[k0][k1] = jnp.diag(hess[k0][k1][k0][k1].reshape( | ||
n_params, n_params)).reshape(params_shape) | ||
for k, v in hess_diag.items(): | ||
hess_diag[k] = v | ||
return jax.flatten_util.ravel_pytree(hess_diag)[0] | ||
|
||
self.hessian = jax_hessian_diag( | ||
self.loss_fn, self.parameters, self.data, self.labels) | ||
|
||
@chex.all_variants | ||
def test_hessian_diag(self): | ||
hessian_diag_fn = self.variant( | ||
functools.partial(second_order.hessian_diag, self.loss_fn)) | ||
actual = hessian_diag_fn(self.parameters, self.data, self.labels) | ||
np.testing.assert_array_almost_equal(self.hessian, actual, 5) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |