From eb56b2ff25f4f4372f73ce69bd29cdea65b02fa7 Mon Sep 17 00:00:00 2001 From: thoffman Date: Tue, 13 Feb 2024 11:26:27 +0100 Subject: [PATCH 01/12] {tools]{foss/2023a} jax v0.4.24 w/ CUDA 12.1.1 --- .../jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb new file mode 100644 index 00000000000..5f3d294504b --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -0,0 +1,127 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Updated by: Thomas Hoffmann (EMBL Heidelberg) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.4.24' +versionsuffix = '-CUDA-%(cudaver)s' + +homepage = 'https://pypi.python.org/pypi/jax' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'foss', 'version': '2023a'} +cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] + +builddependencies = [ + ('Bazel', '6.3.1'), + ('pytest-xdist', '3.3.1'), + # git 2.x required to fetch repository 'io_bazel_rules_docker' + ('git', '2.41.0', '-nodocs'), + ('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py + ('poetry', '1.5.1'), +] + +dependencies = [ + ('CUDA', '12.1.1', '', SYSTEM), + ('cuDNN', '8.9.2.26', versionsuffix, SYSTEM), + ('NCCL', '2.18.3', versionsuffix), + ('Python', '3.11.3'), + ('SciPy-bundle', '2023.07'), + ('flatbuffers-python', '23.5.26'), + ('zlib', '1.2.13'), +] + +# downloading xla and other tarballs to avoid that Bazel downloads it during the build +# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl +local_xla_commit = '12eee889e1f2ad41e27d7b0e970cb92d282d3ec5' +local_tfrt_commit = '4665f7483063a16b6113a05eb45f98103cc1d611' +local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit +local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit + +local_test = "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 " +local_test += "XLA_PYTHON_CLIENT_ALLOCATOR=platform " +local_test += "JAX_ENABLE_X64=true pytest -vv tests " +local_test = True + +use_pip = True + +default_easyblock = 'PythonPackage' +default_component_specs = { + 'sources': [SOURCE_TAR_GZ], + 'source_urls': [PYPI_SOURCE], + 'start_dir': '%(name)s-%(version)s', + 'use_pip': True, + 'sanity_pip_check': True, + 'download_dep_fail': True, +} + +components = [ + ('absl-py', '2.1.0', { + 'options': {'modulename': 'absl'}, + 'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'], + }), + ('jaxlib', version, { + 'sources': [ + '%(name)s-v%(version)s.tar.gz', + { + 'download_filename': '%s.tar.gz' % local_xla_commit, + 'filename': 'xla-%s.tar.gz' % local_xla_commit, + }, + { + 'download_filename': '%s.tar.gz' % local_tfrt_commit, + 'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit, + }, + ], + 'source_urls': [ + 'https://github.com/google/jax/archive/', + 'https://github.com/tensorflow/runtime/archive', + 'https://github.com/openxla/xla/archive' + ], + 'patches': [ + ('jax-0.4.24_xla-%s_indexing_analysis_small_vector.patch' % local_xla_commit[:7], + '../xla-%s' % local_xla_commit), + # cuda-noncanonical-include-paths still required?: + # ('jax-0.4.24_xla-%s_cuda-noncanonical-include-paths.patch' % local_xla_commit[:7], + # '../xla-%s' % local_xla_commit), + ], + 'checksums': [ + {'jaxlib-v0.4.24.tar.gz': + 'c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28'}, + {'xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5.tar.gz': + 'db007b6628cfe108c63f45d611c6de910abe3ee827e55f08314ce143c4887d66'}, + {'tf_runtime-4665f7483063a16b6113a05eb45f98103cc1d611.tar.gz': + '3aa0ab30fe94dab33f20824b9c2d8e7c3b6017106c833b12070f71d2e0f1d6d6'}, + ], + 'start_dir': 'jax-jaxlib-v%(version)s', + 'buildopts': local_repo_opt + }), +] + +exts_list = [ + ('opt_einsum', '3.3.0', { + 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], + }), + ('etils', '1.6.0', { + 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], + }), + ('ml_dtypes', '0.3.2', { + 'checksums': ['533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'], + }), + (name, version, { + # 'patches': ['jax-0.4.24_cusparse.patch'], + 'runtest': local_test, + 'source_tmpl': '%(name)s-v%(version)s.tar.gz', + 'source_urls': ['https://github.com/google/jax/archive/'], + 'checksums': [ + {'jax-v0.4.24.tar.gz': '6e52d8b547624bd70d423e6bf85f4fcd47336b529f1a4f6a94fac3096017a694'}, + # {'jax-0.4.24_cusparse.patch': 'a7d61e412ef3e77b2d4a6ba3c98e5c2c571523ccf46325ea180725f77a9dfb36'}, + ], + }), +] + +sanity_pip_check = True + +moduleclass = 'tools' From a484ba24f3b1363df93b1a573f5954f2dfc801c1 Mon Sep 17 00:00:00 2001 From: thoffman Date: Tue, 13 Feb 2024 12:03:51 +0100 Subject: [PATCH 02/12] add checksum --- .../easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb index 5f3d294504b..1593970c012 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -94,6 +94,8 @@ components = [ 'db007b6628cfe108c63f45d611c6de910abe3ee827e55f08314ce143c4887d66'}, {'tf_runtime-4665f7483063a16b6113a05eb45f98103cc1d611.tar.gz': '3aa0ab30fe94dab33f20824b9c2d8e7c3b6017106c833b12070f71d2e0f1d6d6'}, + {'jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch': + '7187cdd08cce12d0af889494317cb8c32865487d1d6d9254064cb62fd3453b6d'}, ], 'start_dir': 'jax-jaxlib-v%(version)s', 'buildopts': local_repo_opt From 5d7d18fd3ad4ff4578e8785122127067f7c763b3 Mon Sep 17 00:00:00 2001 From: thoffman Date: Tue, 13 Feb 2024 12:20:31 +0100 Subject: [PATCH 03/12] fix local_test --- easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 1 - 1 file changed, 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb index 1593970c012..76abbc7fa36 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -44,7 +44,6 @@ local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/ local_test = "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 " local_test += "XLA_PYTHON_CLIENT_ALLOCATOR=platform " local_test += "JAX_ENABLE_X64=true pytest -vv tests " -local_test = True use_pip = True From 78fc890e70fd133c9064c5abf377a241727e00a3 Mon Sep 17 00:00:00 2001 From: thoffman Date: Tue, 13 Feb 2024 12:33:29 +0100 Subject: [PATCH 04/12] add patch --- ..._xla-12eee88_indexing_analysis_small_vector.patch | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 easybuild/easyconfigs/j/jax/jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch b/easybuild/easyconfigs/j/jax/jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch new file mode 100644 index 00000000000..7e8cbac30c3 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24_xla-12eee88_indexing_analysis_small_vector.patch @@ -0,0 +1,12 @@ +diff -ru xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5_old/xla/service/gpu/model/indexing_analysis.cc xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5/xla/service/gpu/model/indexing_analysis.cc +--- xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5_old/xla/service/gpu/model/indexing_analysis.cc 2024-02-05 19:41:29.000000000 +0100 ++++ xla-12eee889e1f2ad41e27d7b0e970cb92d282d3ec5/xla/service/gpu/model/indexing_analysis.cc 2024-02-12 12:09:35.301680070 +0100 +@@ -687,7 +687,7 @@ + llvm::SmallVector DelinearizeInBoundsIndex( + mlir::AffineExpr linear, absl::Span sizes, + absl::Span strides) { +- llvm::SmallVector result; ++ llvm::SmallVector result; // THEMBL; see commit c10075688d773c43c22e658c814a94ade3cbb372 + result.reserve(sizes.size()); + for (auto [size, stride] : llvm::zip(sizes, strides)) { + result.push_back(linear.floorDiv(stride) % size); From 6e3b537bd2b131c7a125be6e3885620bd0a4f11f Mon Sep 17 00:00:00 2001 From: thoffman Date: Wed, 14 Feb 2024 11:42:48 +0100 Subject: [PATCH 05/12] add mv_dtypes avx512 patch --- .../jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 7 +- ...x-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch | 1217 +++++++++++++++++ 2 files changed, 1223 insertions(+), 1 deletion(-) create mode 100644 easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb index 76abbc7fa36..753df24254d 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -109,7 +109,12 @@ exts_list = [ 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], }), ('ml_dtypes', '0.3.2', { - 'checksums': ['533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'], + 'patches': [('jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch',1)], + 'checksums': [ + {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, + {'jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch': + '01cd2956c6c1a600ea17f93686d7d952bceb6cb1fbe408a8e54d35f6a7b13c4a'}, + ] }), (name, version, { # 'patches': ['jax-0.4.24_cusparse.patch'], diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch b/easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch new file mode 100644 index 00000000000..0a4e63a0457 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch @@ -0,0 +1,1217 @@ +# ml_dtype 0.3.2 provides Eigen commit 7bf2968. src/Core/arch/AVX512/TrsmUnrolls.inc is missing. +diff -ru --new-file ./ml_dtypes-0.3.2_ori/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc ./ml_dtypes-0.3.2/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +--- ./ml_dtypes-0.3.2_ori/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 1970-01-01 01:00:00.000000000 +0100 ++++ ./ml_dtypes-0.3.2/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 2024-02-14 10:32:25.492978066 +0100 +@@ -0,0 +1,1212 @@ ++// This file is part of Eigen, a lightweight C++ template library ++// for linear algebra. ++// ++// Copyright (C) 2022 Intel Corporation ++// ++// This Source Code Form is subject to the terms of the Mozilla ++// Public License v. 2.0. If a copy of the MPL was not distributed ++// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. ++ ++#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++ ++template ++EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) { ++ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; ++ else return i + j * LDA; ++} ++ ++/** ++ * This namespace contains various classes used to generate compile-time unrolls which are ++ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested ++ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion ++ * ++ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop. ++ * ++ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++) ++ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ) ++ * func(startI,startJ) startJ = (startC)%(endJ) ++ * func(...) ++ * ++ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function ++ * with a template parameter used as a counter. ++ * ++ * template ++ * std::enable_if_t<(counter <= 0)> <---- tail case. ++ * aux_func {} ++ * ++ * template ++ * std::enable_if_t<(counter > 0)> <---- actual for-loop ++ * aux_func { ++ * startC = endI*endJ - counter ++ * startI = (startC)/(endJ) ++ * startJ = (startC)%(endJ) ++ * func(startI, startJ) ++ * aux_func() ++ * } ++ * ++ * Note: Additional wrapper functions are provided for aux_func which hides the counter template ++ * parameter since counter usually depends on endI, endJ, etc... ++ * ++ * Conventions: ++ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++)) ++ * ++ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for ++ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW) ++ */ ++namespace unrolls { ++ ++template ++EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { ++ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } ++ else EIGEN_IF_CONSTEXPR(N == 8) { ++ return 0xFF >> (8 - m); ++ } ++ else EIGEN_IF_CONSTEXPR(N == 4) { ++ return 0x0F >> (4 - m); ++ } ++ return 0; ++} ++ ++template ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel); ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); ++ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); ++ ++ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ ++ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); ++ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); ++ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); ++ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); ++ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); ++ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); ++ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); ++ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); ++ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); ++ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); ++ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); ++ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); ++ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); ++ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); ++ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); ++ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); ++ ++ kernel.packet[0] = T0; ++ kernel.packet[1] = T1; ++ kernel.packet[2] = T2; ++ kernel.packet[3] = T3; ++ kernel.packet[4] = T4; ++ kernel.packet[5] = T5; ++ kernel.packet[6] = T6; ++ kernel.packet[7] = T7; ++} ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ ptranspose(kernel); ++} ++ ++/*** ++ * Unrolls for tranposed C stores ++ */ ++template ++class trans { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - storeC ++ *********************************** ++ */ ++ ++ /** ++ * aux_storeC ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ * ++ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC ++ * ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]), ++ remMask(remM_)), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu(C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); ++ } ++ } ++ else { // This block is only needed for fp32 case ++ // Reinterpret as __m512 for _mm512_shuffle_f32x4 ++ vecFullFloat zmm2vecFullFloat = preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]); ++ // Swap lower and upper half of avx register. ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] = ++ preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]))); ++ } ++ } ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t remM_ = 0) { ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ /** ++ * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to ++ * "unrollN"xL ymm registers to be stored col-major into C. ++ * ++ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 zmm2 ++ * row1: zmm3 zmm4 zmm5 ++ * . ++ * . ++ * row7: zmm21 zmm22 zmm23 ++ * ++ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 ++ * row1: zmm2 zmm3 ++ * . ++ * . ++ * row7: zmm14 zmm15 ++ * ++ * ++ * In general we will have {1,2,3} groups of avx registers each of size ++ * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of ++ * avx registers are being transposed. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock &zmm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ constexpr int64_t zmmStride = unrollN / PacketSize; ++ PacketBlock r; ++ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0]; ++ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1]; ++ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2]; ++ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3]; ++ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4]; ++ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5]; ++ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6]; ++ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7]; ++ trans8x8blocks(r); ++ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0]; ++ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1]; ++ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2]; ++ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3]; ++ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4]; ++ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5]; ++ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6]; ++ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7]; ++ } ++}; ++ ++/** ++ * Unrolls for copyBToRowMajor ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers (using loadB). ++ * 2) Convert the block from column-major to row-major (transposeLxL) ++ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false). ++ * ++ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are ++ * used as temps for transposing. ++ * ++ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks ++ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm). ++ */ ++template ++class transB { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadB ++ * - storeB ++ * - loadBBlock ++ * - storeBBlock ++ *********************************** ++ */ ++ ++ /** ++ * aux_loadB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ ymm.packet[packetIndexOffset + startN] = ++ ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); ++ } ++ else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); ++ ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remK || remM) { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN], ++ remMask(rem_)); ++ } ++ else { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]); ++ } ++ ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_loadBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ transB::template loadB(&B_temp[startN], LDB_, ymm); ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(toTemp) { ++ transB::template storeB(&B_temp[startN], LDB_, ymm, remK_); ++ } ++ else { ++ transB::template storeB(&B_arr[0 + startN * LDB], LDB, ++ ymm, remM_); ++ } ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t rem_ = 0) { ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } ++ else { ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock &ymm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ PacketBlock r; ++ r.packet[0] = ymm.packet[packetIndexOffset + 0]; ++ r.packet[1] = ymm.packet[packetIndexOffset + 1]; ++ r.packet[2] = ymm.packet[packetIndexOffset + 2]; ++ r.packet[3] = ymm.packet[packetIndexOffset + 3]; ++ r.packet[4] = ymm.packet[packetIndexOffset + 4]; ++ r.packet[5] = ymm.packet[packetIndexOffset + 5]; ++ r.packet[6] = ymm.packet[packetIndexOffset + 6]; ++ r.packet[7] = ymm.packet[packetIndexOffset + 7]; ++ ptranspose(r); ++ ymm.packet[packetIndexOffset + 0] = r.packet[0]; ++ ymm.packet[packetIndexOffset + 1] = r.packet[1]; ++ ymm.packet[packetIndexOffset + 2] = r.packet[2]; ++ ymm.packet[packetIndexOffset + 3] = r.packet[3]; ++ ymm.packet[packetIndexOffset + 4] = r.packet[4]; ++ ymm.packet[packetIndexOffset + 5] = r.packet[5]; ++ ymm.packet[packetIndexOffset + 6] = r.packet[6]; ++ ymm.packet[packetIndexOffset + 7] = r.packet[7]; ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t U3 = PacketSize * 3; ++ constexpr int64_t U2 = PacketSize * 2; ++ constexpr int64_t U1 = PacketSize * 1; ++ /** ++ * Unrolls needed for each case: ++ * - AVX512 fp32 48 32 16 8 4 2 1 ++ * - AVX512 fp64 24 16 8 4 2 1 ++ * ++ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up. ++ */ ++ EIGEN_IF_CONSTEXPR(unrollN == U3) { ++ // load LxU3 B col major, transpose LxU3 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U3) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U2) { ++ // load LxU2 B col major, transpose LxU2 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U1) { ++ // load LxU1 B col major, transpose LxU1 row major ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); } ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 2) { ++ // load Lx2 B col major, transpose Lx2 row major ++ transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 1) { ++ // load Lx1 B col major, transpose Lx1 row major ++ transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++}; ++ ++/** ++ * Unrolls for triSolveKernel ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS). ++ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix) ++ * stored in AInPacket (using triSolveMicroKernel). ++ * 3) Store final results (in avx registers) back into memory (using storeRHS). ++ * ++ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most ++ * EIGEN_AVX_MAX_NUM_ROW registers. ++ */ ++template ++class trsm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadRHS ++ * - storeRHS ++ * - divRHSByDiag ++ * - updateRHS ++ * - triSolveMicroKernel ++ ************************************/ ++ /** ++ * aux_loadRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); ++ } ++ else { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); ++ } ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_storeRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); ++ } ++ else { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); ++ } ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_divRHSByDiag ++ * ++ * currM may be -1, (currM >=0) in enable_if checks for this ++ * ++ * 1-D unroll ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endK - counter; ++ constexpr int64_t startK = counterReverse; ++ ++ constexpr int64_t packetIndex = currM * endK + startK; ++ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_updateRHS ++ * ++ * 2-D unroll ++ * for(startM = initM; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = (endM - initM) * endK - counter; ++ constexpr int64_t startM = initM + counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ // For each row of A, first update all corresponding RHS ++ constexpr int64_t packetIndex = startM * endK + startK; ++ EIGEN_IF_CONSTEXPR(currentM > 0) { ++ RHSInPacket.packet[packetIndex] = ++ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK], ++ RHSInPacket.packet[packetIndex]); ++ } ++ ++ EIGEN_IF_CONSTEXPR(startK == endK - 1) { ++ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. ++ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { ++ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. ++ // This will be used in divRHSByDiag ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(currentM, currentM, LDA)]); ++ else AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(-currentM, -currentM, LDA)]); ++ } ++ else { ++ // Broadcast next off diagonal element of A ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[startM] = pset1(A_arr[idA(startM, currentM, LDA)]); ++ else AInPacket.packet[startM] = pset1(A_arr[idA(-startM, -currentM, LDA)]); ++ } ++ } ++ ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_triSolverMicroKernel ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ constexpr int64_t currentM = startM; ++ // Divides the right-hand side in row startM, by digonal value of A ++ // broadcasted to AInPacket.packet[startM-1] in the previous iteration. ++ // ++ // Without "if constexpr" the compiler instantiates the case <-1, numK> ++ // this is handled with enable_if to prevent out-of-bound warnings ++ // from the compiler ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ // After division, the rhs corresponding to subsequent rows of A can be partially updated ++ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) ++ // to be used in the next iteration. ++ trsm::template updateRHS(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ ++ // Handle division for the RHS corresponding to the final row of A. ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Only used if Triangular matrix has non-unit diagonal values ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * Update right-hand sides (stored in avx registers) ++ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW ++ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3. ++ * isFWDSolve: true => forward substitution, false => backwards substitution ++ * isUnitDiag: true => triangular matrix has unit diagonal. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ static_assert(numK >= 1 && numK <= 3, "numK out of range"); ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); ++ } ++}; ++ ++/** ++ * Unrolls for gemm kernel ++ * ++ * isAdd: true => C += A*B, false => C -= A*B ++ */ ++template ++class gemm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - setzero ++ * - updateC ++ * - storeC ++ * - startLoadB ++ * - triSolveMicroKernel ++ ************************************/ ++ ++ /** ++ * aux_setzero ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero( ++ PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]); ++ aux_setzero(zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero( ++ PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_updateC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize], remMask(rem_)), ++ zmm.packet[startN * endM + startM], remMask(rem_)); ++ else zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_storeC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM], ++ remMask(rem_)); ++ else pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startLoadB ++ * ++ * 1-D unroll ++ * for(startL = 0; startL < endL; startL++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endL - counter; ++ constexpr int64_t startL = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask(rem_)); ++ else zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]); ++ ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startBCastA ++ * ++ * 1-D unroll ++ * for(startB = 0; startB < endB; startB++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endB - counter; ++ constexpr int64_t startB = counterReverse; ++ ++ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1(&A_t[idA(startB, 0, LDA)]); ++ ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_loadB ++ * currK: current K ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ if ((numLoad / endM + currK < unrollK)) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask(rem_)); ++ } ++ else { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]); ++ } ++ ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_microKernel ++ * ++ * 3-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN * endK - counter; ++ constexpr int startK = counterReverse / (endM * endN); ++ constexpr int startN = (counterReverse / (endM)) % endN; ++ constexpr int startM = counterReverse % endM; ++ ++ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { ++ gemm::template startLoadB(B_t, LDB, zmm, rem_); ++ gemm::template startBCastA(A_t, LDA, zmm); ++ } ++ ++ { ++ // Interleave FMA and Bcast ++ EIGEN_IF_CONSTEXPR(isAdd) { ++ zmm.packet[startN * endM + startM] = ++ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ else { ++ zmm.packet[startN * endM + startM] = ++ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ // Bcast ++ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { ++ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1(&A_t[idA( ++ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]); ++ } ++ } ++ ++ // We have updated all accumlators, time to load next set of B's ++ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) { ++ gemm::template loadB(B_t, LDB, zmm, rem_); ++ } ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock &zmm) { ++ aux_setzero(zmm); ++ } ++ ++ /** ++ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ /** ++ * Use numLoad registers for loading B at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Use numBCast registers for broadcasting A at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA, ++ PacketBlock &zmm) { ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ /** ++ * Loads next set of B into vector registers between each K unroll. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B. ++ * A matrix can be row/col-major. B matrix is assumed row-major. ++ * ++ * isARowMajor: is A row major ++ * endM: Number registers per row ++ * endN: Number of rows ++ * endK: Loop unroll for K. ++ * numLoad: Number of registers for loading B. ++ * numBCast: Number of registers for broadcasting A. ++ * ++ * Ex: microkernel: 8x48 unroll (24 accumulators), k unrolled 4 times, ++ * 6 register for loading B, 2 for broadcasting A. ++ * ++ * Note: Ideally the microkernel should not have any register spilling. ++ * The avx instruction counts should be: ++ * - endK*endN vbroadcasts{s,d} ++ * - endK*endM vmovup{s,d} ++ * - endK*endN*endM FMAs ++ * ++ * From testing, there are no register spills with clang. There are register spills with GNU, which ++ * causes a performance hit. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, ++ rem_); ++ } ++}; ++} // namespace unrolls ++ ++#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H From 3791d23f319a60220f027fdb5d6e3c730a49fc4b Mon Sep 17 00:00:00 2001 From: thoffman Date: Wed, 14 Feb 2024 12:02:50 +0100 Subject: [PATCH 06/12] fix style --- .../easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb index 753df24254d..7985f3b3233 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -109,7 +109,7 @@ exts_list = [ 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], }), ('ml_dtypes', '0.3.2', { - 'patches': [('jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch',1)], + 'patches': [('jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch', 1)], 'checksums': [ {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, {'jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch': From 097e35582d41aa84d6fdab49e84f6390d8fff0f3 Mon Sep 17 00:00:00 2001 From: thoffman Date: Thu, 15 Feb 2024 17:52:56 +0100 Subject: [PATCH 07/12] ml_dtypes: extension->dependency --- .../jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 15 +- .../m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb | 52 + .../ml_dtypes-0.3.2_EigenAvx512.patch | 1217 +++++++++++++++++ 3 files changed, 1270 insertions(+), 14 deletions(-) create mode 100644 easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb create mode 100644 easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb index 7985f3b3233..febaae194bb 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -32,6 +32,7 @@ dependencies = [ ('SciPy-bundle', '2023.07'), ('flatbuffers-python', '23.5.26'), ('zlib', '1.2.13'), + ('ml_dtypes', '0.3.2'), ] # downloading xla and other tarballs to avoid that Bazel downloads it during the build @@ -102,20 +103,6 @@ components = [ ] exts_list = [ - ('opt_einsum', '3.3.0', { - 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], - }), - ('etils', '1.6.0', { - 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], - }), - ('ml_dtypes', '0.3.2', { - 'patches': [('jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch', 1)], - 'checksums': [ - {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, - {'jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch': - '01cd2956c6c1a600ea17f93686d7d952bceb6cb1fbe408a8e54d35f6a7b13c4a'}, - ] - }), (name, version, { # 'patches': ['jax-0.4.24_cusparse.patch'], 'runtest': local_test, diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb new file mode 100644 index 00000000000..822c1e6757d --- /dev/null +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb @@ -0,0 +1,52 @@ +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/02 +easyblock = 'PythonBundle' + +name = 'ml_dtypes' +version = '0.3.2' + +homepage = 'https://github.com/jax-ml/ml_dtypes' +description = """ +ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used +in machine learning libraries, including: + +bfloat16: an alternative to the standard float16 format +float8_*: several experimental 8-bit floating point representations including: +float8_e4m3b11fnuz +float8_e4m3fn +float8_e4m3fnuz +float8_e5m2 +float8_e5m2fnuz +""" + +toolchain = {'name': 'foss', 'version': '2023a'} + +dependencies = [ + ('Python', '3.11.3'), + ('SciPy-bundle', '2023.07'), +] + + +use_pip = True + +default_easyblock = 'PythonPackage' + +exts_list = [ + ('opt_einsum', '3.3.0', { + 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], + }), + ('etils', '1.6.0', { + 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], + }), + (name, version, { + 'patches': [('ml_dtypes-0.3.2_EigenAvx512.patch', 1)], + 'checksums': [ + {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, + {'ml_dtypes-0.3.2_EigenAvx512.patch': + 'fef229a24515b9c03be0d2e932c499965212e3a03ae3ede5d037874f88f93c46'}, + ], + }) +] + +sanity_pip_check = True + +moduleclass = 'tools' diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch new file mode 100644 index 00000000000..79da0f03f1d --- /dev/null +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch @@ -0,0 +1,1217 @@ +# ml_dtype 0.3.2 provides Eigen commit 7bf2968. src/Core/arch/AVX512/TrsmUnrolls.inc is missing. +diff -ru --new-file old/third_party_ori/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +--- old/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 1970-01-01 01:00:00.000000000 +0100 ++++ new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 2024-02-14 10:32:25.492978066 +0100 +@@ -0,0 +1,1212 @@ ++// This file is part of Eigen, a lightweight C++ template library ++// for linear algebra. ++// ++// Copyright (C) 2022 Intel Corporation ++// ++// This Source Code Form is subject to the terms of the Mozilla ++// Public License v. 2.0. If a copy of the MPL was not distributed ++// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. ++ ++#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++ ++template ++EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) { ++ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; ++ else return i + j * LDA; ++} ++ ++/** ++ * This namespace contains various classes used to generate compile-time unrolls which are ++ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested ++ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion ++ * ++ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop. ++ * ++ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++) ++ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ) ++ * func(startI,startJ) startJ = (startC)%(endJ) ++ * func(...) ++ * ++ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function ++ * with a template parameter used as a counter. ++ * ++ * template ++ * std::enable_if_t<(counter <= 0)> <---- tail case. ++ * aux_func {} ++ * ++ * template ++ * std::enable_if_t<(counter > 0)> <---- actual for-loop ++ * aux_func { ++ * startC = endI*endJ - counter ++ * startI = (startC)/(endJ) ++ * startJ = (startC)%(endJ) ++ * func(startI, startJ) ++ * aux_func() ++ * } ++ * ++ * Note: Additional wrapper functions are provided for aux_func which hides the counter template ++ * parameter since counter usually depends on endI, endJ, etc... ++ * ++ * Conventions: ++ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++)) ++ * ++ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for ++ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW) ++ */ ++namespace unrolls { ++ ++template ++EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { ++ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } ++ else EIGEN_IF_CONSTEXPR(N == 8) { ++ return 0xFF >> (8 - m); ++ } ++ else EIGEN_IF_CONSTEXPR(N == 4) { ++ return 0x0F >> (4 - m); ++ } ++ return 0; ++} ++ ++template ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel); ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); ++ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); ++ ++ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ ++ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); ++ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); ++ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); ++ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); ++ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); ++ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); ++ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); ++ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); ++ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); ++ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); ++ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); ++ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); ++ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); ++ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); ++ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); ++ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); ++ ++ kernel.packet[0] = T0; ++ kernel.packet[1] = T1; ++ kernel.packet[2] = T2; ++ kernel.packet[3] = T3; ++ kernel.packet[4] = T4; ++ kernel.packet[5] = T5; ++ kernel.packet[6] = T6; ++ kernel.packet[7] = T7; ++} ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ ptranspose(kernel); ++} ++ ++/*** ++ * Unrolls for tranposed C stores ++ */ ++template ++class trans { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - storeC ++ *********************************** ++ */ ++ ++ /** ++ * aux_storeC ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ * ++ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC ++ * ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]), ++ remMask(remM_)), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu(C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); ++ } ++ } ++ else { // This block is only needed for fp32 case ++ // Reinterpret as __m512 for _mm512_shuffle_f32x4 ++ vecFullFloat zmm2vecFullFloat = preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]); ++ // Swap lower and upper half of avx register. ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] = ++ preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]))); ++ } ++ } ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t remM_ = 0) { ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ /** ++ * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to ++ * "unrollN"xL ymm registers to be stored col-major into C. ++ * ++ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 zmm2 ++ * row1: zmm3 zmm4 zmm5 ++ * . ++ * . ++ * row7: zmm21 zmm22 zmm23 ++ * ++ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 ++ * row1: zmm2 zmm3 ++ * . ++ * . ++ * row7: zmm14 zmm15 ++ * ++ * ++ * In general we will have {1,2,3} groups of avx registers each of size ++ * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of ++ * avx registers are being transposed. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock &zmm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ constexpr int64_t zmmStride = unrollN / PacketSize; ++ PacketBlock r; ++ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0]; ++ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1]; ++ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2]; ++ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3]; ++ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4]; ++ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5]; ++ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6]; ++ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7]; ++ trans8x8blocks(r); ++ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0]; ++ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1]; ++ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2]; ++ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3]; ++ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4]; ++ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5]; ++ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6]; ++ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7]; ++ } ++}; ++ ++/** ++ * Unrolls for copyBToRowMajor ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers (using loadB). ++ * 2) Convert the block from column-major to row-major (transposeLxL) ++ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false). ++ * ++ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are ++ * used as temps for transposing. ++ * ++ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks ++ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm). ++ */ ++template ++class transB { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadB ++ * - storeB ++ * - loadBBlock ++ * - storeBBlock ++ *********************************** ++ */ ++ ++ /** ++ * aux_loadB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ ymm.packet[packetIndexOffset + startN] = ++ ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); ++ } ++ else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); ++ ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remK || remM) { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN], ++ remMask(rem_)); ++ } ++ else { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]); ++ } ++ ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_loadBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ transB::template loadB(&B_temp[startN], LDB_, ymm); ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(toTemp) { ++ transB::template storeB(&B_temp[startN], LDB_, ymm, remK_); ++ } ++ else { ++ transB::template storeB(&B_arr[0 + startN * LDB], LDB, ++ ymm, remM_); ++ } ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t rem_ = 0) { ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } ++ else { ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock &ymm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ PacketBlock r; ++ r.packet[0] = ymm.packet[packetIndexOffset + 0]; ++ r.packet[1] = ymm.packet[packetIndexOffset + 1]; ++ r.packet[2] = ymm.packet[packetIndexOffset + 2]; ++ r.packet[3] = ymm.packet[packetIndexOffset + 3]; ++ r.packet[4] = ymm.packet[packetIndexOffset + 4]; ++ r.packet[5] = ymm.packet[packetIndexOffset + 5]; ++ r.packet[6] = ymm.packet[packetIndexOffset + 6]; ++ r.packet[7] = ymm.packet[packetIndexOffset + 7]; ++ ptranspose(r); ++ ymm.packet[packetIndexOffset + 0] = r.packet[0]; ++ ymm.packet[packetIndexOffset + 1] = r.packet[1]; ++ ymm.packet[packetIndexOffset + 2] = r.packet[2]; ++ ymm.packet[packetIndexOffset + 3] = r.packet[3]; ++ ymm.packet[packetIndexOffset + 4] = r.packet[4]; ++ ymm.packet[packetIndexOffset + 5] = r.packet[5]; ++ ymm.packet[packetIndexOffset + 6] = r.packet[6]; ++ ymm.packet[packetIndexOffset + 7] = r.packet[7]; ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t U3 = PacketSize * 3; ++ constexpr int64_t U2 = PacketSize * 2; ++ constexpr int64_t U1 = PacketSize * 1; ++ /** ++ * Unrolls needed for each case: ++ * - AVX512 fp32 48 32 16 8 4 2 1 ++ * - AVX512 fp64 24 16 8 4 2 1 ++ * ++ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up. ++ */ ++ EIGEN_IF_CONSTEXPR(unrollN == U3) { ++ // load LxU3 B col major, transpose LxU3 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U3) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U2) { ++ // load LxU2 B col major, transpose LxU2 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U1) { ++ // load LxU1 B col major, transpose LxU1 row major ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); } ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 2) { ++ // load Lx2 B col major, transpose Lx2 row major ++ transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 1) { ++ // load Lx1 B col major, transpose Lx1 row major ++ transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++}; ++ ++/** ++ * Unrolls for triSolveKernel ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS). ++ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix) ++ * stored in AInPacket (using triSolveMicroKernel). ++ * 3) Store final results (in avx registers) back into memory (using storeRHS). ++ * ++ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most ++ * EIGEN_AVX_MAX_NUM_ROW registers. ++ */ ++template ++class trsm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadRHS ++ * - storeRHS ++ * - divRHSByDiag ++ * - updateRHS ++ * - triSolveMicroKernel ++ ************************************/ ++ /** ++ * aux_loadRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); ++ } ++ else { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); ++ } ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_storeRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); ++ } ++ else { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); ++ } ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_divRHSByDiag ++ * ++ * currM may be -1, (currM >=0) in enable_if checks for this ++ * ++ * 1-D unroll ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endK - counter; ++ constexpr int64_t startK = counterReverse; ++ ++ constexpr int64_t packetIndex = currM * endK + startK; ++ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_updateRHS ++ * ++ * 2-D unroll ++ * for(startM = initM; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = (endM - initM) * endK - counter; ++ constexpr int64_t startM = initM + counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ // For each row of A, first update all corresponding RHS ++ constexpr int64_t packetIndex = startM * endK + startK; ++ EIGEN_IF_CONSTEXPR(currentM > 0) { ++ RHSInPacket.packet[packetIndex] = ++ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK], ++ RHSInPacket.packet[packetIndex]); ++ } ++ ++ EIGEN_IF_CONSTEXPR(startK == endK - 1) { ++ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. ++ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { ++ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. ++ // This will be used in divRHSByDiag ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(currentM, currentM, LDA)]); ++ else AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(-currentM, -currentM, LDA)]); ++ } ++ else { ++ // Broadcast next off diagonal element of A ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[startM] = pset1(A_arr[idA(startM, currentM, LDA)]); ++ else AInPacket.packet[startM] = pset1(A_arr[idA(-startM, -currentM, LDA)]); ++ } ++ } ++ ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_triSolverMicroKernel ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ constexpr int64_t currentM = startM; ++ // Divides the right-hand side in row startM, by digonal value of A ++ // broadcasted to AInPacket.packet[startM-1] in the previous iteration. ++ // ++ // Without "if constexpr" the compiler instantiates the case <-1, numK> ++ // this is handled with enable_if to prevent out-of-bound warnings ++ // from the compiler ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ // After division, the rhs corresponding to subsequent rows of A can be partially updated ++ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) ++ // to be used in the next iteration. ++ trsm::template updateRHS(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ ++ // Handle division for the RHS corresponding to the final row of A. ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Only used if Triangular matrix has non-unit diagonal values ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * Update right-hand sides (stored in avx registers) ++ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW ++ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3. ++ * isFWDSolve: true => forward substitution, false => backwards substitution ++ * isUnitDiag: true => triangular matrix has unit diagonal. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ static_assert(numK >= 1 && numK <= 3, "numK out of range"); ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); ++ } ++}; ++ ++/** ++ * Unrolls for gemm kernel ++ * ++ * isAdd: true => C += A*B, false => C -= A*B ++ */ ++template ++class gemm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - setzero ++ * - updateC ++ * - storeC ++ * - startLoadB ++ * - triSolveMicroKernel ++ ************************************/ ++ ++ /** ++ * aux_setzero ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero( ++ PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]); ++ aux_setzero(zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero( ++ PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_updateC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize], remMask(rem_)), ++ zmm.packet[startN * endM + startM], remMask(rem_)); ++ else zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_storeC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM], ++ remMask(rem_)); ++ else pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startLoadB ++ * ++ * 1-D unroll ++ * for(startL = 0; startL < endL; startL++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endL - counter; ++ constexpr int64_t startL = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask(rem_)); ++ else zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]); ++ ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startBCastA ++ * ++ * 1-D unroll ++ * for(startB = 0; startB < endB; startB++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endB - counter; ++ constexpr int64_t startB = counterReverse; ++ ++ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1(&A_t[idA(startB, 0, LDA)]); ++ ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_loadB ++ * currK: current K ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ if ((numLoad / endM + currK < unrollK)) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask(rem_)); ++ } ++ else { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]); ++ } ++ ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_microKernel ++ * ++ * 3-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN * endK - counter; ++ constexpr int startK = counterReverse / (endM * endN); ++ constexpr int startN = (counterReverse / (endM)) % endN; ++ constexpr int startM = counterReverse % endM; ++ ++ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { ++ gemm::template startLoadB(B_t, LDB, zmm, rem_); ++ gemm::template startBCastA(A_t, LDA, zmm); ++ } ++ ++ { ++ // Interleave FMA and Bcast ++ EIGEN_IF_CONSTEXPR(isAdd) { ++ zmm.packet[startN * endM + startM] = ++ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ else { ++ zmm.packet[startN * endM + startM] = ++ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ // Bcast ++ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { ++ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1(&A_t[idA( ++ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]); ++ } ++ } ++ ++ // We have updated all accumlators, time to load next set of B's ++ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) { ++ gemm::template loadB(B_t, LDB, zmm, rem_); ++ } ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock &zmm) { ++ aux_setzero(zmm); ++ } ++ ++ /** ++ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ /** ++ * Use numLoad registers for loading B at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Use numBCast registers for broadcasting A at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA, ++ PacketBlock &zmm) { ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ /** ++ * Loads next set of B into vector registers between each K unroll. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B. ++ * A matrix can be row/col-major. B matrix is assumed row-major. ++ * ++ * isARowMajor: is A row major ++ * endM: Number registers per row ++ * endN: Number of rows ++ * endK: Loop unroll for K. ++ * numLoad: Number of registers for loading B. ++ * numBCast: Number of registers for broadcasting A. ++ * ++ * Ex: microkernel: 8x48 unroll (24 accumulators), k unrolled 4 times, ++ * 6 register for loading B, 2 for broadcasting A. ++ * ++ * Note: Ideally the microkernel should not have any register spilling. ++ * The avx instruction counts should be: ++ * - endK*endN vbroadcasts{s,d} ++ * - endK*endM vmovup{s,d} ++ * - endK*endN*endM FMAs ++ * ++ * From testing, there are no register spills with clang. There are register spills with GNU, which ++ * causes a performance hit. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, ++ rem_); ++ } ++}; ++} // namespace unrolls ++ ++#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H From f58a3140f59f16cd722c372b5685c01052e05e1f Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:54:12 +0100 Subject: [PATCH 08/12] Delete easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch --- ...x-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch | 1217 ----------------- 1 file changed, 1217 deletions(-) delete mode 100644 easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch b/easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch deleted file mode 100644 index 0a4e63a0457..00000000000 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24_ml_dtypes_0.3.2_EigenAvx512.patch +++ /dev/null @@ -1,1217 +0,0 @@ -# ml_dtype 0.3.2 provides Eigen commit 7bf2968. src/Core/arch/AVX512/TrsmUnrolls.inc is missing. -diff -ru --new-file ./ml_dtypes-0.3.2_ori/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc ./ml_dtypes-0.3.2/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc ---- ./ml_dtypes-0.3.2_ori/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 1970-01-01 01:00:00.000000000 +0100 -+++ ./ml_dtypes-0.3.2/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 2024-02-14 10:32:25.492978066 +0100 -@@ -0,0 +1,1212 @@ -+// This file is part of Eigen, a lightweight C++ template library -+// for linear algebra. -+// -+// Copyright (C) 2022 Intel Corporation -+// -+// This Source Code Form is subject to the terms of the Mozilla -+// Public License v. 2.0. If a copy of the MPL was not distributed -+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -+ -+#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H -+#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H -+ -+template -+EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) { -+ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; -+ else return i + j * LDA; -+} -+ -+/** -+ * This namespace contains various classes used to generate compile-time unrolls which are -+ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested -+ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion -+ * -+ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop. -+ * -+ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++) -+ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ) -+ * func(startI,startJ) startJ = (startC)%(endJ) -+ * func(...) -+ * -+ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function -+ * with a template parameter used as a counter. -+ * -+ * template -+ * std::enable_if_t<(counter <= 0)> <---- tail case. -+ * aux_func {} -+ * -+ * template -+ * std::enable_if_t<(counter > 0)> <---- actual for-loop -+ * aux_func { -+ * startC = endI*endJ - counter -+ * startI = (startC)/(endJ) -+ * startJ = (startC)%(endJ) -+ * func(startI, startJ) -+ * aux_func() -+ * } -+ * -+ * Note: Additional wrapper functions are provided for aux_func which hides the counter template -+ * parameter since counter usually depends on endI, endJ, etc... -+ * -+ * Conventions: -+ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++)) -+ * -+ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for -+ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW) -+ */ -+namespace unrolls { -+ -+template -+EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { -+ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } -+ else EIGEN_IF_CONSTEXPR(N == 8) { -+ return 0xFF >> (8 - m); -+ } -+ else EIGEN_IF_CONSTEXPR(N == 4) { -+ return 0x0F >> (4 - m); -+ } -+ return 0; -+} -+ -+template -+EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel); -+ -+template <> -+EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { -+ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); -+ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); -+ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); -+ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); -+ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); -+ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); -+ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); -+ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); -+ -+ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); -+ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); -+ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); -+ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); -+ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); -+ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); -+ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); -+ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); -+ -+ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); -+ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); -+ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); -+ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); -+ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); -+ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); -+ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); -+ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); -+ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); -+ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); -+ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); -+ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); -+ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); -+ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); -+ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); -+ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); -+ -+ kernel.packet[0] = T0; -+ kernel.packet[1] = T1; -+ kernel.packet[2] = T2; -+ kernel.packet[3] = T3; -+ kernel.packet[4] = T4; -+ kernel.packet[5] = T5; -+ kernel.packet[6] = T6; -+ kernel.packet[7] = T7; -+} -+ -+template <> -+EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { -+ ptranspose(kernel); -+} -+ -+/*** -+ * Unrolls for tranposed C stores -+ */ -+template -+class trans { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - storeC -+ *********************************** -+ */ -+ -+ /** -+ * aux_storeC -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN++) -+ * -+ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC -+ * -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { -+ EIGEN_IF_CONSTEXPR(remM) { -+ pstoreu( -+ C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), -+ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]), -+ remMask(remM_)), -+ remMask(remM_)); -+ } -+ else { -+ pstoreu(C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN), -+ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); -+ } -+ } -+ else { // This block is only needed for fp32 case -+ // Reinterpret as __m512 for _mm512_shuffle_f32x4 -+ vecFullFloat zmm2vecFullFloat = preinterpret( -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]); -+ // Swap lower and upper half of avx register. -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] = -+ preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); -+ -+ EIGEN_IF_CONSTEXPR(remM) { -+ pstoreu( -+ C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), -+ preinterpret( -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])), -+ remMask(remM_)); -+ } -+ else { -+ pstoreu( -+ C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN), -+ preinterpret( -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]))); -+ } -+ } -+ aux_storeC(C_arr, LDC, zmm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(C_arr); -+ EIGEN_UNUSED_VARIABLE(LDC); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, -+ PacketBlock &zmm, -+ int64_t remM_ = 0) { -+ aux_storeC(C_arr, LDC, zmm, remM_); -+ } -+ -+ /** -+ * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to -+ * "unrollN"xL ymm registers to be stored col-major into C. -+ * -+ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows: -+ * -+ * row0: zmm0 zmm1 zmm2 -+ * row1: zmm3 zmm4 zmm5 -+ * . -+ * . -+ * row7: zmm21 zmm22 zmm23 -+ * -+ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows: -+ * -+ * row0: zmm0 zmm1 -+ * row1: zmm2 zmm3 -+ * . -+ * . -+ * row7: zmm14 zmm15 -+ * -+ * -+ * In general we will have {1,2,3} groups of avx registers each of size -+ * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of -+ * avx registers are being transposed. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock &zmm) { -+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted -+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. -+ constexpr int64_t zmmStride = unrollN / PacketSize; -+ PacketBlock r; -+ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0]; -+ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1]; -+ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2]; -+ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3]; -+ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4]; -+ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5]; -+ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6]; -+ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7]; -+ trans8x8blocks(r); -+ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0]; -+ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1]; -+ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2]; -+ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3]; -+ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4]; -+ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5]; -+ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6]; -+ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7]; -+ } -+}; -+ -+/** -+ * Unrolls for copyBToRowMajor -+ * -+ * Idea: -+ * 1) Load a block of right-hand sides to registers (using loadB). -+ * 2) Convert the block from column-major to row-major (transposeLxL) -+ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false). -+ * -+ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are -+ * used as temps for transposing. -+ * -+ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks -+ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm). -+ */ -+template -+class transB { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - loadB -+ * - storeB -+ * - loadBBlock -+ * - storeBBlock -+ *********************************** -+ */ -+ -+ /** -+ * aux_loadB -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(remM) { -+ ymm.packet[packetIndexOffset + startN] = -+ ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); -+ } -+ else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); -+ -+ aux_loadB(B_arr, LDB, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ /** -+ * aux_storeB -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(remK || remM) { -+ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN], -+ remMask(rem_)); -+ } -+ else { -+ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]); -+ } -+ -+ aux_storeB(B_arr, LDB, ymm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_loadBBlock -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ transB::template loadB(&B_temp[startN], LDB_, ymm); -+ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(B_temp); -+ EIGEN_UNUSED_VARIABLE(LDB_); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ /** -+ * aux_storeBBlock -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(toTemp) { -+ transB::template storeB(&B_temp[startN], LDB_, ymm, remK_); -+ } -+ else { -+ transB::template storeB(&B_arr[0 + startN * LDB], LDB, -+ ymm, remM_); -+ } -+ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(B_temp); -+ EIGEN_UNUSED_VARIABLE(LDB_); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ /******************************************************** -+ * Wrappers for aux_XXXX to hide counter parameter -+ ********************************************************/ -+ -+ template -+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ aux_loadB(B_arr, LDB, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB, -+ PacketBlock &ymm, -+ int64_t rem_ = 0) { -+ aux_storeB(B_arr, LDB, ymm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } -+ else { -+ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock &ymm) { -+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted -+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. -+ PacketBlock r; -+ r.packet[0] = ymm.packet[packetIndexOffset + 0]; -+ r.packet[1] = ymm.packet[packetIndexOffset + 1]; -+ r.packet[2] = ymm.packet[packetIndexOffset + 2]; -+ r.packet[3] = ymm.packet[packetIndexOffset + 3]; -+ r.packet[4] = ymm.packet[packetIndexOffset + 4]; -+ r.packet[5] = ymm.packet[packetIndexOffset + 5]; -+ r.packet[6] = ymm.packet[packetIndexOffset + 6]; -+ r.packet[7] = ymm.packet[packetIndexOffset + 7]; -+ ptranspose(r); -+ ymm.packet[packetIndexOffset + 0] = r.packet[0]; -+ ymm.packet[packetIndexOffset + 1] = r.packet[1]; -+ ymm.packet[packetIndexOffset + 2] = r.packet[2]; -+ ymm.packet[packetIndexOffset + 3] = r.packet[3]; -+ ymm.packet[packetIndexOffset + 4] = r.packet[4]; -+ ymm.packet[packetIndexOffset + 5] = r.packet[5]; -+ ymm.packet[packetIndexOffset + 6] = r.packet[6]; -+ ymm.packet[packetIndexOffset + 7] = r.packet[7]; -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ constexpr int64_t U3 = PacketSize * 3; -+ constexpr int64_t U2 = PacketSize * 2; -+ constexpr int64_t U1 = PacketSize * 1; -+ /** -+ * Unrolls needed for each case: -+ * - AVX512 fp32 48 32 16 8 4 2 1 -+ * - AVX512 fp64 24 16 8 4 2 1 -+ * -+ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up. -+ */ -+ EIGEN_IF_CONSTEXPR(unrollN == U3) { -+ // load LxU3 B col major, transpose LxU3 row major -+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3); -+ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ -+ EIGEN_IF_CONSTEXPR(maxUBlock < U3) { -+ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, -+ ymm, remM_); -+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, -+ ymm, remM_); -+ } -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == U2) { -+ // load LxU2 B col major, transpose LxU2 row major -+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2); -+ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ -+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) { -+ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, -+ &B_temp[maxUBlock], LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, -+ &B_temp[maxUBlock], LDB_, ymm, remM_); -+ } -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == U1) { -+ // load LxU1 B col major, transpose LxU1 row major -+ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); } -+ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { -+ // load Lx4 B col major, transpose Lx4 row major -+ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { -+ // load Lx4 B col major, transpose Lx4 row major -+ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 2) { -+ // load Lx2 B col major, transpose Lx2 row major -+ transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 1) { -+ // load Lx1 B col major, transpose Lx1 row major -+ transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ } -+}; -+ -+/** -+ * Unrolls for triSolveKernel -+ * -+ * Idea: -+ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS). -+ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix) -+ * stored in AInPacket (using triSolveMicroKernel). -+ * 3) Store final results (in avx registers) back into memory (using storeRHS). -+ * -+ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most -+ * EIGEN_AVX_MAX_NUM_ROW registers. -+ */ -+template -+class trsm { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - loadRHS -+ * - storeRHS -+ * - divRHSByDiag -+ * - updateRHS -+ * - triSolveMicroKernel -+ ************************************/ -+ /** -+ * aux_loadRHS -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ constexpr int64_t counterReverse = endM * endK - counter; -+ constexpr int64_t startM = counterReverse / (endK); -+ constexpr int64_t startK = counterReverse % endK; -+ -+ constexpr int64_t packetIndex = startM * endK + startK; -+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; -+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; -+ EIGEN_IF_CONSTEXPR(krem) { -+ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); -+ } -+ else { -+ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); -+ } -+ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(rem); -+ } -+ -+ /** -+ * aux_storeRHS -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ constexpr int64_t counterReverse = endM * endK - counter; -+ constexpr int64_t startM = counterReverse / (endK); -+ constexpr int64_t startK = counterReverse % endK; -+ -+ constexpr int64_t packetIndex = startM * endK + startK; -+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; -+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; -+ EIGEN_IF_CONSTEXPR(krem) { -+ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); -+ } -+ else { -+ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); -+ } -+ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(rem); -+ } -+ -+ /** -+ * aux_divRHSByDiag -+ * -+ * currM may be -1, (currM >=0) in enable_if checks for this -+ * -+ * 1-D unroll -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag( -+ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { -+ constexpr int64_t counterReverse = endK - counter; -+ constexpr int64_t startK = counterReverse; -+ -+ constexpr int64_t packetIndex = currM * endK + startK; -+ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); -+ aux_divRHSByDiag(RHSInPacket, AInPacket); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> aux_divRHSByDiag( -+ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(AInPacket); -+ } -+ -+ /** -+ * aux_updateRHS -+ * -+ * 2-D unroll -+ * for(startM = initM; startM < endM; startM++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ constexpr int64_t counterReverse = (endM - initM) * endK - counter; -+ constexpr int64_t startM = initM + counterReverse / (endK); -+ constexpr int64_t startK = counterReverse % endK; -+ -+ // For each row of A, first update all corresponding RHS -+ constexpr int64_t packetIndex = startM * endK + startK; -+ EIGEN_IF_CONSTEXPR(currentM > 0) { -+ RHSInPacket.packet[packetIndex] = -+ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK], -+ RHSInPacket.packet[packetIndex]); -+ } -+ -+ EIGEN_IF_CONSTEXPR(startK == endK - 1) { -+ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. -+ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { -+ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. -+ // This will be used in divRHSByDiag -+ EIGEN_IF_CONSTEXPR(isFWDSolve) -+ AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(currentM, currentM, LDA)]); -+ else AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(-currentM, -currentM, LDA)]); -+ } -+ else { -+ // Broadcast next off diagonal element of A -+ EIGEN_IF_CONSTEXPR(isFWDSolve) -+ AInPacket.packet[startM] = pset1(A_arr[idA(startM, currentM, LDA)]); -+ else AInPacket.packet[startM] = pset1(A_arr[idA(-startM, -currentM, LDA)]); -+ } -+ } -+ -+ aux_updateRHS( -+ A_arr, LDA, RHSInPacket, AInPacket); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ EIGEN_UNUSED_VARIABLE(A_arr); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(AInPacket); -+ } -+ -+ /** -+ * aux_triSolverMicroKernel -+ * -+ * 1-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ constexpr int64_t counterReverse = endM - counter; -+ constexpr int64_t startM = counterReverse; -+ -+ constexpr int64_t currentM = startM; -+ // Divides the right-hand side in row startM, by digonal value of A -+ // broadcasted to AInPacket.packet[startM-1] in the previous iteration. -+ // -+ // Without "if constexpr" the compiler instantiates the case <-1, numK> -+ // this is handled with enable_if to prevent out-of-bound warnings -+ // from the compiler -+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) -+ trsm::template divRHSByDiag(RHSInPacket, AInPacket); -+ -+ // After division, the rhs corresponding to subsequent rows of A can be partially updated -+ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) -+ // to be used in the next iteration. -+ trsm::template updateRHS(A_arr, LDA, RHSInPacket, -+ AInPacket); -+ -+ // Handle division for the RHS corresponding to the final row of A. -+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) -+ trsm::template divRHSByDiag(RHSInPacket, AInPacket); -+ -+ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, -+ AInPacket); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ EIGEN_UNUSED_VARIABLE(A_arr); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(AInPacket); -+ } -+ -+ /******************************************************** -+ * Wrappers for aux_XXXX to hide counter parameter -+ ********************************************************/ -+ -+ /** -+ * Load endMxendK block of B to RHSInPacket -+ * Masked loads are used for cases where endK is not a multiple of PacketSize -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB, -+ PacketBlock &RHSInPacket, int64_t rem = 0) { -+ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ /** -+ * Load endMxendK block of B to RHSInPacket -+ * Masked loads are used for cases where endK is not a multiple of PacketSize -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB, -+ PacketBlock &RHSInPacket, int64_t rem = 0) { -+ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ /** -+ * Only used if Triangular matrix has non-unit diagonal values -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ aux_divRHSByDiag(RHSInPacket, AInPacket); -+ } -+ -+ /** -+ * Update right-hand sides (stored in avx registers) -+ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA, -+ PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ aux_updateRHS( -+ A_arr, LDA, RHSInPacket, AInPacket); -+ } -+ -+ /** -+ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW -+ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3. -+ * isFWDSolve: true => forward substitution, false => backwards substitution -+ * isUnitDiag: true => triangular matrix has unit diagonal. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, -+ PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ static_assert(numK >= 1 && numK <= 3, "numK out of range"); -+ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); -+ } -+}; -+ -+/** -+ * Unrolls for gemm kernel -+ * -+ * isAdd: true => C += A*B, false => C -= A*B -+ */ -+template -+class gemm { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - setzero -+ * - updateC -+ * - storeC -+ * - startLoadB -+ * - triSolveMicroKernel -+ ************************************/ -+ -+ /** -+ * aux_setzero -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero( -+ PacketBlock &zmm) { -+ constexpr int64_t counterReverse = endM * endN - counter; -+ constexpr int64_t startM = counterReverse / (endN); -+ constexpr int64_t startN = counterReverse % endN; -+ -+ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]); -+ aux_setzero(zmm); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero( -+ PacketBlock &zmm) { -+ EIGEN_UNUSED_VARIABLE(zmm); -+ } -+ -+ /** -+ * aux_updateC -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endM * endN - counter; -+ constexpr int64_t startM = counterReverse / (endN); -+ constexpr int64_t startN = counterReverse % endN; -+ -+ EIGEN_IF_CONSTEXPR(rem) -+ zmm.packet[startN * endM + startM] = -+ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize], remMask(rem_)), -+ zmm.packet[startN * endM + startM], remMask(rem_)); -+ else zmm.packet[startN * endM + startM] = -+ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); -+ aux_updateC(C_arr, LDC, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(C_arr); -+ EIGEN_UNUSED_VARIABLE(LDC); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_storeC -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endM * endN - counter; -+ constexpr int64_t startM = counterReverse / (endN); -+ constexpr int64_t startN = counterReverse % endN; -+ -+ EIGEN_IF_CONSTEXPR(rem) -+ pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM], -+ remMask(rem_)); -+ else pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]); -+ aux_storeC(C_arr, LDC, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(C_arr); -+ EIGEN_UNUSED_VARIABLE(LDC); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_startLoadB -+ * -+ * 1-D unroll -+ * for(startL = 0; startL < endL; startL++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endL - counter; -+ constexpr int64_t startL = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(rem) -+ zmm.packet[unrollM * unrollN + startL] = -+ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask(rem_)); -+ else zmm.packet[unrollM * unrollN + startL] = -+ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]); -+ -+ aux_startLoadB(B_t, LDB, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_t); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_startBCastA -+ * -+ * 1-D unroll -+ * for(startB = 0; startB < endB; startB++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA( -+ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { -+ constexpr int64_t counterReverse = endB - counter; -+ constexpr int64_t startB = counterReverse; -+ -+ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1(&A_t[idA(startB, 0, LDA)]); -+ -+ aux_startBCastA(A_t, LDA, zmm); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA( -+ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { -+ EIGEN_UNUSED_VARIABLE(A_t); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ } -+ -+ /** -+ * aux_loadB -+ * currK: current K -+ * -+ * 1-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ if ((numLoad / endM + currK < unrollK)) { -+ constexpr int64_t counterReverse = endM - counter; -+ constexpr int64_t startM = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(rem) { -+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = -+ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask(rem_)); -+ } -+ else { -+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = -+ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]); -+ } -+ -+ aux_loadB(B_t, LDB, zmm, rem_); -+ } -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_t); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_microKernel -+ * -+ * 3-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel( -+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endM * endN * endK - counter; -+ constexpr int startK = counterReverse / (endM * endN); -+ constexpr int startN = (counterReverse / (endM)) % endN; -+ constexpr int startM = counterReverse % endM; -+ -+ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { -+ gemm::template startLoadB(B_t, LDB, zmm, rem_); -+ gemm::template startBCastA(A_t, LDA, zmm); -+ } -+ -+ { -+ // Interleave FMA and Bcast -+ EIGEN_IF_CONSTEXPR(isAdd) { -+ zmm.packet[startN * endM + startM] = -+ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], -+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); -+ } -+ else { -+ zmm.packet[startN * endM + startM] = -+ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], -+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); -+ } -+ // Bcast -+ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { -+ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1(&A_t[idA( -+ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]); -+ } -+ } -+ -+ // We have updated all accumlators, time to load next set of B's -+ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) { -+ gemm::template loadB(B_t, LDB, zmm, rem_); -+ } -+ aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel( -+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_t); -+ EIGEN_UNUSED_VARIABLE(A_t); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /******************************************************** -+ * Wrappers for aux_XXXX to hide counter parameter -+ ********************************************************/ -+ -+ template -+ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock &zmm) { -+ aux_setzero(zmm); -+ } -+ -+ /** -+ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_updateC(C_arr, LDC, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_storeC(C_arr, LDC, zmm, rem_); -+ } -+ -+ /** -+ * Use numLoad registers for loading B at start of microKernel -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_startLoadB(B_t, LDB, zmm, rem_); -+ } -+ -+ /** -+ * Use numBCast registers for broadcasting A at start of microKernel -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA, -+ PacketBlock &zmm) { -+ aux_startBCastA(A_t, LDA, zmm); -+ } -+ -+ /** -+ * Loads next set of B into vector registers between each K unroll. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_loadB(B_t, LDB, zmm, rem_); -+ } -+ -+ /** -+ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B. -+ * A matrix can be row/col-major. B matrix is assumed row-major. -+ * -+ * isARowMajor: is A row major -+ * endM: Number registers per row -+ * endN: Number of rows -+ * endK: Loop unroll for K. -+ * numLoad: Number of registers for loading B. -+ * numBCast: Number of registers for broadcasting A. -+ * -+ * Ex: microkernel: 8x48 unroll (24 accumulators), k unrolled 4 times, -+ * 6 register for loading B, 2 for broadcasting A. -+ * -+ * Note: Ideally the microkernel should not have any register spilling. -+ * The avx instruction counts should be: -+ * - endK*endN vbroadcasts{s,d} -+ * - endK*endM vmovup{s,d} -+ * - endK*endN*endM FMAs -+ * -+ * From testing, there are no register spills with clang. There are register spills with GNU, which -+ * causes a performance hit. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_microKernel(B_t, A_t, LDB, LDA, zmm, -+ rem_); -+ } -+}; -+} // namespace unrolls -+ -+#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H From ed73d9b8940e27d24c6134685db9678f43d4e1f7 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:11:38 +0100 Subject: [PATCH 09/12] Update easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb Co-authored-by: Jasper Grimm <65227842+jfgrimm@users.noreply.github.com> --- .../easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb index febaae194bb..fe4dde38022 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -42,6 +42,12 @@ local_tfrt_commit = '4665f7483063a16b6113a05eb45f98103cc1d611' local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit +# deliberately not testing in parallel, as that results in (additional) failing tests; +# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, +# see https://github.com/google/jax/issues/7323 and +# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; +# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; +# use NVIDIA_TF32_OVERRIDE=0 to avoid lossing numerical precision by disabling TF32 Tensor Cores; local_test = "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 " local_test += "XLA_PYTHON_CLIENT_ALLOCATOR=platform " local_test += "JAX_ENABLE_X64=true pytest -vv tests " From dfe5f9771a698e7ad24f607ef72818dc385c8079 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:11:47 +0100 Subject: [PATCH 10/12] Update easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb Co-authored-by: Jasper Grimm <65227842+jfgrimm@users.noreply.github.com> --- easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb | 2 -- 1 file changed, 2 deletions(-) diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb index 822c1e6757d..6ca1c6e4ad1 100644 --- a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb @@ -28,8 +28,6 @@ dependencies = [ use_pip = True -default_easyblock = 'PythonPackage' - exts_list = [ ('opt_einsum', '3.3.0', { 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], From ce191f708ad1e202651ef333061443a6d72ad40f Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:12:01 +0100 Subject: [PATCH 11/12] Update easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb Co-authored-by: Jasper Grimm <65227842+jfgrimm@users.noreply.github.com> --- .../easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb | 2 -- 1 file changed, 2 deletions(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb index fe4dde38022..2b43165bc0f 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.24-foss-2023a-CUDA-12.1.1.eb @@ -110,13 +110,11 @@ components = [ exts_list = [ (name, version, { - # 'patches': ['jax-0.4.24_cusparse.patch'], 'runtest': local_test, 'source_tmpl': '%(name)s-v%(version)s.tar.gz', 'source_urls': ['https://github.com/google/jax/archive/'], 'checksums': [ {'jax-v0.4.24.tar.gz': '6e52d8b547624bd70d423e6bf85f4fcd47336b529f1a4f6a94fac3096017a694'}, - # {'jax-0.4.24_cusparse.patch': 'a7d61e412ef3e77b2d4a6ba3c98e5c2c571523ccf46325ea180725f77a9dfb36'}, ], }), ] From 0b7f53a1b463d97bfb7dac9a13d2633dda3c79a1 Mon Sep 17 00:00:00 2001 From: thoffman Date: Wed, 28 Feb 2024 10:21:57 +0100 Subject: [PATCH 12/12] update comment in ml_dtypes-0.3.2_EigenAvx512.patch --- .../easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb | 7 ++++--- .../m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch | 4 +++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb index 6ca1c6e4ad1..df6bd3d5134 100644 --- a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb @@ -28,6 +28,8 @@ dependencies = [ use_pip = True +default_easyblock = 'PythonPackage' + exts_list = [ ('opt_einsum', '3.3.0', { 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], @@ -39,10 +41,9 @@ exts_list = [ 'patches': [('ml_dtypes-0.3.2_EigenAvx512.patch', 1)], 'checksums': [ {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, - {'ml_dtypes-0.3.2_EigenAvx512.patch': - 'fef229a24515b9c03be0d2e932c499965212e3a03ae3ede5d037874f88f93c46'}, + {'ml_dtypes-0.3.2_EigenAvx512.patch': '197b05b0b7f611749824369f026099f6a172f9e8eab6ebb6504a16573746c892'}, ], - }) + }), ] sanity_pip_check = True diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch index 79da0f03f1d..42ea0606391 100644 --- a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch @@ -1,4 +1,6 @@ -# ml_dtype 0.3.2 provides Eigen commit 7bf2968. src/Core/arch/AVX512/TrsmUnrolls.inc is missing. +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/01 +# ml_dtype 0.3.2 ships a copy of Eigen commit 7bf2968 (https://gitlab.com/libeigen/eigen/-/commit/7bf2968). +# This copy is missing the file src/Core/arch/AVX512/TrsmUnrolls.inc, which is added by the present patch. diff -ru --new-file old/third_party_ori/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc --- old/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 1970-01-01 01:00:00.000000000 +0100 +++ new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 2024-02-14 10:32:25.492978066 +0100