Skip to content

Commit

Permalink
Optimized the IPU eigh vertex JacobiUpdateEigenvectors.
Browse files Browse the repository at this point in the history
Very simple optimization, taking advantage of previously optimized
kernel `rotation2d_f32`. 2.5 reduction on vertex cycle counts.
  • Loading branch information
balancap committed Oct 19, 2023
1 parent 0e1939e commit fbb434e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
32 changes: 12 additions & 20 deletions tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,35 +363,26 @@ class JacobiUpdateSecondStep : public MultiVertex {
}
};

template <typename T>
template <class IpuTag, typename T>
void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated,
T* vqcol_updated, T c, T s,
unsigned short wstart,
unsigned short wend) noexcept {
using T2 = float2;
// Using `uint16` seems to be generating more efficient loops?
using IndexType = unsigned short;

const T2 cvec = T2{c, c};
const T2 svec = T2{s, s};
const IndexType wsize = wend - wstart;

using T2 = float2;
const T2 cs_vec = T2{c, s};

// pcol, qcol and results pointers.
const T2* ptr_pcol = reinterpret_cast<const T2*>(vpcol) + wstart;
const T2* ptr_qcol = reinterpret_cast<const T2*>(vqcol) + wstart;
T2* ptr_pcol_updated = reinterpret_cast<T2*>(vpcol_updated) + wstart;
T2* ptr_qcol_updated = reinterpret_cast<T2*>(vqcol_updated) + wstart;

for (IndexType idx = 0; idx != wsize; ++idx) {
const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1);
const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1);

const T2 vpvec_updated = cvec * vpvec - svec * vqvec;
const T2 vqvec_updated = svec * vpvec + cvec * vqvec;

ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1);
ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1);
}
// Apply Schur2 cs rotation to p/q columns (optimized kernel).
rotation2d_f32<IpuTag>(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated,
ptr_qcol_updated, wsize);
}

/**
Expand All @@ -400,8 +391,9 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated,
* See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition,
* Johns Hopkins Chapter 8.
*/
class [[poplar::constraint(
"elem(*vpcol) != elem(*vqcol)")]] JacobiUpdateEigenvectors
class [[poplar::constraint(
"elem(*vpcol) != elem(*vpcol_out)",
"elem(*vqcol) != elem(*vqcol_out)")]] JacobiUpdateEigenvectors
: public MultiVertex {
public:
using T = float;
Expand Down Expand Up @@ -439,12 +431,12 @@ class [[poplar::constraint(
vqcol_out[0] = vqcol[0];
// Swapping pointers if necessary.
if (p <= q) {
jacob_update_eigenvectors(
jacob_update_eigenvectors<IPU_TAG_TYPE>(
vpcol.data() + INDEX_PREFIX, vqcol.data() + INDEX_PREFIX,
vpcol_out.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, c,
s, wstart, wend);
} else {
jacob_update_eigenvectors(
jacob_update_eigenvectors<IPU_TAG_TYPE>(
vqcol.data() + INDEX_PREFIX, vpcol.data() + INDEX_PREFIX,
vqcol_out.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, c,
s, wstart, wend);
Expand Down
6 changes: 3 additions & 3 deletions tests/linalg/test_tile_linalg_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol):
# assert False

def test__jacobi_update_eigenvectors_vertex__benchmark_performance(self):
N = 256
N = 512
tiles = (0,)
cs = np.array([0.2, 0.5], dtype=np.float32)
pcol = np.random.randn(1, N).astype(np.float32)
Expand Down Expand Up @@ -158,13 +158,13 @@ def jacobi_update_eigenvectors_fn(cs, pcol, qcol):
# Cycle count reference for scale_add: 64(375), 128(467), 256(665), 512(1043)
start, end = np.asarray(start)[0], np.asarray(end)[0]
qr_correction_cycle_count = end[0] - start[0]
assert qr_correction_cycle_count <= 2200
assert qr_correction_cycle_count <= 1550
# print("CYCLE count:", qr_correction_cycle_count)
# assert False

@unittest.skipUnless(ipu_num_tiles >= 64, "Requires IPU with 64 tiles")
def test__jacobi_eigh__single_iteration(self):
N = 32
N = 1024
x = np.random.randn(N, N).astype(np.float32)
x = (x + x.T) / 2.0

Expand Down

0 comments on commit fbb434e

Please sign in to comment.