diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb new file mode 100644 index 00000000000..2b43165bc0f --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -0,0 +1,124 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Updated by: Thomas Hoffmann (EMBL Heidelberg) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.4.24' +versionsuffix = '-CUDA-%(cudaver)s' + +homepage = 'https://pypi.python.org/pypi/jax' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'foss', 'version': '2023a'} +cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] + +builddependencies = [ + ('Bazel', '6.3.1'), + ('pytest-xdist', '3.3.1'), + # git 2.x required to fetch repository 'io_bazel_rules_docker' + ('git', '2.41.0', '-nodocs'), + ('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py + ('poetry', '1.5.1'), +] + +dependencies = [ + ('CUDA', '12.1.1', '', SYSTEM), + ('cuDNN', '8.9.2.26', versionsuffix, SYSTEM), + ('NCCL', '2.18.3', versionsuffix), + ('Python', '3.11.3'), + ('SciPy-bundle', '2023.07'), + ('flatbuffers-python', '23.5.26'), + ('zlib', '1.2.13'), + ('ml_dtypes', '0.3.2'), +] + +# downloading xla and other tarballs to avoid that Bazel downloads it during the build +# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl +local_xla_commit = '12eee889e1f2ad41e27d7b0e970cb92d282d3ec5' +local_tfrt_commit = '4665f7483063a16b6113a05eb45f98103cc1d611' +local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit +local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit + +# deliberately not testing in parallel, as that results in (additional) failing tests; +# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, +# see https://github.com/google/jax/issues/7323 and +# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; +# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; +# use NVIDIA_TF32_OVERRIDE=0 to avoid lossing numerical precision by disabling TF32 Tensor Cores; +local_test = "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 " +local_test += "XLA_PYTHON_CLIENT_ALLOCATOR=platform " +local_test += "JAX_ENABLE_X64=true pytest -vv tests " + +use_pip = True + +default_easyblock = 'PythonPackage' +default_component_specs = { + 'sources': [SOURCE_TAR_GZ], + 'source_urls': [PYPI_SOURCE], + 'start_dir': '%(name)s-%(version)s', + 'use_pip': True, + 'sanity_pip_check': True, + 'download_dep_fail': True, +} + +components = [ + ('absl-py', '2.1.0', { + 'options': {'modulename': 'absl'}, + 'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'], + }), + ('jaxlib', version, { + 'sources': [ + '%(name)s-v%(version)s.tar.gz', + { + 'download_filename': '%s.tar.gz' % local_xla_commit, + 'filename': 'xla-%s.tar.gz' % local_xla_commit, + }, + { + 'download_filename': '%s.tar.gz' % local_tfrt_commit, + 'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit, + }, + ], + 'source_urls': [ + 'https://github.com/google/jax/archive/', + 'https://github.com/tensorflow/runtime/archive', + 'https://github.com/openxla/xla/archive' + ], + 'patches': [ + ('jax-0.4.24_xla-%s_indexing_analysis_small_vector.patch' % local_xla_commit[:7], + '../xla-%s' % local_xla_commit), + # cuda-noncanonical-include-paths still required?: + # ('jax-0.4.24_xla-%s_cuda-noncanonical-include-paths.patch' % local_xla_commit[:7], + # '../xla-%s' % local_xla_commit), + ], + 'checksums': [ + {'jaxlib-v0.4.24.tar.gz': + 'c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28'}, + {'xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5.tar.gz': + 'db007b6628cfe108c63f45d611c6de910abe3ee827e55f08314ce143c4887d66'}, + {'tf_runtime-4665f7483063a16b6113a05eb45f98103cc1d611.tar.gz': + '3aa0ab30fe94dab33f20824b9c2d8e7c3b6017106c833b12070f71d2e0f1d6d6'}, + {'jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch': + '7187cdd08cce12d0af889494317cb8c32865487d1d6d9254064cb62fd3453b6d'}, + ], + 'start_dir': 'jax-jaxlib-v%(version)s', + 'buildopts': local_repo_opt + }), +] + +exts_list = [ + (name, version, { + 'runtest': local_test, + 'source_tmpl': '%(name)s-v%(version)s.tar.gz', + 'source_urls': ['https://github.com/google/jax/archive/'], + 'checksums': [ + {'jax-v0.4.24.tar.gz': '6e52d8b547624bd70d423e6bf85f4fcd47336b529f1a4f6a94fac3096017a694'}, + ], + }), +] + +sanity_pip_check = True + +moduleclass = 'tools' diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch b/easybuild/easyconfigs/j/jax/jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch new file mode 100644 index 00000000000..7e8cbac30c3 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch @@ -0,0 +1,12 @@ +diff -ru xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5_old/xla/service/gpu/model/indexing_analysis.cc xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5/xla/service/gpu/model/indexing_analysis.cc +--- xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5_old/xla/service/gpu/model/indexing_analysis.cc 2024-02-05 19:41:29.000000000 +0100 ++++ xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5/xla/service/gpu/model/indexing_analysis.cc 2024-02-12 12:09:35.301680070 +0100 +@@ -687,7 +687,7 @@ + llvm::SmallVector DelinearizeInBoundsIndex( + mlir::AffineExpr linear, absl::Span sizes, + absl::Span strides) { +- llvm::SmallVector result; ++ llvm::SmallVector result; // THEMBL; see commit c10075688d773c43c22e658c814a94ade3cbb372 + result.reserve(sizes.size()); + for (auto [size, stride] : llvm::zip(sizes, strides)) { + result.push_back(linear.floorDiv(stride) % size); diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb new file mode 100644 index 00000000000..df6bd3d5134 --- /dev/null +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb @@ -0,0 +1,51 @@ +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/02 +easyblock = 'PythonBundle' + +name = 'ml_dtypes' +version = '0.3.2' + +homepage = 'https://github.com/jax-ml/ml_dtypes' +description = """ +ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used +in machine learning libraries, including: + +bfloat16: an alternative to the standard float16 format +float8_*: several experimental 8-bit floating point representations including: +float8_e4m3b11fnuz +float8_e4m3fn +float8_e4m3fnuz +float8_e5m2 +float8_e5m2fnuz +""" + +toolchain = {'name': 'foss', 'version': '2023a'} + +dependencies = [ + ('Python', '3.11.3'), + ('SciPy-bundle', '2023.07'), +] + + +use_pip = True + +default_easyblock = 'PythonPackage' + +exts_list = [ + ('opt_einsum', '3.3.0', { + 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], + }), + ('etils', '1.6.0', { + 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], + }), + (name, version, { + 'patches': [('ml_dtypes-0.3.2_EigenAvx512.patch', 1)], + 'checksums': [ + {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, + {'ml_dtypes-0.3.2_EigenAvx512.patch': '197b05b0b7f611749824369f026099f6a172f9e8eab6ebb6504a16573746c892'}, + ], + }), +] + +sanity_pip_check = True + +moduleclass = 'tools' diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch new file mode 100644 index 00000000000..42ea0606391 --- /dev/null +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch @@ -0,0 +1,1219 @@ +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/01 +# ml_dtype 0.3.2 ships a copy of Eigen commit 7bf2968 (https://gitlab.com/libeigen/eigen/-/commit/7bf2968). +# This copy is missing the file src/Core/arch/AVX512/TrsmUnrolls.inc, which is added by the present patch. +diff -ru --new-file old/third_party_ori/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +--- old/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 1970-01-01 01:00:00.000000000 +0100 ++++ new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 2024-02-14 10:32:25.492978066 +0100 +@@ -0,0 +1,1212 @@ ++// This file is part of Eigen, a lightweight C++ template library ++// for linear algebra. ++// ++// Copyright (C) 2022 Intel Corporation ++// ++// This Source Code Form is subject to the terms of the Mozilla ++// Public License v. 2.0. If a copy of the MPL was not distributed ++// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. ++ ++#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++ ++template ++EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) { ++ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; ++ else return i + j * LDA; ++} ++ ++/** ++ * This namespace contains various classes used to generate compile-time unrolls which are ++ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested ++ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion ++ * ++ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop. ++ * ++ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++) ++ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ) ++ * func(startI,startJ) startJ = (startC)%(endJ) ++ * func(...) ++ * ++ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function ++ * with a template parameter used as a counter. ++ * ++ * template ++ * std::enable_if_t<(counter <= 0)> <---- tail case. ++ * aux_func {} ++ * ++ * template ++ * std::enable_if_t<(counter > 0)> <---- actual for-loop ++ * aux_func { ++ * startC = endI*endJ - counter ++ * startI = (startC)/(endJ) ++ * startJ = (startC)%(endJ) ++ * func(startI, startJ) ++ * aux_func() ++ * } ++ * ++ * Note: Additional wrapper functions are provided for aux_func which hides the counter template ++ * parameter since counter usually depends on endI, endJ, etc... ++ * ++ * Conventions: ++ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++)) ++ * ++ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for ++ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW) ++ */ ++namespace unrolls { ++ ++template ++EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { ++ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } ++ else EIGEN_IF_CONSTEXPR(N == 8) { ++ return 0xFF >> (8 - m); ++ } ++ else EIGEN_IF_CONSTEXPR(N == 4) { ++ return 0x0F >> (4 - m); ++ } ++ return 0; ++} ++ ++template ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel); ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); ++ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); ++ ++ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ ++ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); ++ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); ++ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); ++ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); ++ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); ++ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); ++ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); ++ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); ++ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); ++ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); ++ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); ++ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); ++ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); ++ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); ++ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); ++ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); ++ ++ kernel.packet[0] = T0; ++ kernel.packet[1] = T1; ++ kernel.packet[2] = T2; ++ kernel.packet[3] = T3; ++ kernel.packet[4] = T4; ++ kernel.packet[5] = T5; ++ kernel.packet[6] = T6; ++ kernel.packet[7] = T7; ++} ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ ptranspose(kernel); ++} ++ ++/*** ++ * Unrolls for tranposed C stores ++ */ ++template ++class trans { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - storeC ++ *********************************** ++ */ ++ ++ /** ++ * aux_storeC ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ * ++ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC ++ * ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]), ++ remMask(remM_)), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu(C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); ++ } ++ } ++ else { // This block is only needed for fp32 case ++ // Reinterpret as __m512 for _mm512_shuffle_f32x4 ++ vecFullFloat zmm2vecFullFloat = preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]); ++ // Swap lower and upper half of avx register. ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] = ++ preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]))); ++ } ++ } ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t remM_ = 0) { ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ /** ++ * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to ++ * "unrollN"xL ymm registers to be stored col-major into C. ++ * ++ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 zmm2 ++ * row1: zmm3 zmm4 zmm5 ++ * . ++ * . ++ * row7: zmm21 zmm22 zmm23 ++ * ++ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 ++ * row1: zmm2 zmm3 ++ * . ++ * . ++ * row7: zmm14 zmm15 ++ * ++ * ++ * In general we will have {1,2,3} groups of avx registers each of size ++ * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of ++ * avx registers are being transposed. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock &zmm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ constexpr int64_t zmmStride = unrollN / PacketSize; ++ PacketBlock r; ++ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0]; ++ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1]; ++ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2]; ++ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3]; ++ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4]; ++ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5]; ++ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6]; ++ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7]; ++ trans8x8blocks(r); ++ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0]; ++ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1]; ++ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2]; ++ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3]; ++ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4]; ++ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5]; ++ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6]; ++ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7]; ++ } ++}; ++ ++/** ++ * Unrolls for copyBToRowMajor ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers (using loadB). ++ * 2) Convert the block from column-major to row-major (transposeLxL) ++ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false). ++ * ++ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are ++ * used as temps for transposing. ++ * ++ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks ++ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm). ++ */ ++template ++class transB { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadB ++ * - storeB ++ * - loadBBlock ++ * - storeBBlock ++ *********************************** ++ */ ++ ++ /** ++ * aux_loadB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ ymm.packet[packetIndexOffset + startN] = ++ ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); ++ } ++ else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); ++ ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remK || remM) { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN], ++ remMask(rem_)); ++ } ++ else { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]); ++ } ++ ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_loadBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ transB::template loadB(&B_temp[startN], LDB_, ymm); ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(toTemp) { ++ transB::template storeB(&B_temp[startN], LDB_, ymm, remK_); ++ } ++ else { ++ transB::template storeB(&B_arr[0 + startN * LDB], LDB, ++ ymm, remM_); ++ } ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t rem_ = 0) { ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } ++ else { ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock &ymm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ PacketBlock r; ++ r.packet[0] = ymm.packet[packetIndexOffset + 0]; ++ r.packet[1] = ymm.packet[packetIndexOffset + 1]; ++ r.packet[2] = ymm.packet[packetIndexOffset + 2]; ++ r.packet[3] = ymm.packet[packetIndexOffset + 3]; ++ r.packet[4] = ymm.packet[packetIndexOffset + 4]; ++ r.packet[5] = ymm.packet[packetIndexOffset + 5]; ++ r.packet[6] = ymm.packet[packetIndexOffset + 6]; ++ r.packet[7] = ymm.packet[packetIndexOffset + 7]; ++ ptranspose(r); ++ ymm.packet[packetIndexOffset + 0] = r.packet[0]; ++ ymm.packet[packetIndexOffset + 1] = r.packet[1]; ++ ymm.packet[packetIndexOffset + 2] = r.packet[2]; ++ ymm.packet[packetIndexOffset + 3] = r.packet[3]; ++ ymm.packet[packetIndexOffset + 4] = r.packet[4]; ++ ymm.packet[packetIndexOffset + 5] = r.packet[5]; ++ ymm.packet[packetIndexOffset + 6] = r.packet[6]; ++ ymm.packet[packetIndexOffset + 7] = r.packet[7]; ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t U3 = PacketSize * 3; ++ constexpr int64_t U2 = PacketSize * 2; ++ constexpr int64_t U1 = PacketSize * 1; ++ /** ++ * Unrolls needed for each case: ++ * - AVX512 fp32 48 32 16 8 4 2 1 ++ * - AVX512 fp64 24 16 8 4 2 1 ++ * ++ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up. ++ */ ++ EIGEN_IF_CONSTEXPR(unrollN == U3) { ++ // load LxU3 B col major, transpose LxU3 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U3) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U2) { ++ // load LxU2 B col major, transpose LxU2 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U1) { ++ // load LxU1 B col major, transpose LxU1 row major ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); } ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 2) { ++ // load Lx2 B col major, transpose Lx2 row major ++ transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 1) { ++ // load Lx1 B col major, transpose Lx1 row major ++ transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++}; ++ ++/** ++ * Unrolls for triSolveKernel ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS). ++ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix) ++ * stored in AInPacket (using triSolveMicroKernel). ++ * 3) Store final results (in avx registers) back into memory (using storeRHS). ++ * ++ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most ++ * EIGEN_AVX_MAX_NUM_ROW registers. ++ */ ++template ++class trsm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadRHS ++ * - storeRHS ++ * - divRHSByDiag ++ * - updateRHS ++ * - triSolveMicroKernel ++ ************************************/ ++ /** ++ * aux_loadRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); ++ } ++ else { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); ++ } ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_storeRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); ++ } ++ else { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); ++ } ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_divRHSByDiag ++ * ++ * currM may be -1, (currM >=0) in enable_if checks for this ++ * ++ * 1-D unroll ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endK - counter; ++ constexpr int64_t startK = counterReverse; ++ ++ constexpr int64_t packetIndex = currM * endK + startK; ++ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_updateRHS ++ * ++ * 2-D unroll ++ * for(startM = initM; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = (endM - initM) * endK - counter; ++ constexpr int64_t startM = initM + counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ // For each row of A, first update all corresponding RHS ++ constexpr int64_t packetIndex = startM * endK + startK; ++ EIGEN_IF_CONSTEXPR(currentM > 0) { ++ RHSInPacket.packet[packetIndex] = ++ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK], ++ RHSInPacket.packet[packetIndex]); ++ } ++ ++ EIGEN_IF_CONSTEXPR(startK == endK - 1) { ++ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. ++ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { ++ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. ++ // This will be used in divRHSByDiag ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(currentM, currentM, LDA)]); ++ else AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(-currentM, -currentM, LDA)]); ++ } ++ else { ++ // Broadcast next off diagonal element of A ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[startM] = pset1(A_arr[idA(startM, currentM, LDA)]); ++ else AInPacket.packet[startM] = pset1(A_arr[idA(-startM, -currentM, LDA)]); ++ } ++ } ++ ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_triSolverMicroKernel ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ constexpr int64_t currentM = startM; ++ // Divides the right-hand side in row startM, by digonal value of A ++ // broadcasted to AInPacket.packet[startM-1] in the previous iteration. ++ // ++ // Without "if constexpr" the compiler instantiates the case <-1, numK> ++ // this is handled with enable_if to prevent out-of-bound warnings ++ // from the compiler ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ // After division, the rhs corresponding to subsequent rows of A can be partially updated ++ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) ++ // to be used in the next iteration. ++ trsm::template updateRHS(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ ++ // Handle division for the RHS corresponding to the final row of A. ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Only used if Triangular matrix has non-unit diagonal values ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * Update right-hand sides (stored in avx registers) ++ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW ++ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3. ++ * isFWDSolve: true => forward substitution, false => backwards substitution ++ * isUnitDiag: true => triangular matrix has unit diagonal. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ static_assert(numK >= 1 && numK <= 3, "numK out of range"); ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); ++ } ++}; ++ ++/** ++ * Unrolls for gemm kernel ++ * ++ * isAdd: true => C += A*B, false => C -= A*B ++ */ ++template ++class gemm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - setzero ++ * - updateC ++ * - storeC ++ * - startLoadB ++ * - triSolveMicroKernel ++ ************************************/ ++ ++ /** ++ * aux_setzero ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero( ++ PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]); ++ aux_setzero(zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero( ++ PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_updateC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize], remMask(rem_)), ++ zmm.packet[startN * endM + startM], remMask(rem_)); ++ else zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_storeC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM], ++ remMask(rem_)); ++ else pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startLoadB ++ * ++ * 1-D unroll ++ * for(startL = 0; startL < endL; startL++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endL - counter; ++ constexpr int64_t startL = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask(rem_)); ++ else zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]); ++ ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startBCastA ++ * ++ * 1-D unroll ++ * for(startB = 0; startB < endB; startB++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endB - counter; ++ constexpr int64_t startB = counterReverse; ++ ++ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1(&A_t[idA(startB, 0, LDA)]); ++ ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_loadB ++ * currK: current K ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ if ((numLoad / endM + currK < unrollK)) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask(rem_)); ++ } ++ else { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]); ++ } ++ ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_microKernel ++ * ++ * 3-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN * endK - counter; ++ constexpr int startK = counterReverse / (endM * endN); ++ constexpr int startN = (counterReverse / (endM)) % endN; ++ constexpr int startM = counterReverse % endM; ++ ++ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { ++ gemm::template startLoadB(B_t, LDB, zmm, rem_); ++ gemm::template startBCastA(A_t, LDA, zmm); ++ } ++ ++ { ++ // Interleave FMA and Bcast ++ EIGEN_IF_CONSTEXPR(isAdd) { ++ zmm.packet[startN * endM + startM] = ++ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ else { ++ zmm.packet[startN * endM + startM] = ++ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ // Bcast ++ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { ++ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1(&A_t[idA( ++ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]); ++ } ++ } ++ ++ // We have updated all accumlators, time to load next set of B's ++ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) { ++ gemm::template loadB(B_t, LDB, zmm, rem_); ++ } ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock &zmm) { ++ aux_setzero(zmm); ++ } ++ ++ /** ++ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ /** ++ * Use numLoad registers for loading B at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Use numBCast registers for broadcasting A at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA, ++ PacketBlock &zmm) { ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ /** ++ * Loads next set of B into vector registers between each K unroll. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B. ++ * A matrix can be row/col-major. B matrix is assumed row-major. ++ * ++ * isARowMajor: is A row major ++ * endM: Number registers per row ++ * endN: Number of rows ++ * endK: Loop unroll for K. ++ * numLoad: Number of registers for loading B. ++ * numBCast: Number of registers for broadcasting A. ++ * ++ * Ex: microkernel: 8x48 unroll (24 accumulators), k unrolled 4 times, ++ * 6 register for loading B, 2 for broadcasting A. ++ * ++ * Note: Ideally the microkernel should not have any register spilling. ++ * The avx instruction counts should be: ++ * - endK*endN vbroadcasts{s,d} ++ * - endK*endM vmovup{s,d} ++ * - endK*endN*endM FMAs ++ * ++ * From testing, there are no register spills with clang. There are register spills with GNU, which ++ * causes a performance hit. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, ++ rem_); ++ } ++}; ++} // namespace unrolls ++ ++#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H