From 2aea6a7a164c015a26581f4fec1c708eb40d5b68 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 18 Oct 2023 10:32:36 +0100 Subject: [PATCH 1/5] Optimized Jacobi second-step vertex. (#49) The second step of the Jacobi algorithm is an usual "sparse" access update pattern, which can not be simply optimized with a `rpt` loop. The loop is intrinsically limited by the number of load/store operations. Nevertheless, by using `aop` outer product intrinsic + unrolling of 2 steps, the bundling of operations in the loop can be massively improved, leading to 40% decrease of cycle count. Additionally, this PR is adding a thin AMP C++ abstraction, allowing a simple implementation between IPU hardware and IPU model. --- tessellate_ipu/core/tile_array.py | 10 + .../core/tile_interpreter_vertex_utils.py | 35 +++- .../core/vertex/intrinsics_utils.hpp | 15 ++ tessellate_ipu/core/vertex/ipu_amp.hpp | 127 +++++++++++++ .../core/vertex/tile_jacobi_vertex.cpp | 176 ++++++++++++++---- tessellate_ipu/core/vertex/tile_small_dot.hpp | 16 +- tessellate_ipu/linalg/tile_linalg_jacobi.py | 13 +- tests/linalg/test_tile_linalg_jacobi.py | 2 +- 8 files changed, 339 insertions(+), 55 deletions(-) create mode 100644 tessellate_ipu/core/vertex/ipu_amp.hpp diff --git a/tessellate_ipu/core/tile_array.py b/tessellate_ipu/core/tile_array.py index 0979320..5fb8ab2 100644 --- a/tessellate_ipu/core/tile_array.py +++ b/tessellate_ipu/core/tile_array.py @@ -1,8 +1,10 @@ # Copyright (c) 2022 Graphcore Ltd. All rights reserved. +import itertools from dataclasses import dataclass from typing import Any, Sequence, Tuple, Union import chex +import jax.lax import numpy as np from jax.core import ShapedArray from jax.interpreters.xla import DeviceArray @@ -185,6 +187,14 @@ def __getitem__(self, key: Union[SliceType, MultiSliceType]) -> "TileShardedArra check_tile_array_multi_slice(key, self.array.shape) return TileShardedArray(array=self.array[key], tiles=self.tiles[key[0]]) # type:ignore + @classmethod + def concatenate(cls, arrays: Sequence["TileShardedArray"]) -> "TileShardedArray": + """Concatenate tile sharded arrays along the first axis.""" + assert all([isinstance(v, TileShardedArray) for v in arrays]) + outarray = jax.lax.concatenate([v.array for v in arrays], dimension=0) + outtiles = tuple(itertools.chain(*[v.tiles for v in arrays])) + return TileShardedArray(array=outarray, tiles=outtiles) + def tile_put_sharded(array: DeviceArray, tiles: Sequence[int]) -> TileShardedArray: """Shard a JAX array over tiles on the first axis. diff --git a/tessellate_ipu/core/tile_interpreter_vertex_utils.py b/tessellate_ipu/core/tile_interpreter_vertex_utils.py index 2ba3c2a..1389281 100644 --- a/tessellate_ipu/core/tile_interpreter_vertex_utils.py +++ b/tessellate_ipu/core/tile_interpreter_vertex_utils.py @@ -1,6 +1,6 @@ # Copyright (c) 2022 Graphcore Ltd. All rights reserved. import math -from typing import List +from typing import List, Optional import numpy as np from numpy.typing import DTypeLike, NDArray @@ -26,9 +26,14 @@ def make_num_elements_per_worker(N: int, num_workers: int) -> NDArray[np.int32]: def make_ipu_vector1d_worker_offsets( - size: int, vector_size: int = 2, num_workers: int = 6, wdtype: DTypeLike = np.uint16 + size: int, + vector_size: int = 2, + num_workers: int = 6, + wdtype: DTypeLike = np.uint16, + allow_overlap: bool = False, + grain_size: Optional[int] = None, ) -> NDArray[np.int_]: - """Make the QR householder row update worker sizes, i.e. how many + """Make worker sizes/offsets for a 1D array workload, i.e. how many data vectors per worker thread? Args: @@ -36,26 +41,38 @@ def make_ipu_vector1d_worker_offsets( vector_size: Vector size (2: float, 4: half). num_workers: Number of workers. wdtype: Worklists dtype. + allow_overlap: Allowing overlap between workers. Make it easier to deal with remainer term. + grain_size: Optional grain size. vector_size by default. Returns: (6,) number of data vectors per thread. """ + grain_size = grain_size or vector_size + grain_scale = grain_size // vector_size def make_offsets_fn(sizes): sizes = [0] + sizes - offsets = np.cumsum(np.array(sizes, wdtype), dtype=wdtype) + offsets = np.cumsum(np.array(sizes, wdtype) * grain_scale, dtype=wdtype) return offsets - assert size % vector_size == 0 + # TODO: support properly odd size. + assert size % 2 == 0, "Not supporting odd sizing at the moment." + # Base checks! + assert grain_size % vector_size == 0 + assert size >= grain_size, f"Requires at least a size of {grain_size}." + assert ( + size % grain_size == 0 or allow_overlap + ), f"Requires the size, {size}, divisible by the grain size {grain_size}, (or allowing overlap)." + # Base worksize on the first few workers. - base_worksize: int = math.ceil(size / (vector_size * num_workers)) - num_base_workers = size // (vector_size * base_worksize) + base_worksize: int = math.ceil(size / (grain_size * num_workers)) + num_base_workers = size // (grain_size * base_worksize) worker_sizes: List[int] = [base_worksize] * num_base_workers if num_base_workers == num_workers: return make_offsets_fn(worker_sizes) # Remainer term, for the next thread. - rem_worksize = size - base_worksize * vector_size * num_base_workers - rem_worksize = rem_worksize // vector_size + rem_worksize = size - base_worksize * grain_size * num_base_workers + rem_worksize = rem_worksize // grain_size worker_sizes += [rem_worksize] # Fill the rest with zeros. unused_workers = num_workers - num_base_workers - 1 diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index a254777..7860ac2 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -64,6 +64,7 @@ ALWAYS_INLINE T ipu_div_by_6(T n) noexcept { */ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { // TAS register, used for __builtin_ipu_f32v2axpy. + // TODO: use `__builtin_ipu_uput`? asm volatile( R"l( uput $TAS, %[sv] )l" @@ -72,6 +73,20 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { :); } +/** + * @brief Zero AACC registers. + */ +ALWAYS_INLINE void __builtin_ipu_aacc_zero() { + asm (R"( + setzi $a0, 0x8 + uput $FP_CLR, $a0 + )" + : + : + : "$a0"); +} + + /** * @brief IPU cmac f32 instruction. */ diff --git a/tessellate_ipu/core/vertex/ipu_amp.hpp b/tessellate_ipu/core/vertex/ipu_amp.hpp new file mode 100644 index 0000000..0ef7b85 --- /dev/null +++ b/tessellate_ipu/core/vertex/ipu_amp.hpp @@ -0,0 +1,127 @@ +// Copyright (c) 2023 Graphcore Ltd. All rights reserved. +#pragma once +#include + +#include "intrinsics_utils.hpp" +#include "ipu_model_types.hpp" + +namespace ipu { + +/** + * @brief Thin abstraction of the IPU AMP unit(s) and registers, allowing + * to write generic code compiling on IPU model and IPU hardware. + * + * NOTE: zero-cost abstraction on IPU hardware. + * + * The AMP class is modelling AACC registers as well as AMP unit instructions + * on the IPU model, reproducing the expected behaviour of the hardware. + */ +template +class AMP { + public: + // TODO: support half as well. + static_assert(std::is_same_v); + using FPType = T; + /** Number of AACC register available in hw. */ + // TODO: use TFPU_AMP_UNITS_PER_SET and TFPU_AACC_PER_AMP_UNIT; + static constexpr unsigned NumAACC = 16; + + // TODO: random initialization on IPU model of registers. + AMP() noexcept = default; + // No copy + no move allowed! + AMP(const AMP&) = delete; + AMP(AMP&&) = delete; + + /** + * @brief Set the value of the TAS register, used in + * `axpy` operation. + */ + ALWAYS_INLINE void tas(FPType val) noexcept { +#ifdef __IPU__ + __builtin_ipu_put_tas(val); +#else + m_tas = val; +#endif + } + /** + * @brief Zero AACC registers. + */ + ALWAYS_INLINE void aaccZero() noexcept { +#ifdef __IPU__ + __builtin_ipu_aacc_zero(); +#else + for (unsigned idx = 0; idx < NumAACC; ++idx) { + m_aacc[idx] = 0; + } +#endif + } + + /** + * @brief Scaled-add `axpy` intrinsic. Only supported on FP32. + * NOTE: act as 1 stage pipeline, storing result in AACC[0...2] + */ + ALWAYS_INLINE float2 axpy(float2 x, float2 y) noexcept { + using T2 = float2; +#ifdef __IPU__ + // Weird ordering here? Bug in the intrinsic definition? + return __builtin_ipu_f32v2axpy(y, x); +#else + // Simulating pipeline with storing in AACC[0] and AACC[2]. + const auto res = T2{m_aacc[0], m_aacc[2]}; + // FIXME/TODO: understand ordering!? + m_aacc[0] = m_tas * x[0] + y[0]; + m_aacc[2] = m_tas * x[1] + y[1]; + return res; +#endif + } + + /** + * @brief Outer-product `aop` intrinsic. Only supported on FP32. + * Storing results in AACC[0...6] + */ + void aop(float2 x, float2 y) noexcept { +#ifdef __IPU__ + // Note: third argument not used by hw. + __builtin_ipu_f32v2aop(x, y, 0); +#else + // Multiply + accumulate. + m_aacc[0] += x[0] * y[0]; + m_aacc[2] += x[1] * y[0]; + m_aacc[4] += x[0] * y[1]; + m_aacc[6] += x[1] * y[1]; +#endif + } + + /** + * @brief `gina` instruction: get AACC register + propagate. + * FIXME: support non-zero flag/index. + */ + template + float2 gina(float2 val) noexcept { + using T2 = float2; +#ifdef __IPU__ + return __builtin_ipu_f32v2gina(val, 0); +#else + // TODO: implement GINA_IMMFLAGS__SET__GET + const auto res = T2{m_aacc[0], m_aacc[2]}; + // Propagate accumulator states. + for (unsigned idx = 4; idx < NumAACC; idx += 4) { + m_aacc[idx - 4] = m_aacc[idx]; + m_aacc[idx - 2] = m_aacc[idx + 2]; + } + m_aacc[NumAACC - 4] = val[0]; + m_aacc[NumAACC - 2] = val[1]; + return res; +#endif + } + + private: +#ifndef __IPU__ + // Simulating AACC registers on IPU model. + FPType m_aacc[NumAACC]; + // Simulating TAS register on IPU model. + FPType m_tas; +#endif +}; + +} // namespace ipu diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index e30de5f..4074aaa 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" #include "tile_small_dot.hpp" using namespace poplar; @@ -80,10 +81,11 @@ class JacobiSymSchur2 : public Vertex { }; template -void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, - T* qcol_updated, T* cs, unsigned p, unsigned q, - unsigned short wstart, - unsigned short wend) noexcept { +inline void jacob_update_first_step(const T* pcol, const T* qcol, + T* pcol_updated, T* qcol_updated, T* cs, + unsigned p, unsigned q, + unsigned short wstart, + unsigned short wend) noexcept { using T2 = float2; using IndexType = unsigned short; @@ -106,7 +108,7 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, float2* ptr_qcol_updated = reinterpret_cast(qcol_updated) + wstart; // Apply Schur2 cs rotation to p/q columns (optimized kernel). rotation2d_f32(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated, - ptr_qcol_updated, wsize); + ptr_qcol_updated, wsize); // Update main values App, Apq, Aqq pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; @@ -184,6 +186,129 @@ class [[poplar::constraint( } }; +/** + * @brief Jacobi update second step, using Schur2 coefficient from + * other pairs of columns. + */ +template +inline void jacobi_update_second_step(const unsigned* rotset_sorted_arr, + const T* cs_arr, const T* pcol, + const T* qcol, T* pcol_updated, + T* qcol_updated, unsigned wstart, + unsigned wend) noexcept { + const unsigned wsize = (wend - wstart) / 2; + // Necessary for generating `rpt` loop. + __builtin_assume(wsize < 4096); + using T2 = float2; + // Increment pointers. NOTE: unrolling creating "4x" factor. + rotset_sorted_arr += 2 * wstart; + const T2* cs_arr_ptr = reinterpret_cast(cs_arr) + wstart; + + // Basic usage of AMP unit with `aop` outer-product :) + ipu::AMP amp; + amp.aaccZero(); + + const T2 zeros{0, 0}; + T2 res, cs0, cs1, Sp0, Sq0, Sp1, Sq1, tmp0, tmp1; + unsigned k0, l0, k1, l1; + + // The loop body is roughly the following equations: + // const T Spk = pcol_ptr[k]; + // const T Spl = pcol_ptr[l]; + // const T Sqk = qcol_ptr[k]; + // const T Sql = qcol_ptr[l]; + + // pcol_updated_ptr[k] = c * Spk - s * Spl; + // pcol_updated_ptr[l] = s * Spk + c * Spl; + // qcol_updated_ptr[k] = c * Sqk - s * Sql; + // qcol_updated_ptr[l] = s * Sqk + c * Sql; + + // Problem: generate poor bundling of operations in the loop. + // Solution: unroll 2 steps + f32v2aop + manual re-ordering. + // NOTE: f32v2aop mostly useful for reducing register pressure, + // as results are stored in AACC registers (not AUX). Just saving 1 compute + // cycle. + + // Pre-loading due to unrolling + reordering. + k0 = ipu::load_postinc(&rotset_sorted_arr, 1); + l0 = ipu::load_postinc(&rotset_sorted_arr, 1); + cs0 = ipu::load_postinc(&cs_arr_ptr, 1); + Sp0 = {pcol[k0], pcol[l0]}; + for (unsigned half_idx = 0; half_idx < wsize; ++half_idx) { + // Pseudo bundling of instructions, to help popc. + { + Sq0[0] = qcol[k0]; + amp.aop(cs0, Sp0); + } + { + k1 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + l1 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp1 = amp.template gina<0>(zeros); + } + { + Sq0[1] = qcol[l0]; + pcol_updated[k0] = tmp0[0] - tmp1[1]; + } + { + pcol_updated[l0] = tmp0[1] + tmp1[0]; + amp.aop(cs0, Sq0); + } + { + cs1 = ipu::load_postinc(&cs_arr_ptr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + Sp1[0] = pcol[k1]; + tmp1 = amp.template gina<0>(zeros); + } + { + Sp1[1] = pcol[l1]; + qcol_updated[k0] = tmp0[0] - tmp1[1]; + } + // Unrolling: second part. + // NOTE: inputs already (partially) loaded. + { + qcol_updated[l0] = tmp0[1] + tmp1[0]; + amp.aop(cs1, Sp1); + } + { + Sq1[0] = qcol[k1]; + tmp0 = amp.template gina<0>(zeros); + } + { + Sq1[1] = qcol[l1]; + tmp1 = amp.template gina<0>(zeros); + } + { + k0 = ipu::load_postinc(&rotset_sorted_arr, 1); + pcol_updated[k1] = tmp0[0] - tmp1[1]; + } + { + pcol_updated[l1] = tmp0[1] + tmp1[0]; + amp.aop(cs1, Sq1); + } + { + l0 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + cs0 = ipu::load_postinc(&cs_arr_ptr, 1); + tmp1 = amp.template gina<0>(zeros); + } + { + Sp0[0] = pcol[k0]; + qcol_updated[k1] = tmp0[0] - tmp1[1]; + } + { + qcol_updated[l1] = tmp0[1] + tmp1[0]; + Sp0[1] = pcol[l0]; + } + } +} + class JacobiUpdateSecondStep : public MultiVertex { public: using T = float; @@ -213,11 +338,14 @@ class JacobiUpdateSecondStep : public MultiVertex { bool compute(unsigned wid) { // Size of the index prefix in pcol and qcol. - constexpr int INDEX_PREFIX = 2; + constexpr unsigned INDEX_PREFIX = 2; // Worker load: start + end vectorized indexes. - const IndexType wstart = worker_offsets[wid]; - const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; + const unsigned wstart = worker_offsets[wid]; + const unsigned wend = worker_offsets[wid + 1]; + + // Forward pq indices. + pcol_updated[0] = pcol[0]; + qcol_updated[0] = qcol[0]; // Use (p, q) = (1, 0) for ignore idx. const unsigned ignore_idx = 2 * rotset_idx_ignored[0]; @@ -229,33 +357,9 @@ class JacobiUpdateSecondStep : public MultiVertex { auto pcol_updated_ptr = pcol_updated.data() + INDEX_PREFIX; auto qcol_updated_ptr = qcol_updated.data() + INDEX_PREFIX; - // Forward pq indices. - pcol_updated[0] = pcol[0]; - qcol_updated[0] = qcol[0]; - - // Parallized loop on update using other columns coefficients - for (IndexType half_idx = 0; half_idx != wsize; ++half_idx) { - // TODO: cleaning pq indices offset. - const unsigned k = rotset_sorted_arr[2 * half_idx + 2 * wstart]; - const unsigned l = rotset_sorted_arr[2 * half_idx + 1 + 2 * wstart]; - - const T c = cs_arr[2 * half_idx + 2 * wstart]; - const T s = cs_arr[2 * half_idx + 1 + 2 * wstart]; - - // 4 coefficients updates! - // TODO: vectorization?! - const T Spk = pcol_ptr[k]; - const T Spl = pcol_ptr[l]; - - const T Sqk = qcol_ptr[k]; - const T Sql = qcol_ptr[l]; - - pcol_updated_ptr[k] = c * Spk - s * Spl; - pcol_updated_ptr[l] = s * Spk + c * Spl; - - qcol_updated_ptr[k] = c * Sqk - s * Sql; - qcol_updated_ptr[l] = s * Sqk + c * Sql; - } + jacobi_update_second_step(rotset_sorted_arr.data(), cs_arr.data(), pcol_ptr, + qcol_ptr, pcol_updated_ptr, qcol_updated_ptr, + wstart, wend); return true; } }; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp index 26bd338..0b380fd 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.hpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -1,5 +1,6 @@ // Copyright (c) 2023 Graphcore Ltd. All rights reserved. #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" /** * @brief z = a*x + b*y float32 implementation. @@ -35,9 +36,9 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // __builtin_assume(nblocks < 4096); using T2 = float2; const T2 av = {a, a}; - // Using TAS register for one of the scalar. - __ipu_and_ipumodel_tas tas; - tas.put(b); + // Basic AMP usage with TAS + axpy instruction. + ipu::AMP amp; + amp.tas(b); T2 res, xv, yv, zv, tmp; @@ -49,13 +50,11 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // popc should be able to generate an optimal rpt loop. { xv = ipu::load_postinc(&x, 1); - // TODO: fix ordering of arguments in `f32v2axpy`. - tmp = tas.f32v2axpy(res, yv); + tmp = amp.axpy(yv, res); } { yv = ipu::load_postinc(&y, 1); - // TODO: fix ordering of arguments in `f32v2axpy`. - zv = tas.f32v2axpy(tmp, tmp); + zv = amp.axpy(tmp, tmp); } { ipu::store_postinc(&z, zv, 1); @@ -139,7 +138,8 @@ template inline void rotation2d_f32(float2 cs, const float2 *inrow0, const float2 *inrow1, float2 *outrow0, float2 *outrow1, rptsize_t nblocks) { - // TODO: investigate using IPU AMP unit? + // axplusby is using one AMP unit. TODO: investigate using more! axplusby_f32(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); + // NOTE: inrow1+0, outrow1 arguments order necessary due to bank constraints! axplusby_f32(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks); } diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 5319b56..0a67f0c 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -19,6 +19,7 @@ tile_put_sharded, ) from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.lax import tile_fill from tessellate_ipu.utils import NDArray Array = Any @@ -69,8 +70,10 @@ def get_jacobi_vertex_gp_filename() -> str: inputs=["cs_arr", "rotset_sorted_arr", "rotset_idx_ignored", "pcol", "qcol"], outputs={"cs_arr": 0, "pcol_updated": 3, "qcol_updated": 4}, constants={ + # NOTE: using grain_size=4 because of partial loop unrolling + # TODO: support overlap properly. "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16 + inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16, allow_overlap=False, grain_size=4 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -232,6 +235,14 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore jacobi_update_first_step_p, Apcols, Aqcols ) + # Append zero indices to the rotset, for loop unrolling in `jacobi_update_second_step` + rotset_zeros = tile_fill((2,), 0, dtype=rotset_sorted_sharded.dtype, tiles=(0,)) + # Barrier to make sure communication gets fused into a single block. + rotset_zeros, rotset_sorted_sharded, cs_per_tile = tile_data_barrier( + rotset_zeros, rotset_sorted_sharded, cs_per_tile + ) + rotset_sorted_sharded = TileShardedArray.concatenate([rotset_sorted_sharded, rotset_zeros]) + # Replicate Schur decomposition + rotset across all A tiles: (2*N//2) comms. with jax.named_scope("rotset_sorted_replicated"): rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles) diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index e54dd6c..5c07e0d 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -178,7 +178,7 @@ def test__jacobi_eigh__single_iteration(self): @unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles") def test__jacobi_eigh_raw__proper_eigh_result(self): - N = 8 + N = 12 x = np.random.randn(N, N).astype(np.float32) x = (x + x.T) / 2.0 From 0e1939e4703b59f7df4a876971eb9e14bcd867f8 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 19 Oct 2023 16:04:03 +0100 Subject: [PATCH 2/5] Fix IPU Jacobi eigh algorithm when size % 4 == 2 (#50) The recent improvement in PR #49 introduced a regression in Jacobi `eigh`, raising an error when size % 4 == 2. This is due to the partial loop unrolling in Jacobi second update stage. This PR is fixing the issue by passing explicitely the offset and size of the workload to the vertex. --- tessellate_ipu/core/__init__.py | 1 + .../core/tile_interpreter_vertex_utils.py | 65 +++++++++++++++++-- .../core/vertex/tile_jacobi_vertex.cpp | 15 ++--- tessellate_ipu/lax/tile_lax_small_dot.py | 3 +- .../linalg/tile_linalg_hessenberg.py | 2 +- tessellate_ipu/linalg/tile_linalg_jacobi.py | 9 +-- tessellate_ipu/linalg/tile_linalg_qr.py | 2 +- .../test_tile_interpreter_vertex_utils.py | 20 ++++++ tests/linalg/test_tile_linalg_jacobi.py | 7 +- 9 files changed, 99 insertions(+), 25 deletions(-) diff --git a/tessellate_ipu/core/__init__.py b/tessellate_ipu/core/__init__.py index 46fb340..9efdfed 100644 --- a/tessellate_ipu/core/__init__.py +++ b/tessellate_ipu/core/__init__.py @@ -45,6 +45,7 @@ primitive_clone, primitive_num_inout_alias_args, ) +from .tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets, make_ipu_vector1d_worker_offsets_and_sizes def tessellate_ipu_cleanup(): diff --git a/tessellate_ipu/core/tile_interpreter_vertex_utils.py b/tessellate_ipu/core/tile_interpreter_vertex_utils.py index 1389281..52e2e08 100644 --- a/tessellate_ipu/core/tile_interpreter_vertex_utils.py +++ b/tessellate_ipu/core/tile_interpreter_vertex_utils.py @@ -25,7 +25,7 @@ def make_num_elements_per_worker(N: int, num_workers: int) -> NDArray[np.int32]: return num_elements -def make_ipu_vector1d_worker_offsets( +def make_ipu_vector1d_worker_offsets_and_sizes( size: int, vector_size: int = 2, num_workers: int = 6, @@ -33,8 +33,8 @@ def make_ipu_vector1d_worker_offsets( allow_overlap: bool = False, grain_size: Optional[int] = None, ) -> NDArray[np.int_]: - """Make worker sizes/offsets for a 1D array workload, i.e. how many - data vectors per worker thread? + """Make worker sizes + offsets for a 1D array workload, i.e. how many + data vectors per worker thread (with starting offset)? Args: size: Size of the vector to divide. @@ -42,9 +42,62 @@ def make_ipu_vector1d_worker_offsets( num_workers: Number of workers. wdtype: Worklists dtype. allow_overlap: Allowing overlap between workers. Make it easier to deal with remainer term. + grain_size: Optional grain size. vector_size by default. Minimal size per thread. + Returns: + (NUM_WORKERS, 2) data offset + size per worker thread. + + NOTE: offsets and sizes expressed in vector size unit! + """ + grain_size = grain_size or vector_size + grain_scale = grain_size // vector_size + # TODO: support properly odd size. + assert size % 2 == 0, "Not supporting odd sizing at the moment." + # Base checks! + assert grain_size % vector_size == 0 + assert size >= grain_size, f"Requires at least a size of {grain_size}." + assert ( + size % grain_size == 0 or allow_overlap + ), f"Requires the size, {size}, divisible by the grain size {grain_size} (or overlap allowed)." + + # Offset+size array to build. + offset_size_arr = np.zeros((num_workers, 2), dtype=np.int32) + + # Base worksize on the first few workers. + base_worksize: int = math.ceil(size / (grain_size * num_workers)) + num_base_workers = size // (grain_size * base_worksize) + # Offsets + size + offset_size_arr[:num_base_workers, 0] = np.arange(num_base_workers) * base_worksize * grain_scale + offset_size_arr[:num_base_workers, 1] = base_worksize * grain_scale + if num_base_workers == num_workers: + return offset_size_arr.astype(wdtype) + + # Remainer term, for the next thread => all which is left, with potential overlap. + rem_worksize = size - base_worksize * grain_size * num_base_workers + rem_worksize = math.ceil(rem_worksize / grain_size) + offset_size_arr[num_base_workers, 0] = size / vector_size - rem_worksize * grain_scale + offset_size_arr[num_base_workers, 1] = rem_worksize * grain_scale + # Rest already filled with zeros... + return offset_size_arr.astype(wdtype) + + +def make_ipu_vector1d_worker_offsets( + size: int, + vector_size: int = 2, + num_workers: int = 6, + wdtype: DTypeLike = np.uint16, + grain_size: Optional[int] = None, +) -> NDArray[np.int_]: + """Make worker offsets (with additional padding) i.e. how many + data vectors per worker thread? + + Args: + size: Size of the vector to divide. + vector_size: Vector size (2: float, 4: half). + num_workers: Number of workers. + wdtype: Worklists dtype. grain_size: Optional grain size. vector_size by default. Returns: - (6,) number of data vectors per thread. + (NUM_WORKERS + 1,) data offset per worker thread. """ grain_size = grain_size or vector_size grain_scale = grain_size // vector_size @@ -59,9 +112,7 @@ def make_offsets_fn(sizes): # Base checks! assert grain_size % vector_size == 0 assert size >= grain_size, f"Requires at least a size of {grain_size}." - assert ( - size % grain_size == 0 or allow_overlap - ), f"Requires the size, {size}, divisible by the grain size {grain_size}, (or allowing overlap)." + assert size % grain_size == 0, f"Requires the size, {size}, divisible by the grain size {grain_size}." # Base worksize on the first few workers. base_worksize: int = math.ceil(size / (grain_size * num_workers)) diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 4074aaa..4719a2f 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -194,9 +194,8 @@ template inline void jacobi_update_second_step(const unsigned* rotset_sorted_arr, const T* cs_arr, const T* pcol, const T* qcol, T* pcol_updated, - T* qcol_updated, unsigned wstart, - unsigned wend) noexcept { - const unsigned wsize = (wend - wstart) / 2; + T* qcol_updated, const unsigned wstart, + const unsigned wsize) noexcept { // Necessary for generating `rpt` loop. __builtin_assume(wsize < 4096); using T2 = float2; @@ -324,7 +323,7 @@ class JacobiUpdateSecondStep : public MultiVertex { rotset_idx_ignored; // (1,) index in rotset to ignore. Input> - worker_offsets; // (7,) threads work size + 1. + worker_offsets_sizes; // (2, 6) worker offset + size Input> pcol; // (N+2,) p column Input> qcol; // (N+2,) q column @@ -339,9 +338,9 @@ class JacobiUpdateSecondStep : public MultiVertex { bool compute(unsigned wid) { // Size of the index prefix in pcol and qcol. constexpr unsigned INDEX_PREFIX = 2; - // Worker load: start + end vectorized indexes. - const unsigned wstart = worker_offsets[wid]; - const unsigned wend = worker_offsets[wid + 1]; + // Worker load: start + size vectorized indexes. + const unsigned wstart = worker_offsets_sizes[2 * wid]; + const unsigned wsize = worker_offsets_sizes[2 * wid + 1]; // Forward pq indices. pcol_updated[0] = pcol[0]; @@ -359,7 +358,7 @@ class JacobiUpdateSecondStep : public MultiVertex { jacobi_update_second_step(rotset_sorted_arr.data(), cs_arr.data(), pcol_ptr, qcol_ptr, pcol_updated_ptr, qcol_updated_ptr, - wstart, wend); + wstart, wsize); return true; } }; diff --git a/tessellate_ipu/lax/tile_lax_small_dot.py b/tessellate_ipu/lax/tile_lax_small_dot.py index 0a4076a..95c3b2f 100644 --- a/tessellate_ipu/lax/tile_lax_small_dot.py +++ b/tessellate_ipu/lax/tile_lax_small_dot.py @@ -5,8 +5,7 @@ import numpy as np from jax.core import ShapedArray -from tessellate_ipu.core import declare_ipu_tile_primitive -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import declare_ipu_tile_primitive, make_ipu_vector1d_worker_offsets def get_small_dot_vertex_gp_filename() -> str: diff --git a/tessellate_ipu/linalg/tile_linalg_hessenberg.py b/tessellate_ipu/linalg/tile_linalg_hessenberg.py index a9c88a6..5ffa926 100644 --- a/tessellate_ipu/linalg/tile_linalg_hessenberg.py +++ b/tessellate_ipu/linalg/tile_linalg_hessenberg.py @@ -15,7 +15,7 @@ tile_put_replicated, tile_put_sharded, ) -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets from .tile_linalg_qr import dot_product1d_p diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 0a67f0c..13b75a1 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -18,7 +18,7 @@ tile_put_replicated, tile_put_sharded, ) -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets, make_ipu_vector1d_worker_offsets_and_sizes from tessellate_ipu.lax import tile_fill from tessellate_ipu.utils import NDArray @@ -71,10 +71,11 @@ def get_jacobi_vertex_gp_filename() -> str: outputs={"cs_arr": 0, "pcol_updated": 3, "qcol_updated": 4}, constants={ # NOTE: using grain_size=4 because of partial loop unrolling - # TODO: support overlap properly. - "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16, allow_overlap=False, grain_size=4 + # Rescale the size to be directly in grain size unit. + "worker_offsets_sizes": lambda inavals, *_: make_ipu_vector1d_worker_offsets_and_sizes( + inavals[3].size - INDEX_PREFIX, vector_size=2, grain_size=4, wdtype=np.uint16, allow_overlap=True ) + // np.array([[1, 2]], dtype=np.uint16) }, gp_filename=get_jacobi_vertex_gp_filename(), perf_estimate=200, diff --git a/tessellate_ipu/linalg/tile_linalg_qr.py b/tessellate_ipu/linalg/tile_linalg_qr.py index e8412e2..adb2030 100644 --- a/tessellate_ipu/linalg/tile_linalg_qr.py +++ b/tessellate_ipu/linalg/tile_linalg_qr.py @@ -7,7 +7,7 @@ from jax.core import ShapedArray from tessellate_ipu import TileShardedArray, create_ipu_tile_primitive, tile_map, tile_put_replicated, tile_put_sharded -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets Array = Any diff --git a/tests/core/test_tile_interpreter_vertex_utils.py b/tests/core/test_tile_interpreter_vertex_utils.py index b2700de..5d1d336 100644 --- a/tests/core/test_tile_interpreter_vertex_utils.py +++ b/tests/core/test_tile_interpreter_vertex_utils.py @@ -7,6 +7,7 @@ from tessellate_ipu.core.tile_interpreter_vertex_utils import ( make_ipu_vector1d_worker_offsets, + make_ipu_vector1d_worker_offsets_and_sizes, make_num_elements_per_worker, ) @@ -45,3 +46,22 @@ def test__tile_vertex_utils__make_num_elements_per_worker(self, N, expected_num_ num_elements = make_num_elements_per_worker(N, num_workers) assert np.sum(num_elements) == N npt.assert_array_equal(num_elements, expected_num_elements) + + @parameterized.parameters( + {"N": 4, "expected_offsets": [0, 2, 0, 0, 0, 0], "expected_sizes": [2, 0, 0, 0, 0, 0]}, + {"N": 6, "expected_offsets": [0, 1, 0, 0, 0, 0], "expected_sizes": [2, 2, 0, 0, 0, 0]}, + {"N": 24, "expected_offsets": [0, 2, 4, 6, 8, 10], "expected_sizes": [2, 2, 2, 2, 2, 2]}, + {"N": 30, "expected_offsets": [0, 4, 8, 11, 0, 0], "expected_sizes": [4, 4, 4, 4, 0, 0]}, + {"N": 128, "expected_offsets": [0, 12, 24, 36, 48, 60], "expected_sizes": [12, 12, 12, 12, 12, 4]}, + ) + def test__tile_vertex_utils__make_ipu_vector1d_worker_offsets_and_sizes(self, N, expected_offsets, expected_sizes): + vector_size = 2 + grain_size = 4 + num_workers = 6 + woffsets_sizes = make_ipu_vector1d_worker_offsets_and_sizes( + N, vector_size, num_workers=num_workers, wdtype=np.int16, grain_size=grain_size, allow_overlap=True + ) + assert woffsets_sizes.shape == (num_workers, 2) + assert woffsets_sizes.dtype == np.int16 + npt.assert_array_equal(woffsets_sizes[:, 0], expected_offsets) + npt.assert_array_equal(woffsets_sizes[:, 1], expected_sizes) diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index 5c07e0d..15e0e21 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -177,8 +177,11 @@ def test__jacobi_eigh__single_iteration(self): npt.assert_array_almost_equal(np.linalg.eigh(A)[0], np.linalg.eigh(x)[0], decimal=5) @unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles") - def test__jacobi_eigh_raw__proper_eigh_result(self): - N = 12 + @parameterized.parameters( + {"N": 10}, # testing Jacobi 2nd update where grain size=4 + {"N": 12}, + ) + def test__jacobi_eigh_raw__proper_eigh_result(self, N): x = np.random.randn(N, N).astype(np.float32) x = (x + x.T) / 2.0 From b891bd93073650b530b553def71d0f55ac4a90a3 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 19 Oct 2023 16:41:02 +0100 Subject: [PATCH 3/5] Optimized the IPU eigh vertex `JacobiUpdateEigenvectors`. (#51) Very simple optimization, taking advantage of previously optimized kernel `rotation2d_f32`. 2.5 reduction on vertex cycle counts. --- .../core/vertex/tile_jacobi_vertex.cpp | 32 +++++++------------ tests/linalg/test_tile_linalg_jacobi.py | 4 +-- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index 4719a2f..b7aced0 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -363,35 +363,26 @@ class JacobiUpdateSecondStep : public MultiVertex { } }; -template +template void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, T* vqcol_updated, T c, T s, unsigned short wstart, unsigned short wend) noexcept { - using T2 = float2; // Using `uint16` seems to be generating more efficient loops? using IndexType = unsigned short; - - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; const IndexType wsize = wend - wstart; + using T2 = float2; + const T2 cs_vec = T2{c, s}; + // pcol, qcol and results pointers. const T2* ptr_pcol = reinterpret_cast(vpcol) + wstart; const T2* ptr_qcol = reinterpret_cast(vqcol) + wstart; T2* ptr_pcol_updated = reinterpret_cast(vpcol_updated) + wstart; T2* ptr_qcol_updated = reinterpret_cast(vqcol_updated) + wstart; - - for (IndexType idx = 0; idx != wsize; ++idx) { - const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 vpvec_updated = cvec * vpvec - svec * vqvec; - const T2 vqvec_updated = svec * vpvec + cvec * vqvec; - - ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); - ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); - } + // Apply Schur2 cs rotation to p/q columns (optimized kernel). + rotation2d_f32(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated, + ptr_qcol_updated, wsize); } /** @@ -400,8 +391,9 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, * See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition, * Johns Hopkins Chapter 8. */ -class [[poplar::constraint( - "elem(*vpcol) != elem(*vqcol)")]] JacobiUpdateEigenvectors +class [[poplar::constraint( + "elem(*vpcol) != elem(*vpcol_out)", + "elem(*vqcol) != elem(*vqcol_out)")]] JacobiUpdateEigenvectors : public MultiVertex { public: using T = float; @@ -439,12 +431,12 @@ class [[poplar::constraint( vqcol_out[0] = vqcol[0]; // Swapping pointers if necessary. if (p <= q) { - jacob_update_eigenvectors( + jacob_update_eigenvectors( vpcol.data() + INDEX_PREFIX, vqcol.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, c, s, wstart, wend); } else { - jacob_update_eigenvectors( + jacob_update_eigenvectors( vqcol.data() + INDEX_PREFIX, vpcol.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, c, s, wstart, wend); diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index 15e0e21..e3eb87a 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -128,7 +128,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol): # assert False def test__jacobi_update_eigenvectors_vertex__benchmark_performance(self): - N = 256 + N = 512 tiles = (0,) cs = np.array([0.2, 0.5], dtype=np.float32) pcol = np.random.randn(1, N).astype(np.float32) @@ -158,7 +158,7 @@ def jacobi_update_eigenvectors_fn(cs, pcol, qcol): # Cycle count reference for scale_add: 64(375), 128(467), 256(665), 512(1043) start, end = np.asarray(start)[0], np.asarray(end)[0] qr_correction_cycle_count = end[0] - start[0] - assert qr_correction_cycle_count <= 2200 + assert qr_correction_cycle_count <= 1550 # print("CYCLE count:", qr_correction_cycle_count) # assert False From acabc7d3460f8d30767a3345c8f58323006b9c20 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 19 Oct 2023 17:03:31 +0100 Subject: [PATCH 4/5] Move tile rotation to top of IPU Jacobi loop body. (#52) Allows to optimize out one on-tile-copy, saving an additional 10% of cycles. --- tessellate_ipu/linalg/tile_linalg_jacobi.py | 24 ++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 13b75a1..770330b 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -221,11 +221,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu halfN = Apcols.shape[0] with jax.named_scope("jacobi_eigh"): - # with jax.named_scope("Apqcols_rotation"): - # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - # with jax.named_scope("Vpqcols_rotation"): - # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) - # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + with jax.named_scope("Apqcols_rotation"): + Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + with jax.named_scope("Vpqcols_rotation"): + Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Barrier, to make we sync. both set of tiles A and V and force fused comms. + Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) # Sharded constant with p/q indices to ignore in second update stage. with jax.named_scope("rotset_index_ignored"): @@ -274,13 +275,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu Vqcols, ) - # Barrier, to make we sync. both set of tiles A and V - Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) - # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. - with jax.named_scope("Apqcols_rotation"): - Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - with jax.named_scope("Vpqcols_rotation"): - Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + # # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. + # with jax.named_scope("Apqcols_rotation"): + # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + # with jax.named_scope("Vpqcols_rotation"): + # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) return Apcols, Aqcols, Vpcols, Vqcols From 5f247c44c76fab744c37a22f5268a50e8b1cf605 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 20 Oct 2023 11:06:15 +0100 Subject: [PATCH 5/5] Remove TAS register + axpy instrinsic wrapper class (#53) Replacing the use with `ipu::AMP`, which is much more general and should be able to properly model any IPU AMP unit instruction. --- .../core/vertex/intrinsics_utils.hpp | 53 ++----------------- .../core/vertex/tile_hessenberg_vertex.cpp | 20 ++++--- .../core/vertex/tile_jacobi_vertex.cpp | 2 +- tessellate_ipu/core/vertex/tile_qr_vertex.cpp | 24 ++++----- tessellate_ipu/core/vertex/tile_small_dot.hpp | 7 +-- 5 files changed, 28 insertions(+), 78 deletions(-) diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index 7860ac2..f089cdd 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -112,13 +112,6 @@ ALWAYS_INLINE float ld32(const T* address, unsigned offset) { return result; } -struct __ipu_and_ipumodel_tas { - void put(float v) { __builtin_ipu_put_tas(v); } - float2 f32v2axpy(float2 const& x, float2 const& y) { - return __builtin_ipu_f32v2axpy(x, y); - } -}; - #else #include @@ -152,47 +145,7 @@ IpuVector fma(IpuVector const& x, IpuVector const& y, } // namespace ipu -// Reflect IPU's AXPY semantics in a way that is IPUModel compatible -// IPU-only usage: -// __builtin_ipu_put_tas(v); -// z_prev = __builtin_ipu_f32v2axpy(x, y) -// -// IPUModel-compatible usage: -// __ipu_and_ipumodel_tas tas; -// tas.put(v); -// z_prev = tas.f32v2axpy(x, y) -// -// https://docs.graphcore.ai/projects/poplar-api/en/latest/ipu_intrinsics/ipu_builtins.html#_CPPv423__builtin_ipu_f32v2axpy6float26float2 -struct __ipu_and_ipumodel_tas { - float tas; - float2 prev; - - __ipu_and_ipumodel_tas() : tas{0}, prev{0, 0} {} - - void put(float v) { tas = v; } - - float2 f32v2axpy(float2 const& x, float2 const& y) { - const auto res = prev; - prev = float2{ - // TODO: understand ordering!? - // tas * x[0] + y[0], - // tas * x[1] + y[1], - tas * y[0] + x[0], - tas * y[1] + x[1], - }; - return res; - } -}; - -// And give useful error messages when people port from IPU to IPUModel, e.g. -/* clang-format off */ // need these error messages on one line -/* -/workspaces/tessellate-ipu/tessellate/tile/vertex/intrinsics_utils.hpp:166:3: error: static_assert failed due to requirement '__ipu_false>()': *** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel. - static_assert(__ipu_false(), "*** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel."); - ^ ~~~~~~~~~~~~~~~~ -/workspaces/tessellate-ipu/tessellate/tile/vertex/tile_qr_vertex.cpp:231:12: note: in instantiation of function template specialization '__builtin_ipu_f32v2axpy>' requested here - rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); -*/ +// And give useful error messages when people port from IPU to IPUModel. template constexpr bool __ipu_false() { return !std::is_same::value; @@ -200,12 +153,12 @@ constexpr bool __ipu_false() { template void __builtin_ipu_put_tas(T v) { - static_assert(__ipu_false(), "*** Replace __builtin_ipu_put_tas with __ipu_and_ipumodel_tas for TAS handling on IPUModel."); + static_assert(__ipu_false(), "*** Please use `ipu::AMP` class for TAS handling on IPUModel."); } template T __builtin_ipu_f32v2axpy(T const& x, T const& y) { - static_assert(__ipu_false(), "*** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel."); + static_assert(__ipu_false(), "*** Please use `ipu::AMP::axpy` for `f32v2axpy` intrinsic on IPUModel."); return T{}; } // clang-format on diff --git a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp index 08249a0..253fc53 100644 --- a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" using namespace poplar; @@ -162,9 +163,10 @@ class [[poplar::constraint( // Set the $TAS register with the proper scale. const T s = -scale1[0] * scale2[0]; - // __builtin_ipu_put_tas(s); - __ipu_and_ipumodel_tas tas; - tas.put(s); + // Basic AMP usage with TAS + axpy instruction. + // AMP code using this abstraction is compatible with IPU hw & model. + ipu::AMP amp; + amp.tas(s); // Nothing to do in this worker thread. if (wstart == wend) { @@ -183,20 +185,16 @@ class [[poplar::constraint( vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); // TODO: use ld2x64pace + tapack instructions. for (IndexType idx = 1; idx != wsize; ++idx) { - rtmp = tas.f32v2axpy(xin, vin); - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); + rtmp = amp.axpy(vin, xin); // Grouping here seems to help the compiler optimising loads? xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step); vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); - rout = tas.f32v2axpy(rtmp, rtmp); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); } // Finish the loop, getting the last computation. - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); - rtmp = tas.f32v2axpy(xin, vin); - rout = tas.f32v2axpy(rtmp, rtmp); + rtmp = amp.axpy(vin, xin); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); return true; diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index b7aced0..47f9800 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -391,7 +391,7 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, * See: Gene H. Golub, Charles F. Van Loan, MATRIX COMPUTATIONS, 3rd edition, * Johns Hopkins Chapter 8. */ -class [[poplar::constraint( +class [[poplar::constraint( "elem(*vpcol) != elem(*vpcol_out)", "elem(*vqcol) != elem(*vqcol_out)")]] JacobiUpdateEigenvectors : public MultiVertex { diff --git a/tessellate_ipu/core/vertex/tile_qr_vertex.cpp b/tessellate_ipu/core/vertex/tile_qr_vertex.cpp index f6aa71a..51a1962 100644 --- a/tessellate_ipu/core/vertex/tile_qr_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_qr_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" using namespace poplar; @@ -165,8 +166,8 @@ float QRCorrectionVectorVertex::shared_partial_sqnorms[6] = {-1}; * NOTE: poplar::constraint here to make sure x and v are not part of the same * memory bank, allowing simultaneous loads (see `ld2x64pace` instruction). */ -class [[poplar::constraint( - "elem(*x) != elem(*v)")]] QRHouseholderRowUpdateVertex +class [ + [poplar::constraint("elem(*x) != elem(*v)")]] QRHouseholderRowUpdateVertex : public MultiVertex { public: using T = float; @@ -199,9 +200,10 @@ class [[poplar::constraint( // Set the $TAS register with the proper scale. const T s = -scale1[0] * scale2[0]; - // __builtin_ipu_put_tas(s); - __ipu_and_ipumodel_tas tas; - tas.put(s); + // Basic AMP usage with TAS + axpy instruction. + // AMP code using this abstraction is compatible with IPU hw & model. + ipu::AMP amp; + amp.tas(s); // Nothing to do in this worker thread. if (wstart == wend) { @@ -220,20 +222,16 @@ class [[poplar::constraint( vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); // TODO: use ld2x64pace + tapack instructions. for (IndexType idx = 1; idx != wsize; ++idx) { - rtmp = tas.f32v2axpy(xin, vin); - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); + rtmp = amp.axpy(vin, xin); // Grouping here seems to help the compiler optimising loads? xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step); vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); - rout = tas.f32v2axpy(rtmp, rtmp); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); } // Finish the loop, getting the last computation. - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); - rtmp = tas.f32v2axpy(xin, vin); - rout = tas.f32v2axpy(rtmp, rtmp); + rtmp = amp.axpy(vin, xin); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); return true; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp index 0b380fd..7aa6fbf 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.hpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -72,9 +72,10 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // Necessary if using unsigned `nblocks`. // __builtin_assume(nblocks < 4096); using T2 = float2; - // Using TAS register for the scalar `b`. - __ipu_and_ipumodel_tas tas; - tas.put(b); + // Basic AMP usage with TAS + axpy instruction. + ipu::AMP amp; + amp.tas(b); + T2 av = {a, a}; // Explicit variables passed to inline assembly.