Skip to content

Commit

Permalink
Optimize IPU Jacobil tile_gather copies. (#44)
Browse files Browse the repository at this point in the history
Splitting the `tile_gather` into 2 parts helps limiting the number
of on-tile copies introduce by the Poplar compiler.
  • Loading branch information
balancap authored Oct 9, 2023
1 parent 491275a commit feb639f
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,16 @@ def tile_rotate_columns(pcols: TileShardedArray, qcols: TileShardedArray) -> Tup
# Roughtly: pcols move to the right, qcols to the left.
pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]])
qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]])

# Move columns around!
all_indices = np.concatenate([pcols_indices_new, qcols_indices_new])
all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles)
return all_cols_updated[:halfN], all_cols_updated[halfN:]
pcols_updated = tile_gather(all_cols, pcols_indices_new.tolist(), pcols.tiles)
qcols_updated = tile_gather(all_cols, qcols_indices_new.tolist(), qcols.tiles)
return pcols_updated, qcols_updated

# FIXME: understand why Poplar add a copy with the following code.
# all_indices = np.concatenate([pcols_indices_new, qcols_indices_new])
# all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles)
# return all_cols_updated[:halfN], all_cols_updated[halfN:]


def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tuple[TileShardedArray, ...]:
Expand All @@ -211,6 +217,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu
halfN = Apcols.shape[0]

with jax.named_scope("jacobi_eigh"):
# with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
# with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)

# Sharded constant with p/q indices to ignore in second update stage.
with jax.named_scope("rotset_index_ignored"):
rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles)
Expand Down

0 comments on commit feb639f

Please sign in to comment.