From 1a8d53e1586eea68b3b2869600df2d8147528972 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 18 Oct 2023 15:59:36 +0000 Subject: [PATCH] Fix IPU Jacobi eigh algorithm when size % 4 == 2 The recent improvement in PR #49 introduced a regression in Jacobi `eigh`, raising an error when size % 4 == 2. This is due to the partial loop unrolling in Jacobi second update stage. This PR is fixing the issue by passing explicitely the offset and size of the workload to the vertex. --- tessellate_ipu/core/__init__.py | 1 + .../core/tile_interpreter_vertex_utils.py | 65 +++++++++++++++++-- .../core/vertex/tile_jacobi_vertex.cpp | 15 ++--- tessellate_ipu/lax/tile_lax_small_dot.py | 3 +- .../linalg/tile_linalg_hessenberg.py | 2 +- tessellate_ipu/linalg/tile_linalg_jacobi.py | 9 +-- tessellate_ipu/linalg/tile_linalg_qr.py | 2 +- .../test_tile_interpreter_vertex_utils.py | 20 ++++++ tests/linalg/test_tile_linalg_jacobi.py | 7 +- 9 files changed, 99 insertions(+), 25 deletions(-) diff --git a/tessellate_ipu/core/__init__.py b/tessellate_ipu/core/__init__.py index 46fb340..9efdfed 100644 --- a/tessellate_ipu/core/__init__.py +++ b/tessellate_ipu/core/__init__.py @@ -45,6 +45,7 @@ primitive_clone, primitive_num_inout_alias_args, ) +from .tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets, make_ipu_vector1d_worker_offsets_and_sizes def tessellate_ipu_cleanup(): diff --git a/tessellate_ipu/core/tile_interpreter_vertex_utils.py b/tessellate_ipu/core/tile_interpreter_vertex_utils.py index 1389281..52e2e08 100644 --- a/tessellate_ipu/core/tile_interpreter_vertex_utils.py +++ b/tessellate_ipu/core/tile_interpreter_vertex_utils.py @@ -25,7 +25,7 @@ def make_num_elements_per_worker(N: int, num_workers: int) -> NDArray[np.int32]: return num_elements -def make_ipu_vector1d_worker_offsets( +def make_ipu_vector1d_worker_offsets_and_sizes( size: int, vector_size: int = 2, num_workers: int = 6, @@ -33,8 +33,8 @@ def make_ipu_vector1d_worker_offsets( allow_overlap: bool = False, grain_size: Optional[int] = None, ) -> NDArray[np.int_]: - """Make worker sizes/offsets for a 1D array workload, i.e. how many - data vectors per worker thread? + """Make worker sizes + offsets for a 1D array workload, i.e. how many + data vectors per worker thread (with starting offset)? Args: size: Size of the vector to divide. @@ -42,9 +42,62 @@ def make_ipu_vector1d_worker_offsets( num_workers: Number of workers. wdtype: Worklists dtype. allow_overlap: Allowing overlap between workers. Make it easier to deal with remainer term. + grain_size: Optional grain size. vector_size by default. Minimal size per thread. + Returns: + (NUM_WORKERS, 2) data offset + size per worker thread. + + NOTE: offsets and sizes expressed in vector size unit! + """ + grain_size = grain_size or vector_size + grain_scale = grain_size // vector_size + # TODO: support properly odd size. + assert size % 2 == 0, "Not supporting odd sizing at the moment." + # Base checks! + assert grain_size % vector_size == 0 + assert size >= grain_size, f"Requires at least a size of {grain_size}." + assert ( + size % grain_size == 0 or allow_overlap + ), f"Requires the size, {size}, divisible by the grain size {grain_size} (or overlap allowed)." + + # Offset+size array to build. + offset_size_arr = np.zeros((num_workers, 2), dtype=np.int32) + + # Base worksize on the first few workers. + base_worksize: int = math.ceil(size / (grain_size * num_workers)) + num_base_workers = size // (grain_size * base_worksize) + # Offsets + size + offset_size_arr[:num_base_workers, 0] = np.arange(num_base_workers) * base_worksize * grain_scale + offset_size_arr[:num_base_workers, 1] = base_worksize * grain_scale + if num_base_workers == num_workers: + return offset_size_arr.astype(wdtype) + + # Remainer term, for the next thread => all which is left, with potential overlap. + rem_worksize = size - base_worksize * grain_size * num_base_workers + rem_worksize = math.ceil(rem_worksize / grain_size) + offset_size_arr[num_base_workers, 0] = size / vector_size - rem_worksize * grain_scale + offset_size_arr[num_base_workers, 1] = rem_worksize * grain_scale + # Rest already filled with zeros... + return offset_size_arr.astype(wdtype) + + +def make_ipu_vector1d_worker_offsets( + size: int, + vector_size: int = 2, + num_workers: int = 6, + wdtype: DTypeLike = np.uint16, + grain_size: Optional[int] = None, +) -> NDArray[np.int_]: + """Make worker offsets (with additional padding) i.e. how many + data vectors per worker thread? + + Args: + size: Size of the vector to divide. + vector_size: Vector size (2: float, 4: half). + num_workers: Number of workers. + wdtype: Worklists dtype. grain_size: Optional grain size. vector_size by default. Returns: - (6,) number of data vectors per thread. + (NUM_WORKERS + 1,) data offset per worker thread. """ grain_size = grain_size or vector_size grain_scale = grain_size // vector_size @@ -59,9 +112,7 @@ def make_offsets_fn(sizes): # Base checks! assert grain_size % vector_size == 0 assert size >= grain_size, f"Requires at least a size of {grain_size}." - assert ( - size % grain_size == 0 or allow_overlap - ), f"Requires the size, {size}, divisible by the grain size {grain_size}, (or allowing overlap)." + assert size % grain_size == 0, f"Requires the size, {size}, divisible by the grain size {grain_size}." # Base worksize on the first few workers. base_worksize: int = math.ceil(size / (grain_size * num_workers)) diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 4074aaa..4719a2f 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -194,9 +194,8 @@ template inline void jacobi_update_second_step(const unsigned* rotset_sorted_arr, const T* cs_arr, const T* pcol, const T* qcol, T* pcol_updated, - T* qcol_updated, unsigned wstart, - unsigned wend) noexcept { - const unsigned wsize = (wend - wstart) / 2; + T* qcol_updated, const unsigned wstart, + const unsigned wsize) noexcept { // Necessary for generating `rpt` loop. __builtin_assume(wsize < 4096); using T2 = float2; @@ -324,7 +323,7 @@ class JacobiUpdateSecondStep : public MultiVertex { rotset_idx_ignored; // (1,) index in rotset to ignore. Input> - worker_offsets; // (7,) threads work size + 1. + worker_offsets_sizes; // (2, 6) worker offset + size Input> pcol; // (N+2,) p column Input> qcol; // (N+2,) q column @@ -339,9 +338,9 @@ class JacobiUpdateSecondStep : public MultiVertex { bool compute(unsigned wid) { // Size of the index prefix in pcol and qcol. constexpr unsigned INDEX_PREFIX = 2; - // Worker load: start + end vectorized indexes. - const unsigned wstart = worker_offsets[wid]; - const unsigned wend = worker_offsets[wid + 1]; + // Worker load: start + size vectorized indexes. + const unsigned wstart = worker_offsets_sizes[2 * wid]; + const unsigned wsize = worker_offsets_sizes[2 * wid + 1]; // Forward pq indices. pcol_updated[0] = pcol[0]; @@ -359,7 +358,7 @@ class JacobiUpdateSecondStep : public MultiVertex { jacobi_update_second_step(rotset_sorted_arr.data(), cs_arr.data(), pcol_ptr, qcol_ptr, pcol_updated_ptr, qcol_updated_ptr, - wstart, wend); + wstart, wsize); return true; } }; diff --git a/tessellate_ipu/lax/tile_lax_small_dot.py b/tessellate_ipu/lax/tile_lax_small_dot.py index 0a4076a..95c3b2f 100644 --- a/tessellate_ipu/lax/tile_lax_small_dot.py +++ b/tessellate_ipu/lax/tile_lax_small_dot.py @@ -5,8 +5,7 @@ import numpy as np from jax.core import ShapedArray -from tessellate_ipu.core import declare_ipu_tile_primitive -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import declare_ipu_tile_primitive, make_ipu_vector1d_worker_offsets def get_small_dot_vertex_gp_filename() -> str: diff --git a/tessellate_ipu/linalg/tile_linalg_hessenberg.py b/tessellate_ipu/linalg/tile_linalg_hessenberg.py index a9c88a6..5ffa926 100644 --- a/tessellate_ipu/linalg/tile_linalg_hessenberg.py +++ b/tessellate_ipu/linalg/tile_linalg_hessenberg.py @@ -15,7 +15,7 @@ tile_put_replicated, tile_put_sharded, ) -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets from .tile_linalg_qr import dot_product1d_p diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 0a67f0c..13b75a1 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -18,7 +18,7 @@ tile_put_replicated, tile_put_sharded, ) -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets, make_ipu_vector1d_worker_offsets_and_sizes from tessellate_ipu.lax import tile_fill from tessellate_ipu.utils import NDArray @@ -71,10 +71,11 @@ def get_jacobi_vertex_gp_filename() -> str: outputs={"cs_arr": 0, "pcol_updated": 3, "qcol_updated": 4}, constants={ # NOTE: using grain_size=4 because of partial loop unrolling - # TODO: support overlap properly. - "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16, allow_overlap=False, grain_size=4 + # Rescale the size to be directly in grain size unit. + "worker_offsets_sizes": lambda inavals, *_: make_ipu_vector1d_worker_offsets_and_sizes( + inavals[3].size - INDEX_PREFIX, vector_size=2, grain_size=4, wdtype=np.uint16, allow_overlap=True ) + // np.array([[1, 2]], dtype=np.uint16) }, gp_filename=get_jacobi_vertex_gp_filename(), perf_estimate=200, diff --git a/tessellate_ipu/linalg/tile_linalg_qr.py b/tessellate_ipu/linalg/tile_linalg_qr.py index e8412e2..adb2030 100644 --- a/tessellate_ipu/linalg/tile_linalg_qr.py +++ b/tessellate_ipu/linalg/tile_linalg_qr.py @@ -7,7 +7,7 @@ from jax.core import ShapedArray from tessellate_ipu import TileShardedArray, create_ipu_tile_primitive, tile_map, tile_put_replicated, tile_put_sharded -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets Array = Any diff --git a/tests/core/test_tile_interpreter_vertex_utils.py b/tests/core/test_tile_interpreter_vertex_utils.py index b2700de..5d1d336 100644 --- a/tests/core/test_tile_interpreter_vertex_utils.py +++ b/tests/core/test_tile_interpreter_vertex_utils.py @@ -7,6 +7,7 @@ from tessellate_ipu.core.tile_interpreter_vertex_utils import ( make_ipu_vector1d_worker_offsets, + make_ipu_vector1d_worker_offsets_and_sizes, make_num_elements_per_worker, ) @@ -45,3 +46,22 @@ def test__tile_vertex_utils__make_num_elements_per_worker(self, N, expected_num_ num_elements = make_num_elements_per_worker(N, num_workers) assert np.sum(num_elements) == N npt.assert_array_equal(num_elements, expected_num_elements) + + @parameterized.parameters( + {"N": 4, "expected_offsets": [0, 2, 0, 0, 0, 0], "expected_sizes": [2, 0, 0, 0, 0, 0]}, + {"N": 6, "expected_offsets": [0, 1, 0, 0, 0, 0], "expected_sizes": [2, 2, 0, 0, 0, 0]}, + {"N": 24, "expected_offsets": [0, 2, 4, 6, 8, 10], "expected_sizes": [2, 2, 2, 2, 2, 2]}, + {"N": 30, "expected_offsets": [0, 4, 8, 11, 0, 0], "expected_sizes": [4, 4, 4, 4, 0, 0]}, + {"N": 128, "expected_offsets": [0, 12, 24, 36, 48, 60], "expected_sizes": [12, 12, 12, 12, 12, 4]}, + ) + def test__tile_vertex_utils__make_ipu_vector1d_worker_offsets_and_sizes(self, N, expected_offsets, expected_sizes): + vector_size = 2 + grain_size = 4 + num_workers = 6 + woffsets_sizes = make_ipu_vector1d_worker_offsets_and_sizes( + N, vector_size, num_workers=num_workers, wdtype=np.int16, grain_size=grain_size, allow_overlap=True + ) + assert woffsets_sizes.shape == (num_workers, 2) + assert woffsets_sizes.dtype == np.int16 + npt.assert_array_equal(woffsets_sizes[:, 0], expected_offsets) + npt.assert_array_equal(woffsets_sizes[:, 1], expected_sizes) diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index 5c07e0d..15e0e21 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -177,8 +177,11 @@ def test__jacobi_eigh__single_iteration(self): npt.assert_array_almost_equal(np.linalg.eigh(A)[0], np.linalg.eigh(x)[0], decimal=5) @unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles") - def test__jacobi_eigh_raw__proper_eigh_result(self): - N = 12 + @parameterized.parameters( + {"N": 10}, # testing Jacobi 2nd update where grain size=4 + {"N": 12}, + ) + def test__jacobi_eigh_raw__proper_eigh_result(self, N): x = np.random.randn(N, N).astype(np.float32) x = (x + x.T) / 2.0