diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 13b75a1..770330b 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -221,11 +221,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu halfN = Apcols.shape[0] with jax.named_scope("jacobi_eigh"): - # with jax.named_scope("Apqcols_rotation"): - # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - # with jax.named_scope("Vpqcols_rotation"): - # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) - # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + with jax.named_scope("Apqcols_rotation"): + Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + with jax.named_scope("Vpqcols_rotation"): + Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Barrier, to make we sync. both set of tiles A and V and force fused comms. + Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) # Sharded constant with p/q indices to ignore in second update stage. with jax.named_scope("rotset_index_ignored"): @@ -274,13 +275,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu Vqcols, ) - # Barrier, to make we sync. both set of tiles A and V - Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) - # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. - with jax.named_scope("Apqcols_rotation"): - Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - with jax.named_scope("Vpqcols_rotation"): - Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + # # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. + # with jax.named_scope("Apqcols_rotation"): + # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + # with jax.named_scope("Vpqcols_rotation"): + # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) return Apcols, Aqcols, Vpcols, Vqcols