From 736e9b0f5cfc0461872b678686875119f7359b5b Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Sun, 1 Oct 2023 17:23:09 +0100 Subject: [PATCH] IPU Jacobi eigh `fori_loop` transition This PR is switching the IPU Jacobi `eigh` implementation from a Python loop to a JAX `fori_loop`, reducing massively code size and allowing Eigen decomposition on larger matrices. Note: in terms of performance, some issues still remain as Poplar compiler is adding some extra on-tile copies which in theory could be eluded. --- .../core/vertex/tile_jacobi_vertex.cpp | 264 +++++++++++------- tessellate_ipu/linalg/tile_linalg_jacobi.py | 239 +++++++++------- 2 files changed, 306 insertions(+), 197 deletions(-) diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 4c33707..0cb74c0 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -78,6 +78,59 @@ class JacobiSymSchur2 : public Vertex { } }; +template +void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, + T* qcol_updated, T* cs, unsigned p, unsigned q, + unsigned short wstart, + unsigned short wend) noexcept { + using T2 = float2; + using IndexType = unsigned short; + + const T Apq = pcol[q]; + const T App = pcol[p]; + const T Aqq = qcol[q]; + + // Schur2 decomposition. + const T2 cs_vec = sym_schur2(App, Aqq, Apq); + const T& c = cs_vec[0]; + const T& s = cs_vec[1]; + cs[0] = c; + cs[1] = s; + + // Worker load: start + end vectorized indexes. + constexpr unsigned ptr_step = 1; + const IndexType wsize = wend - wstart; + + // pcol, qcol and results pointers. + const float2* ptr_pcol = reinterpret_cast(pcol) + wstart; + const float2* ptr_qcol = reinterpret_cast(qcol) + wstart; + float2* ptr_pcol_updated = reinterpret_cast(pcol_updated) + wstart; + float2* ptr_qcol_updated = reinterpret_cast(qcol_updated) + wstart; + + const T2 cvec = T2{c, c}; + const T2 svec = T2{s, s}; + + // Easier to vectorized + parallelize if start with normal update first. + for (IndexType idx = 0; idx != wsize; ++idx) { + // TODO: investigate assembly? + const T2 pvec = ipu::load_postinc(&ptr_pcol, 1); + const T2 qvec = ipu::load_postinc(&ptr_qcol, 1); + + const T2 pvec_updated = cvec * pvec - svec * qvec; + const T2 qvec_updated = svec * pvec + cvec * qvec; + + ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1); + ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1); + } + + // Update main values App, Apq, Aqq + pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; + qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; + // Zero on purpose with Schur decomposition! + pcol_updated[q] = 0; + qcol_updated[p] = 0; +} + /** * @brief Jacobi algorithm, update first step: schur + column update. * @@ -92,78 +145,55 @@ class [[poplar::constraint("elem(*pcol) != elem(*qcol)")]] JacobiUpdateFirstStep // Using `uint16` seems to be generating more efficient loops? using IndexType = unsigned short; - Input> - rotset; // (2,) rotation index p and q. p < q - Input> pcol; // (N,) p column - Input> qcol; // (N,) q column + // p/q cols + index prefix (2 x uint32). + Input> pcol; // (N + 2,) p column + Input> qcol; // (N + 2,) q column Input> worker_offsets; // (7,) threads work size + 1. + Output> + rotset_sorted; // (3,) rotset index sorted + was sorted? Output> cs; // (2,) (c, s) Schur decomposition values Output> - pcol_updated; // (N,) p column updated + pcol_updated; // (N + 2,) p column updated Output> - qcol_updated; // (N,) q column updated - - const IndexType N; // size + qcol_updated; // (N + 2,) q column updated JacobiUpdateFirstStep(); bool compute(unsigned wid) { - const unsigned p = rotset[0]; - const unsigned q = rotset[1]; - const T Apq = pcol[q]; - const T App = pcol[p]; - const T Aqq = qcol[q]; + // Size of the index prefix in pcol and qcol. + constexpr int INDEX_PREFIX = 2; + const unsigned p = *((unsigned*)pcol.data()); + const unsigned q = *((unsigned*)qcol.data()); - // Schur2 decomposition. - const T2 cs_vec = sym_schur2(App, Aqq, Apq); - const T& c = cs_vec[0]; - const T& s = cs_vec[1]; - cs[0] = c; - cs[1] = s; - - // Worker load: start + end vectorized indexes. - constexpr unsigned ptr_step = 1; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; - // pcol, qcol and results pointers. - const float2* ptr_pcol = - reinterpret_cast(pcol.data()) + wstart; - const float2* ptr_qcol = - reinterpret_cast(qcol.data()) + wstart; - float2* ptr_pcol_updated = - reinterpret_cast(pcol_updated.data()) + wstart; - float2* ptr_qcol_updated = - reinterpret_cast(qcol_updated.data()) + wstart; - - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; - - // Easier to vectorized + parallelize if start with normal update first. - for (IndexType idx = 0; idx != wsize; ++idx) { - // TODO: investigate assembly? - const T2 pvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 qvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 pvec_updated = cvec * pvec - svec * qvec; - const T2 qvec_updated = svec * pvec + cvec * qvec; - - ipu::store_postinc(&ptr_pcol_updated, pvec_updated, 1); - ipu::store_postinc(&ptr_qcol_updated, qvec_updated, 1); + // Forward p/q indices. + pcol_updated[0] = pcol[0]; + qcol_updated[0] = qcol[0]; + + if (p <= q) { + // Proper ordering of p and q already. + jacob_update_first_step( + pcol.data() + INDEX_PREFIX, qcol.data() + INDEX_PREFIX, + pcol_updated.data() + INDEX_PREFIX, + qcol_updated.data() + INDEX_PREFIX, cs.data(), p, q, wstart, wend); + rotset_sorted[0] = p; + rotset_sorted[1] = q; + } else { + // Swap p and q columns as q < p + jacob_update_first_step( + qcol.data() + INDEX_PREFIX, pcol.data() + INDEX_PREFIX, + qcol_updated.data() + INDEX_PREFIX, + pcol_updated.data() + INDEX_PREFIX, cs.data(), q, p, wstart, wend); + rotset_sorted[0] = q; + rotset_sorted[1] = p; } - - // Update main values App, Apq, Aqq - pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; - qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; - // Zero on purpose with Schur decomposition! - pcol_updated[q] = 0; - qcol_updated[p] = 0; return true; } }; @@ -178,65 +208,104 @@ class JacobiUpdateSecondStep : public MultiVertex { InOut> cs_arr; // (N/2, 2) (c, s) values Input> - rotset_arr; // (N/2, 2) (p, q) array values. p < q + rotset_sorted_arr; // (N/2, 2) (p, q) array values. p < q Input> rotset_idx_ignored; // (1,) index in rotset to ignore. Input> worker_offsets; // (7,) threads work size + 1. - Input> pcol; // (N,) p column - Input> qcol; // (N,) q column + Input> pcol; // (N+2,) p column + Input> qcol; // (N+2,) q column Output> - pcol_updated; // (N,) p column updated + pcol_updated; // (N+2,) p column updated Output> - qcol_updated; // (N,) q column updated - - // const unsigned ignore_idx; // cs/pq index to ignore. - const IndexType halfN; // N / 2 + qcol_updated; // (N+2,) q column updated JacobiUpdateSecondStep(); bool compute(unsigned wid) { - // Use (p, q) = (1, 0) for ignore idx. - const unsigned ignore_idx = 2 * rotset_idx_ignored[0]; - cs_arr[ignore_idx] = 1; - cs_arr[ignore_idx + 1] = 0; - + // Size of the index prefix in pcol and qcol. + constexpr int INDEX_PREFIX = 2; // Worker load: start + end vectorized indexes. constexpr unsigned ptr_step = 1; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; const IndexType wsize = wend - wstart; + // Use (p, q) = (1, 0) for ignore idx. + const unsigned ignore_idx = 2 * rotset_idx_ignored[0]; + cs_arr[ignore_idx] = 1; + cs_arr[ignore_idx + 1] = 0; + + auto pcol_ptr = pcol.data() + INDEX_PREFIX; + auto qcol_ptr = qcol.data() + INDEX_PREFIX; + auto pcol_updated_ptr = pcol_updated.data() + INDEX_PREFIX; + auto qcol_updated_ptr = qcol_updated.data() + INDEX_PREFIX; + + // Forward pq indices. + pcol_updated[0] = pcol[0]; + qcol_updated[0] = qcol[0]; + // Parallized loop on update using other columns coefficients - // for (IndexType half_idx = 0; half_idx != halfN; ++half_idx) { for (IndexType half_idx = 0; half_idx != wsize; ++half_idx) { - const unsigned k = rotset_arr[2 * half_idx + 2 * wstart]; - const unsigned l = rotset_arr[2 * half_idx + 1 + 2 * wstart]; + // TODO: cleaning pq indices offset. + const unsigned k = rotset_sorted_arr[2 * half_idx + 2 * wstart]; + const unsigned l = rotset_sorted_arr[2 * half_idx + 1 + 2 * wstart]; const T c = cs_arr[2 * half_idx + 2 * wstart]; const T s = cs_arr[2 * half_idx + 1 + 2 * wstart]; // 4 coefficients updates! // TODO: vectorization?! - const T Spk = pcol[k]; - const T Spl = pcol[l]; + const T Spk = pcol_ptr[k]; + const T Spl = pcol_ptr[l]; - const T Sqk = qcol[k]; - const T Sql = qcol[l]; + const T Sqk = qcol_ptr[k]; + const T Sql = qcol_ptr[l]; - pcol_updated[k] = c * Spk - s * Spl; - pcol_updated[l] = s * Spk + c * Spl; + pcol_updated_ptr[k] = c * Spk - s * Spl; + pcol_updated_ptr[l] = s * Spk + c * Spl; - qcol_updated[k] = c * Sqk - s * Sql; - qcol_updated[l] = s * Sqk + c * Sql; + qcol_updated_ptr[k] = c * Sqk - s * Sql; + qcol_updated_ptr[l] = s * Sqk + c * Sql; } return true; } }; +template +void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, + T* vqcol_updated, T c, T s, + unsigned short wstart, + unsigned short wend) noexcept { + using T2 = float2; + // Using `uint16` seems to be generating more efficient loops? + using IndexType = unsigned short; + + const T2 cvec = T2{c, c}; + const T2 svec = T2{s, s}; + const IndexType wsize = wend - wstart; + + // pcol, qcol and results pointers. + const T2* ptr_pcol = reinterpret_cast(vpcol) + wstart; + const T2* ptr_qcol = reinterpret_cast(vqcol) + wstart; + T2* ptr_pcol_updated = reinterpret_cast(vpcol_updated) + wstart; + T2* ptr_qcol_updated = reinterpret_cast(vqcol_updated) + wstart; + + for (IndexType idx = 0; idx != wsize; ++idx) { + const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); + const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); + + const T2 vpvec_updated = cvec * vpvec - svec * vqvec; + const T2 vqvec_updated = svec * vpvec + cvec * vqvec; + + ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); + ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); + } +} + /** * @brief Jacobi algorithm, update of eigen vectors matrix. * @@ -268,32 +337,29 @@ class [[poplar::constraint( JacobiUpdateEigenvectors(); bool compute(unsigned wid) { + constexpr int INDEX_PREFIX = 2; + const unsigned p = *((unsigned*)vpcol.data()); + const unsigned q = *((unsigned*)vqcol.data()); + const T c = cs[0]; const T s = cs[1]; - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; - - // Worker load: start + end vectorized indexes. - constexpr unsigned ptr_step = 1; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; - - // pcol, qcol and results pointers. - const T2* ptr_pcol = reinterpret_cast(vpcol.data()) + wstart; - const T2* ptr_qcol = reinterpret_cast(vqcol.data()) + wstart; - T2* ptr_pcol_updated = reinterpret_cast(vpcol_out.data()) + wstart; - T2* ptr_qcol_updated = reinterpret_cast(vqcol_out.data()) + wstart; - for (IndexType idx = 0; idx != wsize; ++idx) { - const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 vpvec_updated = cvec * vpvec - svec * vqvec; - const T2 vqvec_updated = svec * vpvec + cvec * vqvec; - - ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); - ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); + // Forwarding p/q (prefix) indices. + vpcol_out[0] = vpcol[0]; + vqcol_out[0] = vqcol[0]; + // Swapping pointers if necessary. + if (p <= q) { + jacob_update_eigenvectors( + vpcol.data() + INDEX_PREFIX, vqcol.data() + INDEX_PREFIX, + vpcol_out.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, c, + s, wstart, wend); + } else { + jacob_update_eigenvectors( + vqcol.data() + INDEX_PREFIX, vpcol.data() + INDEX_PREFIX, + vqcol_out.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, c, + s, wstart, wend); } return true; } diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 1e489ee..24f1249 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -7,10 +7,10 @@ import numpy as np from jax.core import ShapedArray +# import tessellate_ipu from tessellate_ipu import ( TileShardedArray, create_ipu_tile_primitive, - tile_constant_replicated, tile_constant_sharded, tile_data_barrier, tile_gather, @@ -41,11 +41,16 @@ def get_jacobi_vertex_gp_filename() -> str: jacobi_update_first_step_p = create_ipu_tile_primitive( "jacobi_update_first_step", "JacobiUpdateFirstStep", - inputs=["rotset", "pcol", "qcol"], - outputs={"cs": ShapedArray((2,), dtype=np.float32), "pcol_updated": 1, "qcol_updated": 2}, + inputs=["pcol", "qcol"], + outputs={ + "rotset_sorted": ShapedArray((2,), dtype=np.uint32), + "cs": ShapedArray((2,), dtype=np.float32), + "pcol_updated": 0, + "qcol_updated": 1, + }, constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[1].size, vector_size=2, wdtype=np.uint16 + inavals[0].size - 2, vector_size=2, wdtype=np.uint16 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -56,11 +61,11 @@ def get_jacobi_vertex_gp_filename() -> str: jacobi_update_second_step_p = create_ipu_tile_primitive( "jacobi_update_second_step", "JacobiUpdateSecondStep", - inputs=["cs_arr", "rotset_arr", "rotset_idx_ignored", "pcol", "qcol"], + inputs=["cs_arr", "rotset_sorted_arr", "rotset_idx_ignored", "pcol", "qcol"], outputs={"cs_arr": 0, "pcol_updated": 3, "qcol_updated": 4}, constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size, vector_size=2, wdtype=np.uint16 + inavals[3].size - 2, vector_size=2, wdtype=np.uint16 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -74,7 +79,10 @@ def get_jacobi_vertex_gp_filename() -> str: outputs={"vpcol_out": 1, "vqcol_out": 2}, # Bug when inplace update? constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[1].size, vector_size=2, wdtype=np.uint16 + # Remove 2 for pq indices prefix. + inavals[1].size - 2, + vector_size=2, + wdtype=np.uint16, ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -88,6 +96,21 @@ def jacobi_initial_rotation_set(N: int) -> NDArray[np.uint32]: return rot +def jacobi_initial_pqindices(N: int) -> Tuple[NDArray[np.uint32], NDArray[np.uint32]]: + """Jacobi initial p/q indices arrays. + Padded to (N/2, 2) for 64bits alignment. + + Returns: + A tuple of p/q indices arrays. + """ + rotset = jacobi_initial_rotation_set(N) + pindices = rotset[:, :1] + qindices = rotset[:, 1:] + pindices = np.concatenate([pindices, pindices], axis=1) + qindices = np.concatenate([qindices, qindices], axis=1) + return (pindices, qindices) + + def jacobi_next_rotation_set(rot: NDArray[np.uint32]) -> NDArray[np.uint32]: """Jacobi next rotation set (N/2, 2). @@ -112,82 +135,145 @@ def jacobi_sort_rotation_set(rotset: NDArray[np.uint32]) -> NDArray[np.uint32]: return np.stack([pindices, qindices], axis=-1) -def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtiles: Any) -> Tuple[Array, ...]: - """IPU Eigen decomposition: single iteration of the Jacobi algorithm. +def tile_rotate_columns(pcols: TileShardedArray, qcols: TileShardedArray) -> Tuple[TileShardedArray, TileShardedArray]: + """Rotate columns between tiles using a static `tile_gather`. - NOTE: the goal is to have a function which can be easily combined with `fori_loop`. + We follow the Jacobi rotation patterns between tiles. In short + - moving `pcols` to the "left" + - moving `qcols` to the "right" + """ + assert pcols.shape == qcols.shape + assert pcols.tiles == qcols.tiles + halfN = pcols.shape[0] + N = halfN * 2 + # Concat all columns, in order to perform a single gather. + all_cols = TileShardedArray( + jax.lax.concatenate([pcols.array, qcols.array], dimension=0), (*pcols.tiles, *qcols.tiles) + ) + + pcols_indices = np.arange(0, halfN, dtype=np.int32) + qcols_indices = np.arange(halfN, N, dtype=np.int32) + # Rotation of columns between tiles (see Jacobi alg.) + # Roughtly: pcols move to the right, qcols to the left. + pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]]) + qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]]) + # Move columns around! + all_indices = np.concatenate([pcols_indices_new, qcols_indices_new]) + all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles) + return all_cols_updated[:halfN], all_cols_updated[halfN:] + + +def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tuple[TileShardedArray, ...]: + """IPU Jacobi eigen-decomposition algorithm main body. Args: - all_AV_cols: A and V matrices p/q columns. - Atiles: A matrix tiles. - Vtiles: V matrix tiles. + idx: Loop index. + inputs: Tile sharded Apcols, Aqcols, Vpcols, Vqcols Returns: - Tuple of updated A and V matrices p/q columns. + Apcols, Aqcols, Vpcols, Vqcols after a main Jacobi update. """ - Apcols, Aqcols, Vpcols, Vqcols = all_AV_cols - N = Apcols.shape[-1] - halfN = N // 2 - # TODO: check compatibility of TileShardedArray with fori_loop - # Shard arrays across tiles. - Apcols = tile_put_sharded(Apcols, tiles=Atiles) - Aqcols = tile_put_sharded(Aqcols, tiles=Atiles) - # Initial eigenvectors (identity matrix). - Vpcols = tile_put_sharded(Vpcols, tiles=Vtiles) - Vqcols = tile_put_sharded(Vqcols, tiles=Vtiles) - # Constant tensor of index to ignored at every iteration. - rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles) - rotset = jacobi_initial_rotation_set(N) - - # All different size 2 partitions on columns. - for _ in range(1, N): - # Sorted rotation set: p < q indices. - rotset_sorted = jacobi_sort_rotation_set(rotset) - # On tile constant rotation set tensor building. - with jax.named_scope("rotset"): - rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles) - rotset_sharded = tile_constant_sharded(rotset_sorted, tiles=Atiles) + Apcols, Aqcols, Vpcols, Vqcols = inputs + Atiles = Apcols.tiles + Vtiles = Vpcols.tiles + halfN = Apcols.shape[0] + + with jax.named_scope("jacobi_eigh"): + # with jax.named_scope("Apqcols_rotation"): + # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + # with jax.named_scope("Vpqcols_rotation"): + # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + + # Sharded constant with p/q indices to ignore in second update stage. + with jax.named_scope("rotset_index_ignored"): + rotset_index_ignored = tile_constant_sharded(np.arange(0, halfN, dtype=np.uint32), tiles=Atiles) # Compute Schur decomposition + on-tile update of columns. - cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore - jacobi_update_first_step_p, rotset_sharded, Apcols, Aqcols, N=N + # Note: not expecting p < q. Input pcols/qcols sorted inside the vertex. + rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore + jacobi_update_first_step_p, Apcols, Aqcols ) - # Replicate Schur decomposition across all A tiles: (2*N//2) comms. + # Replicate Schur decomposition + rotset across all A tiles: (2*N//2) comms. + with jax.named_scope("rotset_sorted_replicated"): + rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles) with jax.named_scope("cs_replicated_sharded"): cs_replicated = tile_put_replicated(cs_per_tile.array, tiles=Atiles) - # Just copy Schur decomposition to associated V tiles. - cs_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles) - cs_replicated, cs_Vtiles = tile_data_barrier(cs_replicated, cs_Vtiles) + # Just copy Schur decomposition to associated V tiles (no need to replicate). + cs_sharded_Vtiles = tile_put_sharded(cs_per_tile.array, tiles=Vtiles) + # Barrier to force all communications to be fused. + cs_replicated, cs_sharded_Vtiles, rotset_sorted_replicated = tile_data_barrier( + cs_replicated, cs_sharded_Vtiles, rotset_sorted_replicated + ) # Second Jacobi update step. + # Note: does not require sorting of pcols and qcols. cs_replicated, Apcols, Aqcols = tile_map( # type:ignore jacobi_update_second_step_p, cs_replicated, - rotset_replicated, + rotset_sorted_replicated, rotset_index_ignored, Apcols, Aqcols, - halfN=halfN, ) # Jacobi eigenvectors update step. Vpcols, Vqcols = tile_map( # type:ignore jacobi_update_eigenvectors_p, - cs_Vtiles, + cs_sharded_Vtiles, Vpcols, Vqcols, ) # Barrier, to make we sync. both set of tiles A and V Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) - # Move columns between tiles. 2*N commns per tile. - # NOTE: this inter-tile comm is keeping the p < q property on A and V columns. + # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. with jax.named_scope("Apqcols_rotation"): - Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset) + Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) with jax.named_scope("Vpqcols_rotation"): - Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols, rotset) - # Next rotation set. - rotset = jacobi_next_rotation_set(rotset) + Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + return Apcols, Aqcols, Vpcols, Vqcols + + +def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtiles: Any) -> Tuple[Array, ...]: + """IPU Eigen decomposition: single iteration of the Jacobi algorithm. + + NOTE: the goal is to have a function which can be easily combined with `fori_loop`. + + Args: + all_AV_cols: A and V matrices p/q columns. + Atiles: A matrix tiles. + Vtiles: V matrix tiles. + Returns: + Tuple of updated A and V matrices p/q columns. + """ + Apcols, Aqcols, Vpcols, Vqcols = all_AV_cols + N = Apcols.shape[-1] + assert N % 2 == 0 + + # p/q indices used as prefix in p/q columns. + # concatenating with the index + data is helping handling index book keeping in Jacobi algorithm iteration. + pindices, qindices = jacobi_initial_pqindices(N) + pindices_prefix = tile_constant_sharded(pindices.view(np.float32), tiles=Vtiles) + qindices_prefix = tile_constant_sharded(qindices.view(np.float32), tiles=Vtiles) + + Apcols = jax.lax.concatenate([pindices_prefix.array, Apcols], dimension=1) + Aqcols = jax.lax.concatenate([qindices_prefix.array, Aqcols], dimension=1) + + Vpcols = jax.lax.concatenate([pindices_prefix.array, Vpcols], dimension=1) + Vqcols = jax.lax.concatenate([qindices_prefix.array, Vqcols], dimension=1) + + # TODO: check compatibility of TileShardedArray with fori_loop + # Shard arrays across tiles. + Apcols = tile_put_sharded(Apcols, tiles=Atiles) + Aqcols = tile_put_sharded(Aqcols, tiles=Atiles) + # Initial eigenvectors (identity matrix). + Vpcols = tile_put_sharded(Vpcols, tiles=Vtiles) + Vqcols = tile_put_sharded(Vqcols, tiles=Vtiles) + + # Jacobi eigh iteration as a single fori_loop. + Apcols, Aqcols, Vpcols, Vqcols = jax.lax.fori_loop(1, N, ipu_jacobi_eigh_body, (Apcols, Aqcols, Vpcols, Vqcols)) - return (Apcols.array, Aqcols.array, Vpcols.array, Vqcols.array) + # Remove the p/q indices prefix. TODO: do it once instead of at every iteration! + return (Apcols.array[:, 2:], Aqcols.array[:, 2:], Vpcols.array[:, 2:], Vqcols.array[:, 2:]) def ipu_jacobi_eigh(x: Array, num_iters: int = 1) -> Tuple[Array, Array]: @@ -205,8 +291,9 @@ def ipu_jacobi_eigh(x: Array, num_iters: int = 1) -> Tuple[Array, Array]: assert N <= 1024 halfN = N // 2 - Atiles = tuple(range(0, halfN)) - Vtiles = tuple(range(halfN, 2 * halfN)) + tile_offset = 1 + Atiles = tuple(range(tile_offset, halfN + tile_offset)) + Vtiles = tuple(range(halfN + tile_offset, 2 * halfN + tile_offset)) # Initial "eigenvalues" matrix. Apcols = jax.lax.slice_in_dim(x, 0, N, stride=2) Aqcols = jax.lax.slice_in_dim(x, 1, N, stride=2) @@ -248,50 +335,6 @@ def permute_pq_indices( return (np.where(rotset_permute_mask, pindices, qindices), np.where(rotset_permute_mask, qindices, pindices)) -def tile_rotate_columns( - pcols: TileShardedArray, qcols: TileShardedArray, rotset: NDArray[np.uint32] -) -> Tuple[TileShardedArray, TileShardedArray]: - """Rotate columns between tiles using a static `tile_gather`. - - The tricky part of this function is to rotate the columns between tiles, but - keep the property p < q, which means taking care of the present sorting permutation applied - as well the next sorting permutation. - """ - assert pcols.shape == qcols.shape - assert pcols.tiles == qcols.tiles - halfN = pcols.shape[0] - N = halfN * 2 - # Concat all columns, in order to perform a single gather. - all_cols = TileShardedArray( - jax.lax.concatenate([pcols.array, qcols.array], dimension=0), (*pcols.tiles, *qcols.tiles) - ) - - # Start with current indices, in the concat representation of columns - pcols_indices = np.arange(0, halfN, dtype=np.int32) - qcols_indices = np.arange(halfN, N, dtype=np.int32) - # First sorting permutation correction. - rotset_permute_mask = rotset[:, 0] < rotset[:, 1] - pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask) - - # Rotation of columns between tiles (see Jacobi alg.) - # Roughtly: pcols move to the right, qcols to the left. - pcols_indices_new = np.concatenate([pcols_indices[0:1], qcols_indices[0:1], pcols_indices[1:-1]]) - qcols_indices_new = np.concatenate([qcols_indices[1:], pcols_indices[-1:]]) - pcols_indices, qcols_indices = pcols_indices_new, qcols_indices_new - assert len(pcols_indices_new) == halfN - assert len(qcols_indices_new) == halfN - - # Second sorting permutation correction, using the next rotation set. - rotset = jacobi_next_rotation_set(rotset) - rotset_permute_mask = rotset[:, 0] < rotset[:, 1] - pcols_indices, qcols_indices = permute_pq_indices(pcols_indices, qcols_indices, rotset_permute_mask) - - # Move columns around + re-split between pcols and qcols. - all_indices = np.concatenate([pcols_indices, qcols_indices]) - all_cols_updated = tile_gather(all_cols, all_indices.tolist(), all_cols.tiles) - return all_cols_updated[:halfN], all_cols_updated[halfN:] - - def ipu_eigh( x: Array, *, lower: bool = True, symmetrize_input: bool = False, sort_eigenvalues: bool = True, num_iters: int = 1 ) -> Tuple[Array, Array]: