diff --git a/tessellate_ipu/linalg/__init__.py b/tessellate_ipu/linalg/__init__.py index 659a04f..cd0b099 100644 --- a/tessellate_ipu/linalg/__init__.py +++ b/tessellate_ipu/linalg/__init__.py @@ -4,5 +4,5 @@ from .tile_linalg_hessenberg import ipu_hessenberg from .tile_linalg_jacobi import ipu_eigh from .tile_linalg_qr import ipu_qr -from .tile_linalg_tridiagonal_eigh import ipu_eigh_hess, ipu_tridiagonal_eigh +from .tile_linalg_tridiagonal_eigh import ipu_hess_eigh, ipu_tridiagonal_eigh from .tile_linalg_tridiagonal_solver import ipu_tridiag_solve diff --git a/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigenvalue.py b/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigenvalue.py index 8f2a249..ea83bcb 100755 --- a/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigenvalue.py +++ b/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigenvalue.py @@ -9,6 +9,16 @@ jax.config.FLAGS.jax_platform_name = "cpu" +vertex_filename = osp.join(osp.dirname(__file__), "../core", "vertex", "tile_tridiagonal_eigh.cpp") +grad = create_ipu_tile_primitive( + "Sturm", + "Sturm", + inputs=["alpha", "beta_sq", "pivmin", "alpha0_pertubation", "x", "id", "out_shape", "lower", "mid", "upper"], + outputs={"lower_out": 7, "mid_out": 8, "upper_out": 9}, + gp_filename=vertex_filename, + perf_estimate=100, +) + def ipu_tridiagonal_eigenvalue(d, e, *, select="a", select_range=None, tol=None): alpha, beta = jnp.asarray(d), jnp.asarray(e) @@ -64,16 +74,6 @@ def ipu_tridiagonal_eigenvalue(d, e, *, select="a", select_range=None, tol=None) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) - vertex_filename = osp.join(osp.dirname(__file__), "../core", "vertex", "tile_tridiagonal_eigh.cpp") - grad = create_ipu_tile_primitive( - "Sturm", - "Sturm", - inputs=["alpha", "beta_sq", "pivmin", "alpha0_pertubation", "x", "id", "out_shape", "lower", "mid", "upper"], - outputs={"lower_out": 7, "mid_out": 8, "upper_out": 9}, - gp_filename=vertex_filename, - perf_estimate=100, - ) - x = mid n = x.shape[0] tiles = tuple(range(n)) @@ -106,7 +106,6 @@ def body(j, args): import jax import scipy - np.random.seed(42) jax.config.FLAGS.jax_platform_name = "cpu" np.random.seed(42) diff --git a/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigh.py b/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigh.py index b3162f8..1db999e 100644 --- a/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigh.py +++ b/tessellate_ipu/linalg/tile_linalg_tridiagonal_eigh.py @@ -11,7 +11,7 @@ jax.config.update("jax_enable_x64", False) -def ipu_tridiagonal_eigh(d, e, n_iter=2, seed=42): +def ipu_tridiagonal_eigh(d, e, num_iters=2, seed=42): N = d.shape[0] eig = ipu_tridiagonal_eigenvalue(d, e)[:N] @@ -33,19 +33,19 @@ def inverse_iteration(i, x): x /= jnp.linalg.norm(x, axis=1)[:, jnp.newaxis] return x - x = jax.lax.fori_loop(0, n_iter, inverse_iteration, x) + x = jax.lax.fori_loop(0, num_iters, inverse_iteration, x) return x, eig -def ipu_eigh_hess(M): +def ipu_hess_eigh(M, num_iters=2): Q, M_tri_ = ipu_hessenberg(M) M_tri = M_tri_.array d, e = jnp.diag(M_tri), jnp.diag(M_tri, k=1) - x, eig = ipu_tridiagonal_eigh(d, e) - return x @ Q.array.T, eig + x, eig = ipu_tridiagonal_eigh(d, e, num_iters) + return eig, x @ Q.array.T if __name__ == "__main__": @@ -96,7 +96,7 @@ def ipu_eigh_hess(M): print("Specify one of the options -r or -f") sys.exit(1) - x, eig = jax.jit(ipu_eigh_hess, backend="ipu")(mat) + eig, x = jax.jit(ipu_hess_eigh, backend="ipu")(mat) x = np.array(x) diff --git a/tests/linalg/test_tile_linalg_eig_hess.py b/tests/linalg/test_tile_linalg_eig_hess.py new file mode 100644 index 0000000..4daca94 --- /dev/null +++ b/tests/linalg/test_tile_linalg_eig_hess.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 Graphcore Ltd. All rights reserved. +import unittest + +import chex +import jax +import numpy as np +import numpy.testing as npt +import pytest +from absl.testing import parameterized + +from tessellate_ipu.linalg.tile_linalg_tridiagonal_eigh import ipu_hess_eigh +from tessellate_ipu.utils import IpuTargetType + +# Skipping some tests if no local IPU hardware. +ipu_hw_available = len(jax.devices("ipu")) > 0 and jax.devices("ipu")[0].target_type == IpuTargetType.IPU +ipu_num_tiles = jax.devices("ipu")[0].num_tiles + + +@pytest.mark.ipu_hardware +class IpuTileLinalgHessEigh(chex.TestCase, parameterized.TestCase): + def setUp(self): + self.device = jax.devices("ipu")[0] + self.num_tiles = self.device.num_tiles + np.random.seed(42) + + @unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles") + @parameterized.parameters( + {"N": 4}, + # {"N": 512}, + ) + def test__hess_eigh_raw__proper_eigh_result(self, N): + x = np.random.randn(N, N).astype(np.float32) + x = (x + x.T) / 2.0 + + hess_eigh_fn = jax.jit(ipu_hess_eigh, backend="ipu") + # Should be enough iterations... + eigvalues, VT = hess_eigh_fn(x, num_iters=2) + eigvalues = np.asarray(eigvalues).reshape(-1) + VT = np.asarray(VT) + # Expected eigen values and vectors (from Lapack?) + expected_eigvalues, expected_eigvectors = np.linalg.eigh(x) + + # Order raw outputs. + indices = np.argsort(eigvalues) + eigvalues_sorted = eigvalues[indices] + eigvectors_sorted = VT[indices].T + npt.assert_array_almost_equal(eigvalues_sorted, expected_eigvalues, decimal=5) + npt.assert_array_almost_equal(np.abs(eigvectors_sorted), np.abs(expected_eigvectors), decimal=5) + + # TODO: Performance test