diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 18ccde5..4967608 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -323,6 +323,11 @@ class [[poplar::constraint( Input> vpcol; // (N,) p column Input> vqcol; // (N,) q column + Input> + vpindex; // (1,) p index + Input> + vqindex; // (1,) q index + Input> worker_offsets; // (7,) threads work size + 1. @@ -334,38 +339,21 @@ class [[poplar::constraint( JacobiUpdateEigenvectors(); bool compute(unsigned wid) { + const unsigned p = vpindex[0]; + const unsigned q = vqindex[0]; + const T c = cs[0]; const T s = cs[1]; const IndexType wstart = worker_offsets[wid]; const IndexType wend = worker_offsets[wid + 1]; - jacob_update_eigenvectors(vpcol.data(), vqcol.data(), vpcol_out.data(), - vqcol_out.data(), c, s, wstart, wend); + if (p <= q) { + jacob_update_eigenvectors(vpcol.data(), vqcol.data(), vpcol_out.data(), + vqcol_out.data(), c, s, wstart, wend); + } else { + jacob_update_eigenvectors(vqcol.data(), vpcol.data(), vqcol_out.data(), + vpcol_out.data(), c, s, wstart, wend); + } return true; - - // const T2 cvec = T2{c, c}; - // const T2 svec = T2{s, s}; - - // // Worker load: start + end vectorized indexes. - // constexpr unsigned ptr_step = 1; - // const IndexType wsize = wend - wstart; - - // // pcol, qcol and results pointers. - // const T2* ptr_pcol = reinterpret_cast(vpcol.data()) + wstart; - // const T2* ptr_qcol = reinterpret_cast(vqcol.data()) + wstart; - // T2* ptr_pcol_updated = reinterpret_cast(vpcol_out.data()) + wstart; - // T2* ptr_qcol_updated = reinterpret_cast(vqcol_out.data()) + wstart; - - // for (IndexType idx = 0; idx != wsize; ++idx) { - // const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); - // const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); - - // const T2 vpvec_updated = cvec * vpvec - svec * vqvec; - // const T2 vqvec_updated = svec * vpvec + cvec * vqvec; - - // ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); - // ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); - // } - // return true; } }; diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 2765de0..fab6ca3 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -76,7 +76,7 @@ def get_jacobi_vertex_gp_filename() -> str: jacobi_update_eigenvectors_p = create_ipu_tile_primitive( "jacobi_update_eigenvectors", "JacobiUpdateEigenvectors", - inputs=["cs", "vpcol", "vqcol"], + inputs=["cs", "vpcol", "vqcol", "vpindex", "vqindex"], outputs={"vpcol_out": 1, "vqcol_out": 2}, # Bug when inplace update? constants={ "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( @@ -198,6 +198,9 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile cs_replicated, cs_sharded_Vtiles, rotset_sorted_replicated ) + Vpindices_sharded = tile_put_sharded(pindices_sharded.array, tiles=Vtiles) + Vqindices_sharded = tile_put_sharded(qindices_sharded.array, tiles=Vtiles) + # Second Jacobi update step. # Note: does not require sorting of pcols and qcols. cs_replicated, Apcols, Aqcols = tile_map( # type:ignore @@ -215,6 +218,8 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile cs_sharded_Vtiles, Vpcols, Vqcols, + Vpindices_sharded, + Vqindices_sharded, ) # Barrier, to make we sync. both set of tiles A and V diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index 90490cc..92fe4ad 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -197,7 +197,7 @@ def test__jacobi_eigh_raw__proper_eigh_result(self): eigvalues_sorted = eigvalues[indices] eigvectors_sorted = VT[indices].T npt.assert_array_almost_equal(eigvalues_sorted, expected_eigvalues, decimal=5) - # npt.assert_array_almost_equal(np.abs(eigvectors_sorted), np.abs(expected_eigvectors), decimal=5) + npt.assert_array_almost_equal(np.abs(eigvectors_sorted), np.abs(expected_eigvectors), decimal=5) @unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles") def test__jacobi_eigh__not_sorting(self):