Skip to content

Commit

Permalink
Optimize IPU Jacobil tile_gather copies.
Browse files Browse the repository at this point in the history
Splitting the `tile_gather` into 2 parts helps limiting the number
of on-tile copies introduce by the Poplar compiler.
  • Loading branch information
balancap committed Oct 9, 2023
1 parent 491275a commit bd9f991
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
4 changes: 4 additions & 0 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated,
*/
class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep
: public MultiVertex {
// class JacobiUpdateFirstStep
// : public MultiVertex {
public:
using T = float;
using T2 = float2;
Expand Down Expand Up @@ -315,6 +317,8 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated,
class [[poplar::constraint(
"elem(*vpcol) != elem(*vqcol)")]] JacobiUpdateEigenvectors
: public MultiVertex {
// class JacobiUpdateEigenvectors
// : public MultiVertex {
public:
using T = float;
using T2 = float2;
Expand Down
18 changes: 15 additions & 3 deletions tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,16 @@ def tile_rotate_columns(pcols: TileShardedArray, qcols: TileShardedArray) -> Tup
# Roughtly: pcols move to the right, qcols to the left.
pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]])
qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]])

# Move columns around!
all_indices = np.concatenate([pcols_indices_new, qcols_indices_new])
all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles)
return all_cols_updated[:halfN], all_cols_updated[halfN:]
pcols_updated = tile_gather(all_cols, pcols_indices_new.tolist(), pcols.tiles)
qcols_updated = tile_gather(all_cols, qcols_indices_new.tolist(), qcols.tiles)
return pcols_updated, qcols_updated

# FIXME: understand why Poplar add a copy with the following code.
# all_indices = np.concatenate([pcols_indices_new, qcols_indices_new])
# all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles)
# return all_cols_updated[:halfN], all_cols_updated[halfN:]


def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tuple[TileShardedArray, ...]:
Expand All @@ -211,6 +217,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu
halfN = Apcols.shape[0]

with jax.named_scope("jacobi_eigh"):
# with jax.named_scope("Apqcols_rotation"):
# Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols)
# with jax.named_scope("Vpqcols_rotation"):
# Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols)
# Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols)

# Sharded constant with p/q indices to ignore in second update stage.
with jax.named_scope("rotset_index_ignored"):
rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles)
Expand Down

0 comments on commit bd9f991

Please sign in to comment.