Skip to content

Commit

Permalink
added correctness test
Browse files Browse the repository at this point in the history
  • Loading branch information
paolot-gc committed Oct 27, 2023
1 parent 922e4d4 commit 20be2d8
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tessellate_ipu/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from .tile_linalg_hessenberg import ipu_hessenberg
from .tile_linalg_jacobi import ipu_eigh
from .tile_linalg_qr import ipu_qr
from .tile_linalg_tridiagonal_eigh import ipu_eigh_hess, ipu_tridiagonal_eigh
from .tile_linalg_tridiagonal_eigh import ipu_hess_eigh, ipu_tridiagonal_eigh
from .tile_linalg_tridiagonal_solver import ipu_tridiag_solve
21 changes: 10 additions & 11 deletions tessellate_ipu/linalg/tile_linalg_tridiagonal_eigenvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@

jax.config.FLAGS.jax_platform_name = "cpu"

vertex_filename = osp.join(osp.dirname(__file__), "../core", "vertex", "tile_tridiagonal_eigh.cpp")
grad = create_ipu_tile_primitive(
"Sturm",
"Sturm",
inputs=["alpha", "beta_sq", "pivmin", "alpha0_pertubation", "x", "id", "out_shape", "lower", "mid", "upper"],
outputs={"lower_out": 7, "mid_out": 8, "upper_out": 9},
gp_filename=vertex_filename,
perf_estimate=100,
)


def ipu_tridiagonal_eigenvalue(d, e, *, select="a", select_range=None, tol=None):
alpha, beta = jnp.asarray(d), jnp.asarray(e)
Expand Down Expand Up @@ -64,16 +74,6 @@ def ipu_tridiagonal_eigenvalue(d, e, *, select="a", select_range=None, tol=None)
pivmin = jnp.broadcast_to(pivmin, target_shape)
alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

vertex_filename = osp.join(osp.dirname(__file__), "../core", "vertex", "tile_tridiagonal_eigh.cpp")
grad = create_ipu_tile_primitive(
"Sturm",
"Sturm",
inputs=["alpha", "beta_sq", "pivmin", "alpha0_pertubation", "x", "id", "out_shape", "lower", "mid", "upper"],
outputs={"lower_out": 7, "mid_out": 8, "upper_out": 9},
gp_filename=vertex_filename,
perf_estimate=100,
)

x = mid
n = x.shape[0]
tiles = tuple(range(n))
Expand Down Expand Up @@ -106,7 +106,6 @@ def body(j, args):
import jax
import scipy

np.random.seed(42)
jax.config.FLAGS.jax_platform_name = "cpu"
np.random.seed(42)

Expand Down
12 changes: 6 additions & 6 deletions tessellate_ipu/linalg/tile_linalg_tridiagonal_eigh.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
jax.config.update("jax_enable_x64", False)


def ipu_tridiagonal_eigh(d, e, n_iter=2, seed=42):
def ipu_tridiagonal_eigh(d, e, num_iters=2, seed=42):

N = d.shape[0]
eig = ipu_tridiagonal_eigenvalue(d, e)[:N]
Expand All @@ -33,19 +33,19 @@ def inverse_iteration(i, x):
x /= jnp.linalg.norm(x, axis=1)[:, jnp.newaxis]
return x

x = jax.lax.fori_loop(0, n_iter, inverse_iteration, x)
x = jax.lax.fori_loop(0, num_iters, inverse_iteration, x)

return x, eig


def ipu_eigh_hess(M):
def ipu_hess_eigh(M, num_iters=2):

Q, M_tri_ = ipu_hessenberg(M)
M_tri = M_tri_.array
d, e = jnp.diag(M_tri), jnp.diag(M_tri, k=1)

x, eig = ipu_tridiagonal_eigh(d, e)
return x @ Q.array.T, eig
x, eig = ipu_tridiagonal_eigh(d, e, num_iters)
return eig, x @ Q.array.T


if __name__ == "__main__":
Expand Down Expand Up @@ -96,7 +96,7 @@ def ipu_eigh_hess(M):
print("Specify one of the options -r or -f")
sys.exit(1)

x, eig = jax.jit(ipu_eigh_hess, backend="ipu")(mat)
eig, x = jax.jit(ipu_hess_eigh, backend="ipu")(mat)

x = np.array(x)

Expand Down
50 changes: 50 additions & 0 deletions tests/linalg/test_tile_linalg_eig_hess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
import unittest

import chex
import jax
import numpy as np
import numpy.testing as npt
import pytest
from absl.testing import parameterized

from tessellate_ipu.linalg.tile_linalg_tridiagonal_eigh import ipu_hess_eigh
from tessellate_ipu.utils import IpuTargetType

# Skipping some tests if no local IPU hardware.
ipu_hw_available = len(jax.devices("ipu")) > 0 and jax.devices("ipu")[0].target_type == IpuTargetType.IPU
ipu_num_tiles = jax.devices("ipu")[0].num_tiles


@pytest.mark.ipu_hardware
class IpuTileLinalgHessEigh(chex.TestCase, parameterized.TestCase):
def setUp(self):
self.device = jax.devices("ipu")[0]
self.num_tiles = self.device.num_tiles
np.random.seed(42)

@unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles")
@parameterized.parameters(
{"N": 4},
# {"N": 512},
)
def test__hess_eigh_raw__proper_eigh_result(self, N):
x = np.random.randn(N, N).astype(np.float32)
x = (x + x.T) / 2.0

hess_eigh_fn = jax.jit(ipu_hess_eigh, backend="ipu")
# Should be enough iterations...
eigvalues, VT = hess_eigh_fn(x, num_iters=2)
eigvalues = np.asarray(eigvalues).reshape(-1)
VT = np.asarray(VT)
# Expected eigen values and vectors (from Lapack?)
expected_eigvalues, expected_eigvectors = np.linalg.eigh(x)

# Order raw outputs.
indices = np.argsort(eigvalues)
eigvalues_sorted = eigvalues[indices]
eigvectors_sorted = VT[indices].T
npt.assert_array_almost_equal(eigvalues_sorted, expected_eigvalues, decimal=5)
npt.assert_array_almost_equal(np.abs(eigvectors_sorted), np.abs(expected_eigvectors), decimal=5)

# TODO: Performance test

0 comments on commit 20be2d8

Please sign in to comment.