Skip to content

Commit

Permalink
Move tile rotation to top of IPU Jacobi loop body. (#52)
Browse files Browse the repository at this point in the history
Allows to optimize out one on-tile-copy, saving an additional 10% of
cycles.
  • Loading branch information
balancap authored Oct 19, 2023
1 parent b891bd9 commit acabc7d
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu
halfN = Apcols.shape[0]

with jax.named_scope("jacobi_eigh"):
# with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
# with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
with jax.named_scope("Apqcols_rotation"):
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
with jax.named_scope("Vpqcols_rotation"):
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Barrier, to make we sync. both set of tiles A and V and force fused comms.
Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)

# Sharded constant with p/q indices to ignore in second update stage.
with jax.named_scope("rotset_index_ignored"):
Expand Down Expand Up @@ -274,13 +275,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu
Vqcols,
)

# Barrier, to make we sync. both set of tiles A and V
Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
# Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile.
with jax.named_scope("Apqcols_rotation"):
Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
with jax.named_scope("Vpqcols_rotation"):
Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)
# # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile.
# with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
# with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
return Apcols, Aqcols, Vpcols, Vqcols


Expand Down

0 comments on commit acabc7d

Please sign in to comment.