From feb639f49f8948148c6b978fe06246fc0555389f Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 9 Oct 2023 16:40:15 +0100 Subject: [PATCH] Optimize IPU Jacobil `tile_gather` copies. (#44) Splitting the `tile_gather` into 2 parts helps limiting the number of on-tile copies introduce by the Poplar compiler. --- tessellate_ipu/linalg/tile_linalg_jacobi.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 3cbfa5f..5319b56 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -190,10 +190,16 @@ def tile_rotate_columns(pcols: TileShardedArray, qcols: TileShardedArray) -> Tup # Roughtly: pcols move to the right, qcols to the left. pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]]) qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]]) + # Move columns around! - all_indices = np.concatenate([pcols_indices_new, qcols_indices_new]) - all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles) - return all_cols_updated[:halfN], all_cols_updated[halfN:] + pcols_updated = tile_gather(all_cols, pcols_indices_new.tolist(), pcols.tiles) + qcols_updated = tile_gather(all_cols, qcols_indices_new.tolist(), qcols.tiles) + return pcols_updated, qcols_updated + + # FIXME: understand why Poplar add a copy with the following code. + # all_indices = np.concatenate([pcols_indices_new, qcols_indices_new]) + # all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles) + # return all_cols_updated[:halfN], all_cols_updated[halfN:] def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tuple[TileShardedArray, ...]: @@ -211,6 +217,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu halfN = Apcols.shape[0] with jax.named_scope("jacobi_eigh"): + # with jax.named_scope("Apqcols_rotation"): + # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + # with jax.named_scope("Vpqcols_rotation"): + # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + # Sharded constant with p/q indices to ignore in second update stage. with jax.named_scope("rotset_index_ignored"): rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles)