From acabc7d3460f8d30767a3345c8f58323006b9c20 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 19 Oct 2023 17:03:31 +0100 Subject: [PATCH] Move tile rotation to top of IPU Jacobi loop body. (#52) Allows to optimize out one on-tile-copy, saving an additional 10% of cycles. --- tessellate_ipu/linalg/tile_linalg_jacobi.py | 24 ++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 13b75a1..770330b 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -221,11 +221,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu halfN = Apcols.shape[0] with jax.named_scope("jacobi_eigh"): - # with jax.named_scope("Apqcols_rotation"): - # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - # with jax.named_scope("Vpqcols_rotation"): - # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) - # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + with jax.named_scope("Apqcols_rotation"): + Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + with jax.named_scope("Vpqcols_rotation"): + Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Barrier, to make we sync. both set of tiles A and V and force fused comms. + Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) # Sharded constant with p/q indices to ignore in second update stage. with jax.named_scope("rotset_index_ignored"): @@ -274,13 +275,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu Vqcols, ) - # Barrier, to make we sync. both set of tiles A and V - Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) - # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. - with jax.named_scope("Apqcols_rotation"): - Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - with jax.named_scope("Vpqcols_rotation"): - Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + # # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. + # with jax.named_scope("Apqcols_rotation"): + # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + # with jax.named_scope("Vpqcols_rotation"): + # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) return Apcols, Aqcols, Vpcols, Vqcols