Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Oct 2, 2023
1 parent 45933c4 commit 232a69d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 29 deletions.
42 changes: 15 additions & 27 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ class [[poplar::constraint(
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> vpcol; // (N,) p column
Input<Vector<T, poplar::VectorLayout::ONE_PTR, 8>> vqcol; // (N,) q column

Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
vpindex; // (1,) p index
Input<Vector<unsigned, poplar::VectorLayout::ONE_PTR>>
vqindex; // (1,) q index

Input<Vector<IndexType, poplar::VectorLayout::ONE_PTR>>
worker_offsets; // (7,) threads work size + 1.

Expand All @@ -334,38 +339,21 @@ class [[poplar::constraint(
JacobiUpdateEigenvectors();

bool compute(unsigned wid) {
const unsigned p = vpindex[0];
const unsigned q = vqindex[0];

const T c = cs[0];
const T s = cs[1];
const IndexType wstart = worker_offsets[wid];
const IndexType wend = worker_offsets[wid + 1];

jacob_update_eigenvectors(vpcol.data(), vqcol.data(), vpcol_out.data(),
vqcol_out.data(), c, s, wstart, wend);
if (p <= q) {
jacob_update_eigenvectors(vpcol.data(), vqcol.data(), vpcol_out.data(),
vqcol_out.data(), c, s, wstart, wend);
} else {
jacob_update_eigenvectors(vqcol.data(), vpcol.data(), vqcol_out.data(),
vpcol_out.data(), c, s, wstart, wend);
}
return true;

// const T2 cvec = T2{c, c};
// const T2 svec = T2{s, s};

// // Worker load: start + end vectorized indexes.
// constexpr unsigned ptr_step = 1;
// const IndexType wsize = wend - wstart;

// // pcol, qcol and results pointers.
// const T2* ptr_pcol = reinterpret_cast<const T2*>(vpcol.data()) + wstart;
// const T2* ptr_qcol = reinterpret_cast<const T2*>(vqcol.data()) + wstart;
// T2* ptr_pcol_updated = reinterpret_cast<T2*>(vpcol_out.data()) + wstart;
// T2* ptr_qcol_updated = reinterpret_cast<T2*>(vqcol_out.data()) + wstart;

// for (IndexType idx = 0; idx != wsize; ++idx) {
// const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1);
// const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1);

// const T2 vpvec_updated = cvec * vpvec - svec * vqvec;
// const T2 vqvec_updated = svec * vpvec + cvec * vqvec;

// ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1);
// ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1);
// }
// return true;
}
};
7 changes: 6 additions & 1 deletion tessellate_ipu/linalg/tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_jacobi_vertex_gp_filename() -> str:
jacobi_update_eigenvectors_p = create_ipu_tile_primitive(
"jacobi_update_eigenvectors",
"JacobiUpdateEigenvectors",
inputs=["cs", "vpcol", "vqcol"],
inputs=["cs", "vpcol", "vqcol", "vpindex", "vqindex"],
outputs={"vpcol_out": 1, "vqcol_out": 2}, # Bug when inplace update?
constants={
"worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets(
Expand Down Expand Up @@ -198,6 +198,9 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
cs_replicated, cs_sharded_Vtiles, rotset_sorted_replicated
)

Vpindices_sharded = tile_put_sharded(pindices_sharded.array, tiles=Vtiles)
Vqindices_sharded = tile_put_sharded(qindices_sharded.array, tiles=Vtiles)

# Second Jacobi update step.
# Note: does not require sorting of pcols and qcols.
cs_replicated, Apcols, Aqcols = tile_map( # type:ignore
Expand All @@ -215,6 +218,8 @@ def ipu_jacobi_eigh_iteration(all_AV_cols: Tuple[Array, ...], Atiles: Any, Vtile
cs_sharded_Vtiles,
Vpcols,
Vqcols,
Vpindices_sharded,
Vqindices_sharded,
)

# Barrier, to make we sync. both set of tiles A and V
Expand Down
2 changes: 1 addition & 1 deletion tests/linalg/test_tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test__jacobi_eigh_raw__proper_eigh_result(self):
eigvalues_sorted = eigvalues[indices]
eigvectors_sorted = VT[indices].T
npt.assert_array_almost_equal(eigvalues_sorted, expected_eigvalues, decimal=5)
# npt.assert_array_almost_equal(np.abs(eigvectors_sorted), np.abs(expected_eigvectors), decimal=5)
npt.assert_array_almost_equal(np.abs(eigvectors_sorted), np.abs(expected_eigvectors), decimal=5)

@unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles")
def test__jacobi_eigh__not_sorting(self):
Expand Down

0 comments on commit 232a69d

Please sign in to comment.