diff --git a/tessellate_ipu/core/__init__.py b/tessellate_ipu/core/__init__.py index 46fb340..9efdfed 100644 --- a/tessellate_ipu/core/__init__.py +++ b/tessellate_ipu/core/__init__.py @@ -45,6 +45,7 @@ primitive_clone, primitive_num_inout_alias_args, ) +from .tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets, make_ipu_vector1d_worker_offsets_and_sizes def tessellate_ipu_cleanup(): diff --git a/tessellate_ipu/core/tile_array.py b/tessellate_ipu/core/tile_array.py index 0979320..5fb8ab2 100644 --- a/tessellate_ipu/core/tile_array.py +++ b/tessellate_ipu/core/tile_array.py @@ -1,8 +1,10 @@ # Copyright (c) 2022 Graphcore Ltd. All rights reserved. +import itertools from dataclasses import dataclass from typing import Any, Sequence, Tuple, Union import chex +import jax.lax import numpy as np from jax.core import ShapedArray from jax.interpreters.xla import DeviceArray @@ -185,6 +187,14 @@ def __getitem__(self, key: Union[SliceType, MultiSliceType]) -> "TileShardedArra check_tile_array_multi_slice(key, self.array.shape) return TileShardedArray(array=self.array[key], tiles=self.tiles[key[0]]) # type:ignore + @classmethod + def concatenate(cls, arrays: Sequence["TileShardedArray"]) -> "TileShardedArray": + """Concatenate tile sharded arrays along the first axis.""" + assert all([isinstance(v, TileShardedArray) for v in arrays]) + outarray = jax.lax.concatenate([v.array for v in arrays], dimension=0) + outtiles = tuple(itertools.chain(*[v.tiles for v in arrays])) + return TileShardedArray(array=outarray, tiles=outtiles) + def tile_put_sharded(array: DeviceArray, tiles: Sequence[int]) -> TileShardedArray: """Shard a JAX array over tiles on the first axis. diff --git a/tessellate_ipu/core/tile_interpreter_vertex_utils.py b/tessellate_ipu/core/tile_interpreter_vertex_utils.py index 2ba3c2a..52e2e08 100644 --- a/tessellate_ipu/core/tile_interpreter_vertex_utils.py +++ b/tessellate_ipu/core/tile_interpreter_vertex_utils.py @@ -1,6 +1,6 @@ # Copyright (c) 2022 Graphcore Ltd. All rights reserved. import math -from typing import List +from typing import List, Optional import numpy as np from numpy.typing import DTypeLike, NDArray @@ -25,10 +25,69 @@ def make_num_elements_per_worker(N: int, num_workers: int) -> NDArray[np.int32]: return num_elements +def make_ipu_vector1d_worker_offsets_and_sizes( + size: int, + vector_size: int = 2, + num_workers: int = 6, + wdtype: DTypeLike = np.uint16, + allow_overlap: bool = False, + grain_size: Optional[int] = None, +) -> NDArray[np.int_]: + """Make worker sizes + offsets for a 1D array workload, i.e. how many + data vectors per worker thread (with starting offset)? + + Args: + size: Size of the vector to divide. + vector_size: Vector size (2: float, 4: half). + num_workers: Number of workers. + wdtype: Worklists dtype. + allow_overlap: Allowing overlap between workers. Make it easier to deal with remainer term. + grain_size: Optional grain size. vector_size by default. Minimal size per thread. + Returns: + (NUM_WORKERS, 2) data offset + size per worker thread. + + NOTE: offsets and sizes expressed in vector size unit! + """ + grain_size = grain_size or vector_size + grain_scale = grain_size // vector_size + # TODO: support properly odd size. + assert size % 2 == 0, "Not supporting odd sizing at the moment." + # Base checks! + assert grain_size % vector_size == 0 + assert size >= grain_size, f"Requires at least a size of {grain_size}." + assert ( + size % grain_size == 0 or allow_overlap + ), f"Requires the size, {size}, divisible by the grain size {grain_size} (or overlap allowed)." + + # Offset+size array to build. + offset_size_arr = np.zeros((num_workers, 2), dtype=np.int32) + + # Base worksize on the first few workers. + base_worksize: int = math.ceil(size / (grain_size * num_workers)) + num_base_workers = size // (grain_size * base_worksize) + # Offsets + size + offset_size_arr[:num_base_workers, 0] = np.arange(num_base_workers) * base_worksize * grain_scale + offset_size_arr[:num_base_workers, 1] = base_worksize * grain_scale + if num_base_workers == num_workers: + return offset_size_arr.astype(wdtype) + + # Remainer term, for the next thread => all which is left, with potential overlap. + rem_worksize = size - base_worksize * grain_size * num_base_workers + rem_worksize = math.ceil(rem_worksize / grain_size) + offset_size_arr[num_base_workers, 0] = size / vector_size - rem_worksize * grain_scale + offset_size_arr[num_base_workers, 1] = rem_worksize * grain_scale + # Rest already filled with zeros... + return offset_size_arr.astype(wdtype) + + def make_ipu_vector1d_worker_offsets( - size: int, vector_size: int = 2, num_workers: int = 6, wdtype: DTypeLike = np.uint16 + size: int, + vector_size: int = 2, + num_workers: int = 6, + wdtype: DTypeLike = np.uint16, + grain_size: Optional[int] = None, ) -> NDArray[np.int_]: - """Make the QR householder row update worker sizes, i.e. how many + """Make worker offsets (with additional padding) i.e. how many data vectors per worker thread? Args: @@ -36,26 +95,35 @@ def make_ipu_vector1d_worker_offsets( vector_size: Vector size (2: float, 4: half). num_workers: Number of workers. wdtype: Worklists dtype. + grain_size: Optional grain size. vector_size by default. Returns: - (6,) number of data vectors per thread. + (NUM_WORKERS + 1,) data offset per worker thread. """ + grain_size = grain_size or vector_size + grain_scale = grain_size // vector_size def make_offsets_fn(sizes): sizes = [0] + sizes - offsets = np.cumsum(np.array(sizes, wdtype), dtype=wdtype) + offsets = np.cumsum(np.array(sizes, wdtype) * grain_scale, dtype=wdtype) return offsets - assert size % vector_size == 0 + # TODO: support properly odd size. + assert size % 2 == 0, "Not supporting odd sizing at the moment." + # Base checks! + assert grain_size % vector_size == 0 + assert size >= grain_size, f"Requires at least a size of {grain_size}." + assert size % grain_size == 0, f"Requires the size, {size}, divisible by the grain size {grain_size}." + # Base worksize on the first few workers. - base_worksize: int = math.ceil(size / (vector_size * num_workers)) - num_base_workers = size // (vector_size * base_worksize) + base_worksize: int = math.ceil(size / (grain_size * num_workers)) + num_base_workers = size // (grain_size * base_worksize) worker_sizes: List[int] = [base_worksize] * num_base_workers if num_base_workers == num_workers: return make_offsets_fn(worker_sizes) # Remainer term, for the next thread. - rem_worksize = size - base_worksize * vector_size * num_base_workers - rem_worksize = rem_worksize // vector_size + rem_worksize = size - base_worksize * grain_size * num_base_workers + rem_worksize = rem_worksize // grain_size worker_sizes += [rem_worksize] # Fill the rest with zeros. unused_workers = num_workers - num_base_workers - 1 diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index a254777..f089cdd 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -64,6 +64,7 @@ ALWAYS_INLINE T ipu_div_by_6(T n) noexcept { */ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { // TAS register, used for __builtin_ipu_f32v2axpy. + // TODO: use `__builtin_ipu_uput`? asm volatile( R"l( uput $TAS, %[sv] )l" @@ -72,6 +73,20 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { :); } +/** + * @brief Zero AACC registers. + */ +ALWAYS_INLINE void __builtin_ipu_aacc_zero() { + asm (R"( + setzi $a0, 0x8 + uput $FP_CLR, $a0 + )" + : + : + : "$a0"); +} + + /** * @brief IPU cmac f32 instruction. */ @@ -97,13 +112,6 @@ ALWAYS_INLINE float ld32(const T* address, unsigned offset) { return result; } -struct __ipu_and_ipumodel_tas { - void put(float v) { __builtin_ipu_put_tas(v); } - float2 f32v2axpy(float2 const& x, float2 const& y) { - return __builtin_ipu_f32v2axpy(x, y); - } -}; - #else #include @@ -137,47 +145,7 @@ IpuVector fma(IpuVector const& x, IpuVector const& y, } // namespace ipu -// Reflect IPU's AXPY semantics in a way that is IPUModel compatible -// IPU-only usage: -// __builtin_ipu_put_tas(v); -// z_prev = __builtin_ipu_f32v2axpy(x, y) -// -// IPUModel-compatible usage: -// __ipu_and_ipumodel_tas tas; -// tas.put(v); -// z_prev = tas.f32v2axpy(x, y) -// -// https://docs.graphcore.ai/projects/poplar-api/en/latest/ipu_intrinsics/ipu_builtins.html#_CPPv423__builtin_ipu_f32v2axpy6float26float2 -struct __ipu_and_ipumodel_tas { - float tas; - float2 prev; - - __ipu_and_ipumodel_tas() : tas{0}, prev{0, 0} {} - - void put(float v) { tas = v; } - - float2 f32v2axpy(float2 const& x, float2 const& y) { - const auto res = prev; - prev = float2{ - // TODO: understand ordering!? - // tas * x[0] + y[0], - // tas * x[1] + y[1], - tas * y[0] + x[0], - tas * y[1] + x[1], - }; - return res; - } -}; - -// And give useful error messages when people port from IPU to IPUModel, e.g. -/* clang-format off */ // need these error messages on one line -/* -/workspaces/tessellate-ipu/tessellate/tile/vertex/intrinsics_utils.hpp:166:3: error: static_assert failed due to requirement '__ipu_false>()': *** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel. - static_assert(__ipu_false(), "*** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel."); - ^ ~~~~~~~~~~~~~~~~ -/workspaces/tessellate-ipu/tessellate/tile/vertex/tile_qr_vertex.cpp:231:12: note: in instantiation of function template specialization '__builtin_ipu_f32v2axpy>' requested here - rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); -*/ +// And give useful error messages when people port from IPU to IPUModel. template constexpr bool __ipu_false() { return !std::is_same::value; @@ -185,12 +153,12 @@ constexpr bool __ipu_false() { template void __builtin_ipu_put_tas(T v) { - static_assert(__ipu_false(), "*** Replace __builtin_ipu_put_tas with __ipu_and_ipumodel_tas for TAS handling on IPUModel."); + static_assert(__ipu_false(), "*** Please use `ipu::AMP` class for TAS handling on IPUModel."); } template T __builtin_ipu_f32v2axpy(T const& x, T const& y) { - static_assert(__ipu_false(), "*** Replace __builtin_ipu_f32v2axpy with __ipu_and_ipumodel_tas for TAS handling on IPUModel."); + static_assert(__ipu_false(), "*** Please use `ipu::AMP::axpy` for `f32v2axpy` intrinsic on IPUModel."); return T{}; } // clang-format on diff --git a/tessellate_ipu/core/vertex/ipu_amp.hpp b/tessellate_ipu/core/vertex/ipu_amp.hpp new file mode 100644 index 0000000..0ef7b85 --- /dev/null +++ b/tessellate_ipu/core/vertex/ipu_amp.hpp @@ -0,0 +1,127 @@ +// Copyright (c) 2023 Graphcore Ltd. All rights reserved. +#pragma once +#include + +#include "intrinsics_utils.hpp" +#include "ipu_model_types.hpp" + +namespace ipu { + +/** + * @brief Thin abstraction of the IPU AMP unit(s) and registers, allowing + * to write generic code compiling on IPU model and IPU hardware. + * + * NOTE: zero-cost abstraction on IPU hardware. + * + * The AMP class is modelling AACC registers as well as AMP unit instructions + * on the IPU model, reproducing the expected behaviour of the hardware. + */ +template +class AMP { + public: + // TODO: support half as well. + static_assert(std::is_same_v); + using FPType = T; + /** Number of AACC register available in hw. */ + // TODO: use TFPU_AMP_UNITS_PER_SET and TFPU_AACC_PER_AMP_UNIT; + static constexpr unsigned NumAACC = 16; + + // TODO: random initialization on IPU model of registers. + AMP() noexcept = default; + // No copy + no move allowed! + AMP(const AMP&) = delete; + AMP(AMP&&) = delete; + + /** + * @brief Set the value of the TAS register, used in + * `axpy` operation. + */ + ALWAYS_INLINE void tas(FPType val) noexcept { +#ifdef __IPU__ + __builtin_ipu_put_tas(val); +#else + m_tas = val; +#endif + } + /** + * @brief Zero AACC registers. + */ + ALWAYS_INLINE void aaccZero() noexcept { +#ifdef __IPU__ + __builtin_ipu_aacc_zero(); +#else + for (unsigned idx = 0; idx < NumAACC; ++idx) { + m_aacc[idx] = 0; + } +#endif + } + + /** + * @brief Scaled-add `axpy` intrinsic. Only supported on FP32. + * NOTE: act as 1 stage pipeline, storing result in AACC[0...2] + */ + ALWAYS_INLINE float2 axpy(float2 x, float2 y) noexcept { + using T2 = float2; +#ifdef __IPU__ + // Weird ordering here? Bug in the intrinsic definition? + return __builtin_ipu_f32v2axpy(y, x); +#else + // Simulating pipeline with storing in AACC[0] and AACC[2]. + const auto res = T2{m_aacc[0], m_aacc[2]}; + // FIXME/TODO: understand ordering!? + m_aacc[0] = m_tas * x[0] + y[0]; + m_aacc[2] = m_tas * x[1] + y[1]; + return res; +#endif + } + + /** + * @brief Outer-product `aop` intrinsic. Only supported on FP32. + * Storing results in AACC[0...6] + */ + void aop(float2 x, float2 y) noexcept { +#ifdef __IPU__ + // Note: third argument not used by hw. + __builtin_ipu_f32v2aop(x, y, 0); +#else + // Multiply + accumulate. + m_aacc[0] += x[0] * y[0]; + m_aacc[2] += x[1] * y[0]; + m_aacc[4] += x[0] * y[1]; + m_aacc[6] += x[1] * y[1]; +#endif + } + + /** + * @brief `gina` instruction: get AACC register + propagate. + * FIXME: support non-zero flag/index. + */ + template + float2 gina(float2 val) noexcept { + using T2 = float2; +#ifdef __IPU__ + return __builtin_ipu_f32v2gina(val, 0); +#else + // TODO: implement GINA_IMMFLAGS__SET__GET + const auto res = T2{m_aacc[0], m_aacc[2]}; + // Propagate accumulator states. + for (unsigned idx = 4; idx < NumAACC; idx += 4) { + m_aacc[idx - 4] = m_aacc[idx]; + m_aacc[idx - 2] = m_aacc[idx + 2]; + } + m_aacc[NumAACC - 4] = val[0]; + m_aacc[NumAACC - 2] = val[1]; + return res; +#endif + } + + private: +#ifndef __IPU__ + // Simulating AACC registers on IPU model. + FPType m_aacc[NumAACC]; + // Simulating TAS register on IPU model. + FPType m_tas; +#endif +}; + +} // namespace ipu diff --git a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp index 08249a0..253fc53 100644 --- a/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_hessenberg_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" using namespace poplar; @@ -162,9 +163,10 @@ class [[poplar::constraint( // Set the $TAS register with the proper scale. const T s = -scale1[0] * scale2[0]; - // __builtin_ipu_put_tas(s); - __ipu_and_ipumodel_tas tas; - tas.put(s); + // Basic AMP usage with TAS + axpy instruction. + // AMP code using this abstraction is compatible with IPU hw & model. + ipu::AMP amp; + amp.tas(s); // Nothing to do in this worker thread. if (wstart == wend) { @@ -183,20 +185,16 @@ class [[poplar::constraint( vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); // TODO: use ld2x64pace + tapack instructions. for (IndexType idx = 1; idx != wsize; ++idx) { - rtmp = tas.f32v2axpy(xin, vin); - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); + rtmp = amp.axpy(vin, xin); // Grouping here seems to help the compiler optimising loads? xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step); vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); - rout = tas.f32v2axpy(rtmp, rtmp); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); } // Finish the loop, getting the last computation. - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); - rtmp = tas.f32v2axpy(xin, vin); - rout = tas.f32v2axpy(rtmp, rtmp); + rtmp = amp.axpy(vin, xin); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); return true; diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index e30de5f..47f9800 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" #include "tile_small_dot.hpp" using namespace poplar; @@ -80,10 +81,11 @@ class JacobiSymSchur2 : public Vertex { }; template -void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, - T* qcol_updated, T* cs, unsigned p, unsigned q, - unsigned short wstart, - unsigned short wend) noexcept { +inline void jacob_update_first_step(const T* pcol, const T* qcol, + T* pcol_updated, T* qcol_updated, T* cs, + unsigned p, unsigned q, + unsigned short wstart, + unsigned short wend) noexcept { using T2 = float2; using IndexType = unsigned short; @@ -106,7 +108,7 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, float2* ptr_qcol_updated = reinterpret_cast(qcol_updated) + wstart; // Apply Schur2 cs rotation to p/q columns (optimized kernel). rotation2d_f32(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated, - ptr_qcol_updated, wsize); + ptr_qcol_updated, wsize); // Update main values App, Apq, Aqq pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; @@ -184,6 +186,128 @@ class [[poplar::constraint( } }; +/** + * @brief Jacobi update second step, using Schur2 coefficient from + * other pairs of columns. + */ +template +inline void jacobi_update_second_step(const unsigned* rotset_sorted_arr, + const T* cs_arr, const T* pcol, + const T* qcol, T* pcol_updated, + T* qcol_updated, const unsigned wstart, + const unsigned wsize) noexcept { + // Necessary for generating `rpt` loop. + __builtin_assume(wsize < 4096); + using T2 = float2; + // Increment pointers. NOTE: unrolling creating "4x" factor. + rotset_sorted_arr += 2 * wstart; + const T2* cs_arr_ptr = reinterpret_cast(cs_arr) + wstart; + + // Basic usage of AMP unit with `aop` outer-product :) + ipu::AMP amp; + amp.aaccZero(); + + const T2 zeros{0, 0}; + T2 res, cs0, cs1, Sp0, Sq0, Sp1, Sq1, tmp0, tmp1; + unsigned k0, l0, k1, l1; + + // The loop body is roughly the following equations: + // const T Spk = pcol_ptr[k]; + // const T Spl = pcol_ptr[l]; + // const T Sqk = qcol_ptr[k]; + // const T Sql = qcol_ptr[l]; + + // pcol_updated_ptr[k] = c * Spk - s * Spl; + // pcol_updated_ptr[l] = s * Spk + c * Spl; + // qcol_updated_ptr[k] = c * Sqk - s * Sql; + // qcol_updated_ptr[l] = s * Sqk + c * Sql; + + // Problem: generate poor bundling of operations in the loop. + // Solution: unroll 2 steps + f32v2aop + manual re-ordering. + // NOTE: f32v2aop mostly useful for reducing register pressure, + // as results are stored in AACC registers (not AUX). Just saving 1 compute + // cycle. + + // Pre-loading due to unrolling + reordering. + k0 = ipu::load_postinc(&rotset_sorted_arr, 1); + l0 = ipu::load_postinc(&rotset_sorted_arr, 1); + cs0 = ipu::load_postinc(&cs_arr_ptr, 1); + Sp0 = {pcol[k0], pcol[l0]}; + for (unsigned half_idx = 0; half_idx < wsize; ++half_idx) { + // Pseudo bundling of instructions, to help popc. + { + Sq0[0] = qcol[k0]; + amp.aop(cs0, Sp0); + } + { + k1 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + l1 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp1 = amp.template gina<0>(zeros); + } + { + Sq0[1] = qcol[l0]; + pcol_updated[k0] = tmp0[0] - tmp1[1]; + } + { + pcol_updated[l0] = tmp0[1] + tmp1[0]; + amp.aop(cs0, Sq0); + } + { + cs1 = ipu::load_postinc(&cs_arr_ptr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + Sp1[0] = pcol[k1]; + tmp1 = amp.template gina<0>(zeros); + } + { + Sp1[1] = pcol[l1]; + qcol_updated[k0] = tmp0[0] - tmp1[1]; + } + // Unrolling: second part. + // NOTE: inputs already (partially) loaded. + { + qcol_updated[l0] = tmp0[1] + tmp1[0]; + amp.aop(cs1, Sp1); + } + { + Sq1[0] = qcol[k1]; + tmp0 = amp.template gina<0>(zeros); + } + { + Sq1[1] = qcol[l1]; + tmp1 = amp.template gina<0>(zeros); + } + { + k0 = ipu::load_postinc(&rotset_sorted_arr, 1); + pcol_updated[k1] = tmp0[0] - tmp1[1]; + } + { + pcol_updated[l1] = tmp0[1] + tmp1[0]; + amp.aop(cs1, Sq1); + } + { + l0 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + cs0 = ipu::load_postinc(&cs_arr_ptr, 1); + tmp1 = amp.template gina<0>(zeros); + } + { + Sp0[0] = pcol[k0]; + qcol_updated[k1] = tmp0[0] - tmp1[1]; + } + { + qcol_updated[l1] = tmp0[1] + tmp1[0]; + Sp0[1] = pcol[l0]; + } + } +} + class JacobiUpdateSecondStep : public MultiVertex { public: using T = float; @@ -199,7 +323,7 @@ class JacobiUpdateSecondStep : public MultiVertex { rotset_idx_ignored; // (1,) index in rotset to ignore. Input> - worker_offsets; // (7,) threads work size + 1. + worker_offsets_sizes; // (2, 6) worker offset + size Input> pcol; // (N+2,) p column Input> qcol; // (N+2,) q column @@ -213,11 +337,14 @@ class JacobiUpdateSecondStep : public MultiVertex { bool compute(unsigned wid) { // Size of the index prefix in pcol and qcol. - constexpr int INDEX_PREFIX = 2; - // Worker load: start + end vectorized indexes. - const IndexType wstart = worker_offsets[wid]; - const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; + constexpr unsigned INDEX_PREFIX = 2; + // Worker load: start + size vectorized indexes. + const unsigned wstart = worker_offsets_sizes[2 * wid]; + const unsigned wsize = worker_offsets_sizes[2 * wid + 1]; + + // Forward pq indices. + pcol_updated[0] = pcol[0]; + qcol_updated[0] = qcol[0]; // Use (p, q) = (1, 0) for ignore idx. const unsigned ignore_idx = 2 * rotset_idx_ignored[0]; @@ -229,66 +356,33 @@ class JacobiUpdateSecondStep : public MultiVertex { auto pcol_updated_ptr = pcol_updated.data() + INDEX_PREFIX; auto qcol_updated_ptr = qcol_updated.data() + INDEX_PREFIX; - // Forward pq indices. - pcol_updated[0] = pcol[0]; - qcol_updated[0] = qcol[0]; - - // Parallized loop on update using other columns coefficients - for (IndexType half_idx = 0; half_idx != wsize; ++half_idx) { - // TODO: cleaning pq indices offset. - const unsigned k = rotset_sorted_arr[2 * half_idx + 2 * wstart]; - const unsigned l = rotset_sorted_arr[2 * half_idx + 1 + 2 * wstart]; - - const T c = cs_arr[2 * half_idx + 2 * wstart]; - const T s = cs_arr[2 * half_idx + 1 + 2 * wstart]; - - // 4 coefficients updates! - // TODO: vectorization?! - const T Spk = pcol_ptr[k]; - const T Spl = pcol_ptr[l]; - - const T Sqk = qcol_ptr[k]; - const T Sql = qcol_ptr[l]; - - pcol_updated_ptr[k] = c * Spk - s * Spl; - pcol_updated_ptr[l] = s * Spk + c * Spl; - - qcol_updated_ptr[k] = c * Sqk - s * Sql; - qcol_updated_ptr[l] = s * Sqk + c * Sql; - } + jacobi_update_second_step(rotset_sorted_arr.data(), cs_arr.data(), pcol_ptr, + qcol_ptr, pcol_updated_ptr, qcol_updated_ptr, + wstart, wsize); return true; } }; -template +template void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, T* vqcol_updated, T c, T s, unsigned short wstart, unsigned short wend) noexcept { - using T2 = float2; // Using `uint16` seems to be generating more efficient loops? using IndexType = unsigned short; - - const T2 cvec = T2{c, c}; - const T2 svec = T2{s, s}; const IndexType wsize = wend - wstart; + using T2 = float2; + const T2 cs_vec = T2{c, s}; + // pcol, qcol and results pointers. const T2* ptr_pcol = reinterpret_cast(vpcol) + wstart; const T2* ptr_qcol = reinterpret_cast(vqcol) + wstart; T2* ptr_pcol_updated = reinterpret_cast(vpcol_updated) + wstart; T2* ptr_qcol_updated = reinterpret_cast(vqcol_updated) + wstart; - - for (IndexType idx = 0; idx != wsize; ++idx) { - const T2 vpvec = ipu::load_postinc(&ptr_pcol, 1); - const T2 vqvec = ipu::load_postinc(&ptr_qcol, 1); - - const T2 vpvec_updated = cvec * vpvec - svec * vqvec; - const T2 vqvec_updated = svec * vpvec + cvec * vqvec; - - ipu::store_postinc(&ptr_qcol_updated, vqvec_updated, 1); - ipu::store_postinc(&ptr_pcol_updated, vpvec_updated, 1); - } + // Apply Schur2 cs rotation to p/q columns (optimized kernel). + rotation2d_f32(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated, + ptr_qcol_updated, wsize); } /** @@ -298,7 +392,8 @@ void jacob_update_eigenvectors(const T* vpcol, const T* vqcol, T* vpcol_updated, * Johns Hopkins Chapter 8. */ class [[poplar::constraint( - "elem(*vpcol) != elem(*vqcol)")]] JacobiUpdateEigenvectors + "elem(*vpcol) != elem(*vpcol_out)", + "elem(*vqcol) != elem(*vqcol_out)")]] JacobiUpdateEigenvectors : public MultiVertex { public: using T = float; @@ -336,12 +431,12 @@ class [[poplar::constraint( vqcol_out[0] = vqcol[0]; // Swapping pointers if necessary. if (p <= q) { - jacob_update_eigenvectors( + jacob_update_eigenvectors( vpcol.data() + INDEX_PREFIX, vqcol.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, c, s, wstart, wend); } else { - jacob_update_eigenvectors( + jacob_update_eigenvectors( vqcol.data() + INDEX_PREFIX, vpcol.data() + INDEX_PREFIX, vqcol_out.data() + INDEX_PREFIX, vpcol_out.data() + INDEX_PREFIX, c, s, wstart, wend); diff --git a/tessellate_ipu/core/vertex/tile_qr_vertex.cpp b/tessellate_ipu/core/vertex/tile_qr_vertex.cpp index f6aa71a..51a1962 100644 --- a/tessellate_ipu/core/vertex/tile_qr_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_qr_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" using namespace poplar; @@ -165,8 +166,8 @@ float QRCorrectionVectorVertex::shared_partial_sqnorms[6] = {-1}; * NOTE: poplar::constraint here to make sure x and v are not part of the same * memory bank, allowing simultaneous loads (see `ld2x64pace` instruction). */ -class [[poplar::constraint( - "elem(*x) != elem(*v)")]] QRHouseholderRowUpdateVertex +class [ + [poplar::constraint("elem(*x) != elem(*v)")]] QRHouseholderRowUpdateVertex : public MultiVertex { public: using T = float; @@ -199,9 +200,10 @@ class [[poplar::constraint( // Set the $TAS register with the proper scale. const T s = -scale1[0] * scale2[0]; - // __builtin_ipu_put_tas(s); - __ipu_and_ipumodel_tas tas; - tas.put(s); + // Basic AMP usage with TAS + axpy instruction. + // AMP code using this abstraction is compatible with IPU hw & model. + ipu::AMP amp; + amp.tas(s); // Nothing to do in this worker thread. if (wstart == wend) { @@ -220,20 +222,16 @@ class [[poplar::constraint( vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); // TODO: use ld2x64pace + tapack instructions. for (IndexType idx = 1; idx != wsize; ++idx) { - rtmp = tas.f32v2axpy(xin, vin); - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); + rtmp = amp.axpy(vin, xin); // Grouping here seems to help the compiler optimising loads? xin = ipu::load_postinc(&ptr_inxdata_f2, ptr_step); vin = ipu::load_postinc(&ptr_vdata_f2, ptr_step); - rout = tas.f32v2axpy(rtmp, rtmp); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); } // Finish the loop, getting the last computation. - // rtmp = __builtin_ipu_f32v2axpy(xin, vin); - // rout = __builtin_ipu_f32v2axpy(rtmp, rtmp); - rtmp = tas.f32v2axpy(xin, vin); - rout = tas.f32v2axpy(rtmp, rtmp); + rtmp = amp.axpy(vin, xin); + rout = amp.axpy(rtmp, rtmp); ipu::store_postinc(&ptr_outxdata_f2, rout, ptr_step); return true; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp index 26bd338..7aa6fbf 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.hpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -1,5 +1,6 @@ // Copyright (c) 2023 Graphcore Ltd. All rights reserved. #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" /** * @brief z = a*x + b*y float32 implementation. @@ -35,9 +36,9 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // __builtin_assume(nblocks < 4096); using T2 = float2; const T2 av = {a, a}; - // Using TAS register for one of the scalar. - __ipu_and_ipumodel_tas tas; - tas.put(b); + // Basic AMP usage with TAS + axpy instruction. + ipu::AMP amp; + amp.tas(b); T2 res, xv, yv, zv, tmp; @@ -49,13 +50,11 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // popc should be able to generate an optimal rpt loop. { xv = ipu::load_postinc(&x, 1); - // TODO: fix ordering of arguments in `f32v2axpy`. - tmp = tas.f32v2axpy(res, yv); + tmp = amp.axpy(yv, res); } { yv = ipu::load_postinc(&y, 1); - // TODO: fix ordering of arguments in `f32v2axpy`. - zv = tas.f32v2axpy(tmp, tmp); + zv = amp.axpy(tmp, tmp); } { ipu::store_postinc(&z, zv, 1); @@ -73,9 +72,10 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // Necessary if using unsigned `nblocks`. // __builtin_assume(nblocks < 4096); using T2 = float2; - // Using TAS register for the scalar `b`. - __ipu_and_ipumodel_tas tas; - tas.put(b); + // Basic AMP usage with TAS + axpy instruction. + ipu::AMP amp; + amp.tas(b); + T2 av = {a, a}; // Explicit variables passed to inline assembly. @@ -139,7 +139,8 @@ template inline void rotation2d_f32(float2 cs, const float2 *inrow0, const float2 *inrow1, float2 *outrow0, float2 *outrow1, rptsize_t nblocks) { - // TODO: investigate using IPU AMP unit? + // axplusby is using one AMP unit. TODO: investigate using more! axplusby_f32(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); + // NOTE: inrow1+0, outrow1 arguments order necessary due to bank constraints! axplusby_f32(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks); } diff --git a/tessellate_ipu/lax/tile_lax_small_dot.py b/tessellate_ipu/lax/tile_lax_small_dot.py index 0a4076a..95c3b2f 100644 --- a/tessellate_ipu/lax/tile_lax_small_dot.py +++ b/tessellate_ipu/lax/tile_lax_small_dot.py @@ -5,8 +5,7 @@ import numpy as np from jax.core import ShapedArray -from tessellate_ipu.core import declare_ipu_tile_primitive -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import declare_ipu_tile_primitive, make_ipu_vector1d_worker_offsets def get_small_dot_vertex_gp_filename() -> str: diff --git a/tessellate_ipu/linalg/tile_linalg_hessenberg.py b/tessellate_ipu/linalg/tile_linalg_hessenberg.py index a9c88a6..5ffa926 100644 --- a/tessellate_ipu/linalg/tile_linalg_hessenberg.py +++ b/tessellate_ipu/linalg/tile_linalg_hessenberg.py @@ -15,7 +15,7 @@ tile_put_replicated, tile_put_sharded, ) -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets from .tile_linalg_qr import dot_product1d_p diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 5319b56..770330b 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -18,7 +18,8 @@ tile_put_replicated, tile_put_sharded, ) -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets, make_ipu_vector1d_worker_offsets_and_sizes +from tessellate_ipu.lax import tile_fill from tessellate_ipu.utils import NDArray Array = Any @@ -69,9 +70,12 @@ def get_jacobi_vertex_gp_filename() -> str: inputs=["cs_arr", "rotset_sorted_arr", "rotset_idx_ignored", "pcol", "qcol"], outputs={"cs_arr": 0, "pcol_updated": 3, "qcol_updated": 4}, constants={ - "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16 + # NOTE: using grain_size=4 because of partial loop unrolling + # Rescale the size to be directly in grain size unit. + "worker_offsets_sizes": lambda inavals, *_: make_ipu_vector1d_worker_offsets_and_sizes( + inavals[3].size - INDEX_PREFIX, vector_size=2, grain_size=4, wdtype=np.uint16, allow_overlap=True ) + // np.array([[1, 2]], dtype=np.uint16) }, gp_filename=get_jacobi_vertex_gp_filename(), perf_estimate=200, @@ -217,11 +221,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu halfN = Apcols.shape[0] with jax.named_scope("jacobi_eigh"): - # with jax.named_scope("Apqcols_rotation"): - # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - # with jax.named_scope("Vpqcols_rotation"): - # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) - # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + with jax.named_scope("Apqcols_rotation"): + Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + with jax.named_scope("Vpqcols_rotation"): + Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Barrier, to make we sync. both set of tiles A and V and force fused comms. + Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) # Sharded constant with p/q indices to ignore in second update stage. with jax.named_scope("rotset_index_ignored"): @@ -232,6 +237,14 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore jacobi_update_first_step_p, Apcols, Aqcols ) + # Append zero indices to the rotset, for loop unrolling in `jacobi_update_second_step` + rotset_zeros = tile_fill((2,), 0, dtype=rotset_sorted_sharded.dtype, tiles=(0,)) + # Barrier to make sure communication gets fused into a single block. + rotset_zeros, rotset_sorted_sharded, cs_per_tile = tile_data_barrier( + rotset_zeros, rotset_sorted_sharded, cs_per_tile + ) + rotset_sorted_sharded = TileShardedArray.concatenate([rotset_sorted_sharded, rotset_zeros]) + # Replicate Schur decomposition + rotset across all A tiles: (2*N//2) comms. with jax.named_scope("rotset_sorted_replicated"): rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles) @@ -262,13 +275,12 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu Vqcols, ) - # Barrier, to make we sync. both set of tiles A and V - Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) - # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. - with jax.named_scope("Apqcols_rotation"): - Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) - with jax.named_scope("Vpqcols_rotation"): - Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) + # Apcols, Aqcols, Vpcols, Vqcols = tile_data_barrier(Apcols, Aqcols, Vpcols, Vqcols) + # # Move columns between tiles following Jacobi rotation pattern. 2*N commns per tile. + # with jax.named_scope("Apqcols_rotation"): + # Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols) + # with jax.named_scope("Vpqcols_rotation"): + # Vpcols, Vqcols = tile_rotate_columns(Vpcols, Vqcols) return Apcols, Aqcols, Vpcols, Vqcols diff --git a/tessellate_ipu/linalg/tile_linalg_qr.py b/tessellate_ipu/linalg/tile_linalg_qr.py index e8412e2..adb2030 100644 --- a/tessellate_ipu/linalg/tile_linalg_qr.py +++ b/tessellate_ipu/linalg/tile_linalg_qr.py @@ -7,7 +7,7 @@ from jax.core import ShapedArray from tessellate_ipu import TileShardedArray, create_ipu_tile_primitive, tile_map, tile_put_replicated, tile_put_sharded -from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.core import make_ipu_vector1d_worker_offsets Array = Any diff --git a/tests/core/test_tile_interpreter_vertex_utils.py b/tests/core/test_tile_interpreter_vertex_utils.py index b2700de..5d1d336 100644 --- a/tests/core/test_tile_interpreter_vertex_utils.py +++ b/tests/core/test_tile_interpreter_vertex_utils.py @@ -7,6 +7,7 @@ from tessellate_ipu.core.tile_interpreter_vertex_utils import ( make_ipu_vector1d_worker_offsets, + make_ipu_vector1d_worker_offsets_and_sizes, make_num_elements_per_worker, ) @@ -45,3 +46,22 @@ def test__tile_vertex_utils__make_num_elements_per_worker(self, N, expected_num_ num_elements = make_num_elements_per_worker(N, num_workers) assert np.sum(num_elements) == N npt.assert_array_equal(num_elements, expected_num_elements) + + @parameterized.parameters( + {"N": 4, "expected_offsets": [0, 2, 0, 0, 0, 0], "expected_sizes": [2, 0, 0, 0, 0, 0]}, + {"N": 6, "expected_offsets": [0, 1, 0, 0, 0, 0], "expected_sizes": [2, 2, 0, 0, 0, 0]}, + {"N": 24, "expected_offsets": [0, 2, 4, 6, 8, 10], "expected_sizes": [2, 2, 2, 2, 2, 2]}, + {"N": 30, "expected_offsets": [0, 4, 8, 11, 0, 0], "expected_sizes": [4, 4, 4, 4, 0, 0]}, + {"N": 128, "expected_offsets": [0, 12, 24, 36, 48, 60], "expected_sizes": [12, 12, 12, 12, 12, 4]}, + ) + def test__tile_vertex_utils__make_ipu_vector1d_worker_offsets_and_sizes(self, N, expected_offsets, expected_sizes): + vector_size = 2 + grain_size = 4 + num_workers = 6 + woffsets_sizes = make_ipu_vector1d_worker_offsets_and_sizes( + N, vector_size, num_workers=num_workers, wdtype=np.int16, grain_size=grain_size, allow_overlap=True + ) + assert woffsets_sizes.shape == (num_workers, 2) + assert woffsets_sizes.dtype == np.int16 + npt.assert_array_equal(woffsets_sizes[:, 0], expected_offsets) + npt.assert_array_equal(woffsets_sizes[:, 1], expected_sizes) diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index e54dd6c..e3eb87a 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -128,7 +128,7 @@ def jacobi_update_first_step_fn(pq, pcol, qcol): # assert False def test__jacobi_update_eigenvectors_vertex__benchmark_performance(self): - N = 256 + N = 512 tiles = (0,) cs = np.array([0.2, 0.5], dtype=np.float32) pcol = np.random.randn(1, N).astype(np.float32) @@ -158,7 +158,7 @@ def jacobi_update_eigenvectors_fn(cs, pcol, qcol): # Cycle count reference for scale_add: 64(375), 128(467), 256(665), 512(1043) start, end = np.asarray(start)[0], np.asarray(end)[0] qr_correction_cycle_count = end[0] - start[0] - assert qr_correction_cycle_count <= 2200 + assert qr_correction_cycle_count <= 1550 # print("CYCLE count:", qr_correction_cycle_count) # assert False @@ -177,8 +177,11 @@ def test__jacobi_eigh__single_iteration(self): npt.assert_array_almost_equal(np.linalg.eigh(A)[0], np.linalg.eigh(x)[0], decimal=5) @unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles") - def test__jacobi_eigh_raw__proper_eigh_result(self): - N = 8 + @parameterized.parameters( + {"N": 10}, # testing Jacobi 2nd update where grain size=4 + {"N": 12}, + ) + def test__jacobi_eigh_raw__proper_eigh_result(self, N): x = np.random.randn(N, N).astype(np.float32) x = (x + x.T) / 2.0