diff --git a/examples/hessenberg_example.py b/examples/hessenberg_example.py deleted file mode 100644 index e5b12ec..0000000 --- a/examples/hessenberg_example.py +++ /dev/null @@ -1,30 +0,0 @@ -import sys - -import jax -import numpy as np - -from tessellate_ipu.linalg import ipu_hessenberg - -jax.config.FLAGS.jax_platform_name = "cpu" -jax.config.update("jax_enable_x64", False) - - -d = int(sys.argv[1]) -np.random.seed(42) - -np.set_printoptions(precision=3, linewidth=120, suppress=True) - -A = np.random.normal(0, 1, (d, d)) -A = (A + A.T) / 2 - -Q, R = jax.jit(ipu_hessenberg, backend="ipu")(A) - -Q_ = np.array(Q.array) -R_ = np.array(R.array) -print("\nR matrix") -print(R_) -print("\nQ matrix") -print(Q_) -print(f"\nReconstruction Delta: {np.max(np.abs(Q_ @ R_ @ Q_.T - A))}") -print("\nQ.T @ Q") -print(Q_.T @ Q_) diff --git a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp index 858d9aa..08249a0 100644 --- a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp @@ -12,57 +12,9 @@ using namespace poplar; */ static constexpr size_t MIN_ALIGN = 8; -class [[poplar::constraint("elem(*x) != elem(*y)")]] DotProduct1dIndexedVertex - : public MultiVertex { - public: - using T = float; - using T2 = float2; - // Using `uint16` seems to be generating more efficient loops? - using IndexType = unsigned short; - - - Input> - x; // (N,) x vector - Input> - y; // (N,) y vector - Input> - start_idx; - - Input> - worker_offsets; // (7,) number threads + 1. - Output> partials; // float result. - - bool compute(unsigned wid) { - // Always assuming size % 2 == 0 - const IndexType wstart = worker_offsets[wid]; - const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; - - const IndexType index = start_idx[0]; - - T2* ptr_tmp_partials_f2 = reinterpret_cast(partials.data()) + wid; - // Nothing to do in this worker thread. - if (wstart == wend) { - ipu::store_postinc(&ptr_tmp_partials_f2, T2{0, 0}, 1); - return true; - } - // X and Y input pointers. - const T2* ptr_inxdata_f2 = reinterpret_cast(x.data()) + wstart; - const T2* ptr_inydata_f2 = reinterpret_cast(y.data()) + wstart; - T2 partial = T2{0, 0}; - - for (IndexType idx = 0; idx != wsize; ++idx) { - // TODO: use ld2x64pace + tapack instructions? - const T2 xin = ipu::load_postinc(&ptr_inxdata_f2, 1); - const T2 yin = ipu::load_postinc(&ptr_inydata_f2, 1); - // popc seems to recognize this pattern and optimize it. - // Using directly ipu::fma intrinsics leads to poor performance!? - partial += xin * yin; - } - ipu::store_postinc(&ptr_tmp_partials_f2, partial, 1); - return true; - } -}; +/* + The code here is just a minor modification of tile_qr_vertex.cpp +*/ /** * @brief Vertex computing the correction vector in the Hessenberg algorithm. diff --git a/tessellate_ipu/linalg/tile_linalg_hessenberg.py b/tessellate_ipu/linalg/tile_linalg_hessenberg.py index f41b08b..33896c9 100644 --- a/tessellate_ipu/linalg/tile_linalg_hessenberg.py +++ b/tessellate_ipu/linalg/tile_linalg_hessenberg.py @@ -4,7 +4,7 @@ from typing import Any, Tuple import jax.lax -import numpy as np # used for np.float32, shouldn't we use jax types? +import numpy as np from jax.core import ShapedArray from tessellate_ipu import ( @@ -21,26 +21,13 @@ Array = Any +# The code here is heavily based on tile_linalg_qr.py + def get_hessenberg_vertex_gp_filename() -> str: return os.path.join(os.path.dirname(__file__), "../core", "vertex", "tile_hessenberg_vertex.cpp") -dot_product1d_indexed_p = create_ipu_tile_primitive( - "dot_product1d_indexed", - "DotProduct1dIndexedVertex", - inputs=["x", "y", "start_idx"], - outputs={"partials": ShapedArray((12,), dtype=np.float32)}, - constants={ - "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[0].size, vector_size=2, num_workers=6, wdtype=np.uint16 - ) - }, - # tmp_space=ShapedArray((12,), dtype=np.float32), - gp_filename=get_hessenberg_vertex_gp_filename(), - perf_estimate=1000, -) - """Vertex computing Hessenberg correction vector. """ hessenberg_correction_vector_p = create_ipu_tile_primitive( @@ -81,17 +68,21 @@ def ipu_hessenberg_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArr assert x.shape[0] == x.shape[1] N = x.shape[0] n_tiles = 1472 - # Sharding R and Q - - n_per_tile = math.ceil(N / float(n_tiles)) - full_tiles = N % n_tiles - if full_tiles == 0: - full_tiles = n_tiles - Q_tiles = [i for i in range(full_tiles) for _ in range(n_per_tile)] + [ - i for i in range(full_tiles, n_tiles) for _ in range(n_per_tile - 1) - ] - R_tiles = Q_tiles + # Sharding R and Q + if N <= 736: + Q_tiles = list(range(N)) + R_tiles = list(range(N, 2 * N)) + else: + n_per_tile = math.ceil(N / float(n_tiles)) + full_tiles = N % n_tiles + if full_tiles == 0: + full_tiles = n_tiles + + Q_tiles = [i for i in range(full_tiles) for _ in range(n_per_tile)] + [ + i for i in range(full_tiles, n_tiles) for _ in range(n_per_tile - 1) + ] + R_tiles = Q_tiles # TODO: on-device construction of identity Q = tile_put_sharded(np.identity(N, dtype=x.dtype), Q_tiles) @@ -101,20 +92,20 @@ def ipu_hessenberg_shard_inputs(x: Array, xsdiag: Array) -> Tuple[TileShardedArr return Q, R, sdiag_full -# Heavily based on ipu_qr_iterations in tile_linalg_qr.py # The body of the for-loop computes # v = Householder(R[i]) # v is chosen to annihilate the elements below the first lower diagonal # R = R - 2 * v.reshape(-1, 1) @ (v.reshape(1, -1) @ R) # R = R - 2 * (R @ v.reshape(-1, 1)) @ v.reshape(1, -1) # Not present in QR algorithm # Q = Q - 2 * (Q @ v.reshape(-1, 1)) @ v.reshape(1, -1) - - def ipu_hessenberg_body( i: int, carry: Tuple[TileShardedArray, TileShardedArray, TileShardedArray] ) -> Tuple[TileShardedArray, TileShardedArray, TileShardedArray]: Q, R, sdiag_full = carry + # Extract the i-th col of R and the i-th element of sdiag_full + # Using the gather_p primitive avoids inefficient general-case processing + dim_numbers = jax.lax.GatherDimensionNumbers(offset_dims=tuple(), collapsed_slice_dims=(0,), start_index_map=(0,)) i_rep = tile_put_replicated(jax.numpy.array([[i]], dtype=np.uint32), R.tiles) @@ -131,6 +122,8 @@ def ipu_hessenberg_body( fill_value=None, ) # => TileShardedArray() (Num_tiles, 1) + # This determines also where the computation of v (Householder correction vector) takes place + # For now, the tile is picked arbitrarily. Are there better choices? R.tiles[0]? Rcol_replicated = tile_put_replicated(Rcol.array, tiles=[736]) # type:ignore sdiag = tile_map( @@ -147,24 +140,17 @@ def ipu_hessenberg_body( sdiag_rep = tile_put_replicated(sdiag.array, Rcol_replicated.tiles) # type:ignore + # Smart-indexing # start_idx = (i // 2) * 2 start_idx = 0 start_idxQ = tile_put_replicated(start_idx, Q.tiles) start_idxR = tile_put_replicated(start_idx, R.tiles) - # Alternative: we pass the whole RT and sdiag; then we extract the result from the i-th tile - - # Correction vector. Computed + # Correction vector. Computed on the tile where Rcol is located v, vrescale = tile_map( hessenberg_correction_vector_p, Rcol_replicated, sdiag_rep, tile_put_replicated(i + 1, Rcol_replicated.tiles) ) # type:ignore - # v, vrescale = tile_map( - # hessenberg_correction_vector_p, RT, sdiag_full, tile_put_replicated(i + 1, RT.tiles) - # ) # type:ignore - - # This compiles - # vi = tile_gather(v.array[i], [0], [0] # Replicate to all Q and R tiles. vQ = tile_put_replicated(v.array, Q.tiles) # 0 @@ -173,13 +159,7 @@ def ipu_hessenberg_body( vrescaleQ = tile_put_replicated(vrescale.array, Q.tiles) # 0 vrescaleR = tile_put_replicated(vrescale.array, R.tiles) # 0 - # Alternative using tile_gather - # vQ = tile_gather(v, [i]*len(Q.tiles), list(Q.tiles), copy=False) # 0 - # vR = tile_gather(v, [i]*len(RT.tiles), RT.tiles) # 0 - # # v normalization factor to pass to householder update. - # vrescaleQ = tile_gather(vrescale, [i]*len(Q.tiles), Q.tiles) # 0 - # vrescaleR = tile_gather(vrescale, [i]*len(RT.tiles), RT.tiles) # - + # Transpose R so that we can use hessenberg_householder_row_update_p() to compute R @ ... RT = tile_put_sharded(R.array.T, R.tiles) # w = R^T @ v @@ -195,6 +175,9 @@ def ipu_hessenberg_body( hessenberg_householder_row_update_p, RT, vR, w, vrescaleR, start_idxR # type:ignore ) + # We compute the Q updates. + # It is done here and is followed by tile_data_barrier() because this induces the Poplar + # to schedule it in parallel to the RT updates, when RT and Q are mapped on disjoint tiles. # w = Q @ v # w = tile_map(dot_product1d_indexed_p, vQ, Q, start_idxQ) w = tile_map(dot_product1d_p, vQ, Q) @@ -205,7 +188,7 @@ def ipu_hessenberg_body( ) RT, Q = tile_data_barrier(RT, Q) - # Transpose the RT matrix so that we can do the right product + # Transpose the RT matrix so that we can use hessenberg_householder_row_update_p() to compute ... @ R R = tile_put_sharded(RT.array.T, RT.tiles) # w = R^T @ v diff --git a/tests/linalg/test_tile_linalg_hessenberg.py b/tests/linalg/test_tile_linalg_hessenberg.py index 262e4d1..cbd59c1 100644 --- a/tests/linalg/test_tile_linalg_hessenberg.py +++ b/tests/linalg/test_tile_linalg_hessenberg.py @@ -69,4 +69,4 @@ def hessenberg_decomposition_fn(x, xsdiag): start, end = np.asarray(start)[0], np.asarray(end)[0] hessenberg_cycle_count = end[0] - start[0] - assert hessenberg_cycle_count <= 105000 + assert hessenberg_cycle_count <= 150000