From 8e4e1842aaf822cd13857c152451a3145270d842 Mon Sep 17 00:00:00 2001 From: thoffman Date: Thu, 14 Mar 2024 16:17:35 +0100 Subject: [PATCH 01/21] {tools}[foss/2023a] jax v0.4.25, ml_dtypes v0.3.2 w/ CUDA 12.1.1 --- .../jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb | 113 ++ .../m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb | 51 + .../ml_dtypes-0.3.2_EigenAvx512.patch | 1219 +++++++++++++++++ 3 files changed, 1383 insertions(+) create mode 100644 easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb create mode 100644 easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb create mode 100644 easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb new file mode 100644 index 00000000000..6ccd8a9d010 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb @@ -0,0 +1,113 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Updated by: Thomas Hoffmann (EMBL Heidelberg) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.4.25' +versionsuffix = '-CUDA-%(cudaver)s' + +homepage = 'https://pypi.python.org/pypi/jax' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'foss', 'version': '2023a'} +cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] + +builddependencies = [ + ('Bazel', '6.3.1'), + ('pytest-xdist', '3.3.1'), + # git 2.x required to fetch repository 'io_bazel_rules_docker' + ('git', '2.41.0', '-nodocs'), + ('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py + ('poetry', '1.5.1'), +] + +dependencies = [ + ('CUDA', '12.1.1', '', SYSTEM), + ('cuDNN', '8.9.2.26', versionsuffix, SYSTEM), + ('NCCL', '2.18.3', versionsuffix), + ('Python', '3.11.3'), + ('SciPy-bundle', '2023.07'), + ('flatbuffers-python', '23.5.26'), + ('zlib', '1.2.13'), + ('ml_dtypes', '0.3.2'), +] + +# downloading xla and other tarballs to avoid that Bazel downloads it during the build +# note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl +local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' +local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' +local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit +local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit + +# deliberately not testing in parallel, as that results in (additional) failing tests; +# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, +# see https://github.com/google/jax/issues/7323 and +# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; +# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; +# use NVIDIA_TF32_OVERRIDE=0 to avoid lossing numerical precision by disabling TF32 Tensor Cores; +local_test = "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 " +local_test += "XLA_PYTHON_CLIENT_ALLOCATOR=platform " +local_test += "JAX_ENABLE_X64=true pytest -vv tests " + +use_pip = True + +default_easyblock = 'PythonPackage' +default_component_specs = { + 'sources': [SOURCE_TAR_GZ], + 'source_urls': [PYPI_SOURCE], + 'start_dir': '%(name)s-%(version)s', + 'use_pip': True, + 'sanity_pip_check': True, + 'download_dep_fail': True, +} + +components = [ + ('absl-py', '2.1.0', { + 'options': {'modulename': 'absl'}, + 'checksums': ['7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff'], + }), + ('jaxlib', version, { + 'sources': [ + '%(name)s-v%(version)s.tar.gz', + { + 'download_filename': '%s.tar.gz' % local_xla_commit, + 'filename': 'xla-%s.tar.gz' % local_xla_commit, + }, + { + 'download_filename': '%s.tar.gz' % local_tfrt_commit, + 'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit, + }, + ], + 'source_urls': [ + 'https://github.com/google/jax/archive/', + 'https://github.com/tensorflow/runtime/archive', + 'https://github.com/openxla/xla/archive' + ], + 'checksums': [ + {'jaxlib-v0.4.25.tar.gz': + 'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'}, + {'xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4.tar.gz': + '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'}, + {'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz': + 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, + ], + 'start_dir': 'jax-jaxlib-v%(version)s', + 'buildopts': local_repo_opt + }), +] + +exts_list = [ + (name, version, { + 'runtest': "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_ALLOCATOR=platform JAX_ENABLE_X64=true pytest -vv tests ", + 'source_tmpl': '%(name)s-v%(version)s.tar.gz', + 'source_urls': ['https://github.com/google/jax/archive/'], + 'checksums': ['8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'], + }), +] + +sanity_pip_check = True + +moduleclass = 'tools' diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb new file mode 100644 index 00000000000..df6bd3d5134 --- /dev/null +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb @@ -0,0 +1,51 @@ +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/02 +easyblock = 'PythonBundle' + +name = 'ml_dtypes' +version = '0.3.2' + +homepage = 'https://github.com/jax-ml/ml_dtypes' +description = """ +ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used +in machine learning libraries, including: + +bfloat16: an alternative to the standard float16 format +float8_*: several experimental 8-bit floating point representations including: +float8_e4m3b11fnuz +float8_e4m3fn +float8_e4m3fnuz +float8_e5m2 +float8_e5m2fnuz +""" + +toolchain = {'name': 'foss', 'version': '2023a'} + +dependencies = [ + ('Python', '3.11.3'), + ('SciPy-bundle', '2023.07'), +] + + +use_pip = True + +default_easyblock = 'PythonPackage' + +exts_list = [ + ('opt_einsum', '3.3.0', { + 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], + }), + ('etils', '1.6.0', { + 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], + }), + (name, version, { + 'patches': [('ml_dtypes-0.3.2_EigenAvx512.patch', 1)], + 'checksums': [ + {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, + {'ml_dtypes-0.3.2_EigenAvx512.patch': '197b05b0b7f611749824369f026099f6a172f9e8eab6ebb6504a16573746c892'}, + ], + }), +] + +sanity_pip_check = True + +moduleclass = 'tools' diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch new file mode 100644 index 00000000000..42ea0606391 --- /dev/null +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch @@ -0,0 +1,1219 @@ +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/01 +# ml_dtype 0.3.2 ships a copy of Eigen commit 7bf2968 (https://gitlab.com/libeigen/eigen/-/commit/7bf2968). +# This copy is missing the file src/Core/arch/AVX512/TrsmUnrolls.inc, which is added by the present patch. +diff -ru --new-file old/third_party_ori/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +--- old/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 1970-01-01 01:00:00.000000000 +0100 ++++ new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 2024-02-14 10:32:25.492978066 +0100 +@@ -0,0 +1,1212 @@ ++// This file is part of Eigen, a lightweight C++ template library ++// for linear algebra. ++// ++// Copyright (C) 2022 Intel Corporation ++// ++// This Source Code Form is subject to the terms of the Mozilla ++// Public License v. 2.0. If a copy of the MPL was not distributed ++// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. ++ ++#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H ++ ++template ++EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) { ++ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; ++ else return i + j * LDA; ++} ++ ++/** ++ * This namespace contains various classes used to generate compile-time unrolls which are ++ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested ++ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion ++ * ++ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop. ++ * ++ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++) ++ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ) ++ * func(startI,startJ) startJ = (startC)%(endJ) ++ * func(...) ++ * ++ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function ++ * with a template parameter used as a counter. ++ * ++ * template ++ * std::enable_if_t<(counter <= 0)> <---- tail case. ++ * aux_func {} ++ * ++ * template ++ * std::enable_if_t<(counter > 0)> <---- actual for-loop ++ * aux_func { ++ * startC = endI*endJ - counter ++ * startI = (startC)/(endJ) ++ * startJ = (startC)%(endJ) ++ * func(startI, startJ) ++ * aux_func() ++ * } ++ * ++ * Note: Additional wrapper functions are provided for aux_func which hides the counter template ++ * parameter since counter usually depends on endI, endJ, etc... ++ * ++ * Conventions: ++ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++)) ++ * ++ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for ++ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW) ++ */ ++namespace unrolls { ++ ++template ++EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { ++ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } ++ else EIGEN_IF_CONSTEXPR(N == 8) { ++ return 0xFF >> (8 - m); ++ } ++ else EIGEN_IF_CONSTEXPR(N == 4) { ++ return 0x0F >> (4 - m); ++ } ++ return 0; ++} ++ ++template ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel); ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); ++ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); ++ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); ++ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); ++ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); ++ ++ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); ++ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); ++ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); ++ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); ++ ++ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); ++ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); ++ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); ++ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); ++ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); ++ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); ++ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); ++ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); ++ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); ++ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); ++ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); ++ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); ++ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); ++ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); ++ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); ++ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); ++ ++ kernel.packet[0] = T0; ++ kernel.packet[1] = T1; ++ kernel.packet[2] = T2; ++ kernel.packet[3] = T3; ++ kernel.packet[4] = T4; ++ kernel.packet[5] = T5; ++ kernel.packet[6] = T6; ++ kernel.packet[7] = T7; ++} ++ ++template <> ++EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { ++ ptranspose(kernel); ++} ++ ++/*** ++ * Unrolls for tranposed C stores ++ */ ++template ++class trans { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - storeC ++ *********************************** ++ */ ++ ++ /** ++ * aux_storeC ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ * ++ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC ++ * ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]), ++ remMask(remM_)), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu(C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); ++ } ++ } ++ else { // This block is only needed for fp32 case ++ // Reinterpret as __m512 for _mm512_shuffle_f32x4 ++ vecFullFloat zmm2vecFullFloat = preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]); ++ // Swap lower and upper half of avx register. ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] = ++ preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])), ++ remMask(remM_)); ++ } ++ else { ++ pstoreu( ++ C_arr + LDC * startN, ++ padd(ploadu((const Scalar *)C_arr + LDC * startN), ++ preinterpret( ++ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]))); ++ } ++ } ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t remM_ = 0) { ++ aux_storeC(C_arr, LDC, zmm, remM_); ++ } ++ ++ /** ++ * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to ++ * "unrollN"xL ymm registers to be stored col-major into C. ++ * ++ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 zmm2 ++ * row1: zmm3 zmm4 zmm5 ++ * . ++ * . ++ * row7: zmm21 zmm22 zmm23 ++ * ++ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows: ++ * ++ * row0: zmm0 zmm1 ++ * row1: zmm2 zmm3 ++ * . ++ * . ++ * row7: zmm14 zmm15 ++ * ++ * ++ * In general we will have {1,2,3} groups of avx registers each of size ++ * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of ++ * avx registers are being transposed. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock &zmm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ constexpr int64_t zmmStride = unrollN / PacketSize; ++ PacketBlock r; ++ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0]; ++ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1]; ++ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2]; ++ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3]; ++ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4]; ++ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5]; ++ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6]; ++ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7]; ++ trans8x8blocks(r); ++ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0]; ++ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1]; ++ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2]; ++ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3]; ++ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4]; ++ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5]; ++ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6]; ++ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7]; ++ } ++}; ++ ++/** ++ * Unrolls for copyBToRowMajor ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers (using loadB). ++ * 2) Convert the block from column-major to row-major (transposeLxL) ++ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false). ++ * ++ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are ++ * used as temps for transposing. ++ * ++ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks ++ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm). ++ */ ++template ++class transB { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadB ++ * - storeB ++ * - loadBBlock ++ * - storeBBlock ++ *********************************** ++ */ ++ ++ /** ++ * aux_loadB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remM) { ++ ymm.packet[packetIndexOffset + startN] = ++ ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); ++ } ++ else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); ++ ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeB ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(remK || remM) { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN], ++ remMask(rem_)); ++ } ++ else { ++ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]); ++ } ++ ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB( ++ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_loadBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ transB::template loadB(&B_temp[startN], LDB_, ymm); ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /** ++ * aux_storeBBlock ++ * ++ * 1-D unroll ++ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ constexpr int64_t counterReverse = endN - counter; ++ constexpr int64_t startN = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(toTemp) { ++ transB::template storeB(&B_temp[startN], LDB_, ymm, remK_); ++ } ++ else { ++ transB::template storeB(&B_arr[0 + startN * LDB], LDB, ++ ymm, remM_); ++ } ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock( ++ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, int64_t remM_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(B_temp); ++ EIGEN_UNUSED_VARIABLE(LDB_); ++ EIGEN_UNUSED_VARIABLE(ymm); ++ EIGEN_UNUSED_VARIABLE(remM_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_loadB(B_arr, LDB, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB, ++ PacketBlock &ymm, ++ int64_t rem_ = 0) { ++ aux_storeB(B_arr, LDB, ymm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } ++ else { ++ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock &ymm) { ++ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted ++ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. ++ PacketBlock r; ++ r.packet[0] = ymm.packet[packetIndexOffset + 0]; ++ r.packet[1] = ymm.packet[packetIndexOffset + 1]; ++ r.packet[2] = ymm.packet[packetIndexOffset + 2]; ++ r.packet[3] = ymm.packet[packetIndexOffset + 3]; ++ r.packet[4] = ymm.packet[packetIndexOffset + 4]; ++ r.packet[5] = ymm.packet[packetIndexOffset + 5]; ++ r.packet[6] = ymm.packet[packetIndexOffset + 6]; ++ r.packet[7] = ymm.packet[packetIndexOffset + 7]; ++ ptranspose(r); ++ ymm.packet[packetIndexOffset + 0] = r.packet[0]; ++ ymm.packet[packetIndexOffset + 1] = r.packet[1]; ++ ymm.packet[packetIndexOffset + 2] = r.packet[2]; ++ ymm.packet[packetIndexOffset + 3] = r.packet[3]; ++ ymm.packet[packetIndexOffset + 4] = r.packet[4]; ++ ymm.packet[packetIndexOffset + 5] = r.packet[5]; ++ ymm.packet[packetIndexOffset + 6] = r.packet[6]; ++ ymm.packet[packetIndexOffset + 7] = r.packet[7]; ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, ++ PacketBlock &ymm, ++ int64_t remM_ = 0) { ++ constexpr int64_t U3 = PacketSize * 3; ++ constexpr int64_t U2 = PacketSize * 2; ++ constexpr int64_t U1 = PacketSize * 1; ++ /** ++ * Unrolls needed for each case: ++ * - AVX512 fp32 48 32 16 8 4 2 1 ++ * - AVX512 fp64 24 16 8 4 2 1 ++ * ++ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up. ++ */ ++ EIGEN_IF_CONSTEXPR(unrollN == U3) { ++ // load LxU3 B col major, transpose LxU3 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U3) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, ++ ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U2) { ++ // load LxU2 B col major, transpose LxU2 row major ++ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2); ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ ++ EIGEN_IF_CONSTEXPR(maxUBlock < U2) { ++ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, ++ &B_temp[maxUBlock], LDB_, ymm, remM_); ++ } ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == U1) { ++ // load LxU1 B col major, transpose LxU1 row major ++ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); } ++ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { ++ // load Lx4 B col major, transpose Lx4 row major ++ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 2) { ++ // load Lx2 B col major, transpose Lx2 row major ++ transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ else EIGEN_IF_CONSTEXPR(unrollN == 1) { ++ // load Lx1 B col major, transpose Lx1 row major ++ transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ transB::template transposeLxL<0>(ymm); ++ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); ++ } ++ } ++}; ++ ++/** ++ * Unrolls for triSolveKernel ++ * ++ * Idea: ++ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS). ++ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix) ++ * stored in AInPacket (using triSolveMicroKernel). ++ * 3) Store final results (in avx registers) back into memory (using storeRHS). ++ * ++ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most ++ * EIGEN_AVX_MAX_NUM_ROW registers. ++ */ ++template ++class trsm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - loadRHS ++ * - storeRHS ++ * - divRHSByDiag ++ * - updateRHS ++ * - triSolveMicroKernel ++ ************************************/ ++ /** ++ * aux_loadRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); ++ } ++ else { ++ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); ++ } ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_storeRHS ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ constexpr int64_t counterReverse = endM * endK - counter; ++ constexpr int64_t startM = counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ constexpr int64_t packetIndex = startM * endK + startK; ++ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; ++ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; ++ EIGEN_IF_CONSTEXPR(krem) { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); ++ } ++ else { ++ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); ++ } ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS( ++ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { ++ EIGEN_UNUSED_VARIABLE(B_arr); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(rem); ++ } ++ ++ /** ++ * aux_divRHSByDiag ++ * ++ * currM may be -1, (currM >=0) in enable_if checks for this ++ * ++ * 1-D unroll ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endK - counter; ++ constexpr int64_t startK = counterReverse; ++ ++ constexpr int64_t packetIndex = currM * endK + startK; ++ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> aux_divRHSByDiag( ++ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_updateRHS ++ * ++ * 2-D unroll ++ * for(startM = initM; startM < endM; startM++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = (endM - initM) * endK - counter; ++ constexpr int64_t startM = initM + counterReverse / (endK); ++ constexpr int64_t startK = counterReverse % endK; ++ ++ // For each row of A, first update all corresponding RHS ++ constexpr int64_t packetIndex = startM * endK + startK; ++ EIGEN_IF_CONSTEXPR(currentM > 0) { ++ RHSInPacket.packet[packetIndex] = ++ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK], ++ RHSInPacket.packet[packetIndex]); ++ } ++ ++ EIGEN_IF_CONSTEXPR(startK == endK - 1) { ++ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. ++ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { ++ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. ++ // This will be used in divRHSByDiag ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(currentM, currentM, LDA)]); ++ else AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(-currentM, -currentM, LDA)]); ++ } ++ else { ++ // Broadcast next off diagonal element of A ++ EIGEN_IF_CONSTEXPR(isFWDSolve) ++ AInPacket.packet[startM] = pset1(A_arr[idA(startM, currentM, LDA)]); ++ else AInPacket.packet[startM] = pset1(A_arr[idA(-startM, -currentM, LDA)]); ++ } ++ } ++ ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /** ++ * aux_triSolverMicroKernel ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ constexpr int64_t currentM = startM; ++ // Divides the right-hand side in row startM, by digonal value of A ++ // broadcasted to AInPacket.packet[startM-1] in the previous iteration. ++ // ++ // Without "if constexpr" the compiler instantiates the case <-1, numK> ++ // this is handled with enable_if to prevent out-of-bound warnings ++ // from the compiler ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ // After division, the rhs corresponding to subsequent rows of A can be partially updated ++ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) ++ // to be used in the next iteration. ++ trsm::template updateRHS(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ ++ // Handle division for the RHS corresponding to the final row of A. ++ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) ++ trsm::template divRHSByDiag(RHSInPacket, AInPacket); ++ ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, ++ AInPacket); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel( ++ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ EIGEN_UNUSED_VARIABLE(A_arr); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(RHSInPacket); ++ EIGEN_UNUSED_VARIABLE(AInPacket); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Load endMxendK block of B to RHSInPacket ++ * Masked loads are used for cases where endK is not a multiple of PacketSize ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB, ++ PacketBlock &RHSInPacket, int64_t rem = 0) { ++ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); ++ } ++ ++ /** ++ * Only used if Triangular matrix has non-unit diagonal values ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_divRHSByDiag(RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * Update right-hand sides (stored in avx registers) ++ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ aux_updateRHS( ++ A_arr, LDA, RHSInPacket, AInPacket); ++ } ++ ++ /** ++ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW ++ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3. ++ * isFWDSolve: true => forward substitution, false => backwards substitution ++ * isUnitDiag: true => triangular matrix has unit diagonal. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, ++ PacketBlock &RHSInPacket, ++ PacketBlock &AInPacket) { ++ static_assert(numK >= 1 && numK <= 3, "numK out of range"); ++ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); ++ } ++}; ++ ++/** ++ * Unrolls for gemm kernel ++ * ++ * isAdd: true => C += A*B, false => C -= A*B ++ */ ++template ++class gemm { ++ public: ++ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; ++ static constexpr int64_t PacketSize = packet_traits::size; ++ ++ /*********************************** ++ * Auxillary Functions for: ++ * - setzero ++ * - updateC ++ * - storeC ++ * - startLoadB ++ * - triSolveMicroKernel ++ ************************************/ ++ ++ /** ++ * aux_setzero ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero( ++ PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]); ++ aux_setzero(zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero( ++ PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_updateC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize], remMask(rem_)), ++ zmm.packet[startN * endM + startM], remMask(rem_)); ++ else zmm.packet[startN * endM + startM] = ++ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_storeC ++ * ++ * 2-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN - counter; ++ constexpr int64_t startM = counterReverse / (endN); ++ constexpr int64_t startN = counterReverse % endN; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM], ++ remMask(rem_)); ++ else pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC( ++ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(C_arr); ++ EIGEN_UNUSED_VARIABLE(LDC); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startLoadB ++ * ++ * 1-D unroll ++ * for(startL = 0; startL < endL; startL++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endL - counter; ++ constexpr int64_t startL = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) ++ zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask(rem_)); ++ else zmm.packet[unrollM * unrollN + startL] = ++ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]); ++ ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_startBCastA ++ * ++ * 1-D unroll ++ * for(startB = 0; startB < endB; startB++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ constexpr int64_t counterReverse = endB - counter; ++ constexpr int64_t startB = counterReverse; ++ ++ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1(&A_t[idA(startB, 0, LDA)]); ++ ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA( ++ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ } ++ ++ /** ++ * aux_loadB ++ * currK: current K ++ * ++ * 1-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ if ((numLoad / endM + currK < unrollK)) { ++ constexpr int64_t counterReverse = endM - counter; ++ constexpr int64_t startM = counterReverse; ++ ++ EIGEN_IF_CONSTEXPR(rem) { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask(rem_)); ++ } ++ else { ++ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = ++ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]); ++ } ++ ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( ++ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /** ++ * aux_microKernel ++ * ++ * 3-D unroll ++ * for(startM = 0; startM < endM; startM++) ++ * for(startN = 0; startN < endN; startN++) ++ * for(startK = 0; startK < endK; startK++) ++ **/ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ constexpr int64_t counterReverse = endM * endN * endK - counter; ++ constexpr int startK = counterReverse / (endM * endN); ++ constexpr int startN = (counterReverse / (endM)) % endN; ++ constexpr int startM = counterReverse % endM; ++ ++ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { ++ gemm::template startLoadB(B_t, LDB, zmm, rem_); ++ gemm::template startBCastA(A_t, LDA, zmm); ++ } ++ ++ { ++ // Interleave FMA and Bcast ++ EIGEN_IF_CONSTEXPR(isAdd) { ++ zmm.packet[startN * endM + startM] = ++ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ else { ++ zmm.packet[startN * endM + startM] = ++ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], ++ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); ++ } ++ // Bcast ++ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { ++ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1(&A_t[idA( ++ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]); ++ } ++ } ++ ++ // We have updated all accumlators, time to load next set of B's ++ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) { ++ gemm::template loadB(B_t, LDB, zmm, rem_); ++ } ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel( ++ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(B_t); ++ EIGEN_UNUSED_VARIABLE(A_t); ++ EIGEN_UNUSED_VARIABLE(LDB); ++ EIGEN_UNUSED_VARIABLE(LDA); ++ EIGEN_UNUSED_VARIABLE(zmm); ++ EIGEN_UNUSED_VARIABLE(rem_); ++ } ++ ++ /******************************************************** ++ * Wrappers for aux_XXXX to hide counter parameter ++ ********************************************************/ ++ ++ template ++ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock &zmm) { ++ aux_setzero(zmm); ++ } ++ ++ /** ++ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_updateC(C_arr, LDC, zmm, rem_); ++ } ++ ++ template ++ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_storeC(C_arr, LDC, zmm, rem_); ++ } ++ ++ /** ++ * Use numLoad registers for loading B at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_startLoadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Use numBCast registers for broadcasting A at start of microKernel ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA, ++ PacketBlock &zmm) { ++ aux_startBCastA(A_t, LDA, zmm); ++ } ++ ++ /** ++ * Loads next set of B into vector registers between each K unroll. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_loadB(B_t, LDB, zmm, rem_); ++ } ++ ++ /** ++ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B. ++ * A matrix can be row/col-major. B matrix is assumed row-major. ++ * ++ * isARowMajor: is A row major ++ * endM: Number registers per row ++ * endN: Number of rows ++ * endK: Loop unroll for K. ++ * numLoad: Number of registers for loading B. ++ * numBCast: Number of registers for broadcasting A. ++ * ++ * Ex: microkernel: 8x48 unroll (24 accumulators), k unrolled 4 times, ++ * 6 register for loading B, 2 for broadcasting A. ++ * ++ * Note: Ideally the microkernel should not have any register spilling. ++ * The avx instruction counts should be: ++ * - endK*endN vbroadcasts{s,d} ++ * - endK*endM vmovup{s,d} ++ * - endK*endN*endM FMAs ++ * ++ * From testing, there are no register spills with clang. There are register spills with GNU, which ++ * causes a performance hit. ++ */ ++ template ++ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, ++ PacketBlock &zmm, ++ int64_t rem_ = 0) { ++ EIGEN_UNUSED_VARIABLE(rem_); ++ aux_microKernel(B_t, A_t, LDB, LDA, zmm, ++ rem_); ++ } ++}; ++} // namespace unrolls ++ ++#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H From 94e95676fbd53ec3bc86ad475b10c372bf827e34 Mon Sep 17 00:00:00 2001 From: thoffman Date: Thu, 14 Mar 2024 16:35:24 +0100 Subject: [PATCH 02/21] fix style --- .../easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb index 6ccd8a9d010..e221fab5480 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb @@ -37,7 +37,7 @@ dependencies = [ # downloading xla and other tarballs to avoid that Bazel downloads it during the build # note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl -local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' +local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit @@ -101,7 +101,7 @@ components = [ exts_list = [ (name, version, { - 'runtest': "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 XLA_PYTHON_CLIENT_ALLOCATOR=platform JAX_ENABLE_X64=true pytest -vv tests ", + 'runtest': local_test, 'source_tmpl': '%(name)s-v%(version)s.tar.gz', 'source_urls': ['https://github.com/google/jax/archive/'], 'checksums': ['8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'], From 1827967b80ac7ae4de9b84085e7f3aeca32197ee Mon Sep 17 00:00:00 2001 From: thoffman Date: Fri, 15 Mar 2024 12:09:47 +0100 Subject: [PATCH 03/21] fix test_no_log_spam --- .../j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb | 7 ++++++- .../jax-0.4.25_fix_env_test_no_log_spam.patch | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb index e221fab5480..b91854890a8 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb @@ -101,10 +101,15 @@ components = [ exts_list = [ (name, version, { + 'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'], 'runtest': local_test, 'source_tmpl': '%(name)s-v%(version)s.tar.gz', 'source_urls': ['https://github.com/google/jax/archive/'], - 'checksums': ['8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'], + 'checksums': [ + {'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'}, + {'jax-0.4.25_fix_env_test_no_log_spam.patch': + 'a5d4493c69814833615a2914f8481738fa438bacba9a1852601f7d19506371b9'}, + ], }), ] diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch new file mode 100644 index 00000000000..e679bb8e274 --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch @@ -0,0 +1,18 @@ +# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/03 +# avoid overriding LD_LIBRARY_PATH, which would lead to test error: error while loading shared libraries: libpython3.11.so.1.0: cannot open shared object file: No such file or directory' +diff -ru jax-jax-v0.4.25/tests/logging_test.py jax-jax-v0.4.25_fix_env_test_no_log_spam/tests/logging_test.py +--- jax-jax-v0.4.25/tests/logging_test.py 2024-02-24 19:25:17.000000000 +0100 ++++ jax-jax-v0.4.25_fix_env_test_no_log_spam/tests/logging_test.py 2024-03-15 12:00:34.133022613 +0100 +@@ -72,8 +72,11 @@ + python = sys.executable + assert "python" in python + # Make sure C++ logging is at default level for the test process. ++ import os ++ tmp_env=os.environ.copy() ++ tmp_env['TF_CPP_MIN_LOG_LEVEL']=1 + proc = subprocess.run([python, "-c", program], capture_output=True, +- env={"TF_CPP_MIN_LOG_LEVEL": "1"}) ++ env=tmp_env) + + lines = proc.stdout.split(b"\n") + lines.extend(proc.stderr.split(b"\n")) From 021b248b6970c4c259dd06d1214f0315d5e7ace6 Mon Sep 17 00:00:00 2001 From: thoffman Date: Fri, 15 Mar 2024 16:26:04 +0100 Subject: [PATCH 04/21] update patch --- .../easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb | 2 +- .../easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb index b91854890a8..6a2cd618907 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb @@ -108,7 +108,7 @@ exts_list = [ 'checksums': [ {'jax-v0.4.25.tar.gz': '8b30af49688c0c13b82c6f5ce992727c00b5fc6d04a4c6962012f4246fa664eb'}, {'jax-0.4.25_fix_env_test_no_log_spam.patch': - 'a5d4493c69814833615a2914f8481738fa438bacba9a1852601f7d19506371b9'}, + 'a18b5f147569d9ad41025124333a0f04fd0d0e0f9e4309658d7f6b9b838e2e2a'}, ], }), ] diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch index e679bb8e274..ad919608437 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix_env_test_no_log_spam.patch @@ -9,7 +9,7 @@ diff -ru jax-jax-v0.4.25/tests/logging_test.py jax-jax-v0.4.25_fix_env_test_no_l # Make sure C++ logging is at default level for the test process. + import os + tmp_env=os.environ.copy() -+ tmp_env['TF_CPP_MIN_LOG_LEVEL']=1 ++ tmp_env['TF_CPP_MIN_LOG_LEVEL']='1' proc = subprocess.run([python, "-c", program], capture_output=True, - env={"TF_CPP_MIN_LOG_LEVEL": "1"}) + env=tmp_env) From a90bc884cb81ec2c2840119323735cfee6875578 Mon Sep 17 00:00:00 2001 From: thoffman Date: Tue, 19 Mar 2024 12:41:56 +0100 Subject: [PATCH 05/21] isolate some tests --- .../jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb index 6a2cd618907..8d78d89a36d 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb @@ -42,15 +42,31 @@ local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit + +# Some tests require an isolated run: +local_isolated_tests = [ + 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp', + 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', + 'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + + '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', +] # deliberately not testing in parallel, as that results in (additional) failing tests; # use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, # see https://github.com/google/jax/issues/7323 and # https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; # use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; # use NVIDIA_TF32_OVERRIDE=0 to avoid lossing numerical precision by disabling TF32 Tensor Cores; -local_test = "NVIDIA_TF32_OVERRIDE=0 CUDA_VISIBLE_DEVICES=0 " -local_test += "XLA_PYTHON_CLIENT_ALLOCATOR=platform " -local_test += "JAX_ENABLE_X64=true pytest -vv tests " +local_test_exports = [ + "NVIDIA_TF32_OVERRIDE=0", + "CUDA_VISIBLE_DEVICES=0", + "XLA_PYTHON_CLIENT_ALLOCATOR=platform", + "JAX_ENABLE_X64=true", +] +local_test = ''.join(['export %s;' % x for x in local_test_exports]) +# run all tests at once except for local_isolated_tests: +local_test += "pytest -vv tests %s &&" % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) +# run remaining local_isolated_tests separately: +local_test += '&&'.join(['pytest -vv %s' % x for x in local_isolated_tests]) use_pip = True From 8c5d1f3195b536e5bd262fa0cdb7cb510cc0d050 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Sat, 23 Mar 2024 09:39:04 +0100 Subject: [PATCH 06/21] Update easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb Co-authored-by: Alexander Grund --- .../easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb index 8d78d89a36d..9b36e57d663 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb @@ -111,7 +111,8 @@ components = [ 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, ], 'start_dir': 'jax-jaxlib-v%(version)s', - 'buildopts': local_repo_opt + # Avoid warning (treated as error) in upb/table.c + 'buildopts': local_repo_opt + ' --bazel_options="--copt=-Wno-maybe-uninitialized"' }), ] From 1839f6d53b4c351b982397f7f18f15eaf0d7ec36 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:25:50 +0100 Subject: [PATCH 07/21] Update easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb Co-authored-by: Alexander Grund --- .../easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb index 9b36e57d663..75762ea7d12 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb @@ -64,9 +64,9 @@ local_test_exports = [ ] local_test = ''.join(['export %s;' % x for x in local_test_exports]) # run all tests at once except for local_isolated_tests: -local_test += "pytest -vv tests %s &&" % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) +local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) # run remaining local_isolated_tests separately: -local_test += '&&'.join(['pytest -vv %s' % x for x in local_isolated_tests]) +local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests]) use_pip = True From d80e8fed9c77ab63727f9357cdb135f474402b98 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:32:45 +0100 Subject: [PATCH 08/21] Update and rename jax-0.4.25-foss-2023a-CUDA-12.1.1.eb to jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb typo, zlib dependency order, foss->gfbf --- ...-CUDA-12.1.1.eb => jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb} | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) rename easybuild/easyconfigs/j/jax/{jax-0.4.25-foss-2023a-CUDA-12.1.1.eb => jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb} (97%) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb similarity index 97% rename from easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb rename to easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 75762ea7d12..551ff2a73a3 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-foss-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -12,7 +12,7 @@ homepage = 'https://pypi.python.org/pypi/jax' description = """Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more""" -toolchain = {'name': 'foss', 'version': '2023a'} +toolchain = {'name': 'gfbf', 'version': '2023a'} cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] builddependencies = [ @@ -28,10 +28,10 @@ dependencies = [ ('CUDA', '12.1.1', '', SYSTEM), ('cuDNN', '8.9.2.26', versionsuffix, SYSTEM), ('NCCL', '2.18.3', versionsuffix), + ('zlib', '1.2.13'), ('Python', '3.11.3'), ('SciPy-bundle', '2023.07'), ('flatbuffers-python', '23.5.26'), - ('zlib', '1.2.13'), ('ml_dtypes', '0.3.2'), ] @@ -55,7 +55,7 @@ local_isolated_tests = [ # see https://github.com/google/jax/issues/7323 and # https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; # use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; -# use NVIDIA_TF32_OVERRIDE=0 to avoid lossing numerical precision by disabling TF32 Tensor Cores; +# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores; local_test_exports = [ "NVIDIA_TF32_OVERRIDE=0", "CUDA_VISIBLE_DEVICES=0", From a025bfbf67a4be17a38b51a7287e337df5d58482 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:34:01 +0100 Subject: [PATCH 09/21] Update and rename ml_dtypes-0.3.2-foss-2023a.eb to ml_dtypes-0.3.2-gfbf-2023a.eb foss->gfbf --- ...dtypes-0.3.2-foss-2023a.eb => ml_dtypes-0.3.2-gfbf-2023a.eb} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename easybuild/easyconfigs/m/ml_dtypes/{ml_dtypes-0.3.2-foss-2023a.eb => ml_dtypes-0.3.2-gfbf-2023a.eb} (96%) diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb similarity index 96% rename from easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb rename to easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb index df6bd3d5134..9c3a18bfdb5 100644 --- a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-foss-2023a.eb +++ b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb @@ -18,7 +18,7 @@ float8_e5m2 float8_e5m2fnuz """ -toolchain = {'name': 'foss', 'version': '2023a'} +toolchain = {'name': 'gfbf', 'version': '2023a'} dependencies = [ ('Python', '3.11.3'), From 9291f70d116b5c287df260941ce0bb39ffe20751 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Mon, 8 Apr 2024 17:32:19 +0200 Subject: [PATCH 10/21] Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb add tests/lax_numpy_test.py::NumpyUfuncTests::testUfuncInputTypes763 to isolated tests --- easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 1 + 1 file changed, 1 insertion(+) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 551ff2a73a3..6b5b930524e 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -49,6 +49,7 @@ local_isolated_tests = [ 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', 'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', + 'tests/lax_numpy_test.py::NumpyUfuncTests::testUfuncInputTypes763' ] # deliberately not testing in parallel, as that results in (additional) failing tests; # use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, From 07bb12db55c15482dc123803270ddf24a51d27ff Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:41:26 +0200 Subject: [PATCH 11/21] Update easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb Co-authored-by: Alexander Grund --- .../easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 2 ++ 1 file changed, 2 insertions(+) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 6b5b930524e..23d51500542 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -41,6 +41,8 @@ local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit +local_repo_opt += '-bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' +local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include"' # Some tests require an isolated run: From 7809bd3377fe8d875d0983f77015191ea6c8425e Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Thu, 11 Apr 2024 18:16:42 +0200 Subject: [PATCH 12/21] Fix usage of system Pybind11 --- .../j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 7 ++++--- .../jax-0.4.25_fix-pybind11-systemlib.patch | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 23d51500542..04db9ad33e1 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -41,8 +41,8 @@ local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit -local_repo_opt += '-bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' -local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include"' +local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' +local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" ' # Some tests require an isolated run: @@ -51,7 +51,6 @@ local_isolated_tests = [ 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', 'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', - 'tests/lax_numpy_test.py::NumpyUfuncTests::testUfuncInputTypes763' ] # deliberately not testing in parallel, as that results in (additional) failing tests; # use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, @@ -105,6 +104,7 @@ components = [ 'https://github.com/tensorflow/runtime/archive', 'https://github.com/openxla/xla/archive' ], + 'patches': [('jax-0.4.25_fix-pybind11-systemlib.patch', '../xla-' + local_xla_commit)], 'checksums': [ {'jaxlib-v0.4.25.tar.gz': 'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'}, @@ -112,6 +112,7 @@ components = [ '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'}, {'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz': 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, + {'jax-0.4.25_fix-pybind11-systemlib.patch': '4cdc97ce05b708b16e161082428e5413c89c2852edb4262cd19d16618ddad9b6'}, ], 'start_dir': 'jax-jaxlib-v%(version)s', # Avoid warning (treated as error) in upb/table.c diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch new file mode 100644 index 00000000000..620acb11d0a --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch @@ -0,0 +1,18 @@ +Add missing value for System Pybind11 Bazel config + +Author: Alexander Grund (TU Dresden) + +--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD ++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD +@@ -6,3 +6,10 @@ + "@tsl//third_party/python_runtime:headers", + ], + ) ++ ++# Needed by pybind11_bazel. ++config_setting( ++ name = "osx", ++ constraint_values = ["@platforms//os:osx"], ++) ++ + From 65e87a6dd94b4bf1b14b3a1479db9846ad93dd40 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Fri, 12 Apr 2024 11:53:34 +0200 Subject: [PATCH 13/21] Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb fix checksum --- .../easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 04db9ad33e1..af2e1b2ec5c 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -112,7 +112,7 @@ components = [ '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'}, {'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz': 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, - {'jax-0.4.25_fix-pybind11-systemlib.patch': '4cdc97ce05b708b16e161082428e5413c89c2852edb4262cd19d16618ddad9b6'}, + {'jax-0.4.25_fix-pybind11-systemlib.patch': 'ec93de5628e4d40d3378b92784f7d1e5b0b43bd207a86badeffd44a42e0b1d47'}, ], 'start_dir': 'jax-jaxlib-v%(version)s', # Avoid warning (treated as error) in upb/table.c From c5f0711b146fa13de8768ef07a27e915be23b1ef Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Fri, 12 Apr 2024 12:18:50 +0200 Subject: [PATCH 14/21] Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb fix style; add PyBind11 builddep --- .../easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index af2e1b2ec5c..d7ed71d94d6 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -22,6 +22,7 @@ builddependencies = [ ('git', '2.41.0', '-nodocs'), ('matplotlib', '3.7.2'), # required for tests/lobpcg_test.py ('poetry', '1.5.1'), + ('pybind11', '2.11.1'), ] dependencies = [ @@ -112,7 +113,8 @@ components = [ '8a59b9af7d0850059d7043f7043c780066d61538f3af536e8a10d3d717f35089'}, {'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz': 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, - {'jax-0.4.25_fix-pybind11-systemlib.patch': 'ec93de5628e4d40d3378b92784f7d1e5b0b43bd207a86badeffd44a42e0b1d47'}, + {'jax-0.4.25_fix-pybind11-systemlib.patch': + 'ec93de5628e4d40d3378b92784f7d1e5b0b43bd207a86badeffd44a42e0b1d47'}, ], 'start_dir': 'jax-jaxlib-v%(version)s', # Avoid warning (treated as error) in upb/table.c From cf6f043765f231e0a8bfb62b306f74259c5e60ae Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Mon, 29 Apr 2024 12:59:07 +0200 Subject: [PATCH 15/21] Update easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb Co-authored-by: Alexander Grund --- .../easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index d7ed71d94d6..fcc807a7055 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -41,7 +41,7 @@ dependencies = [ local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit -local_repo_opt += '--bazel_options="--override_repository=runtime=%%(builddir)s/tf_runtime-%s" ' % local_xla_commit +local_repo_opt += '--bazel_options="--override_repository=tf_runtime=%%(builddir)s/runtime-%s" ' % local_tfrt_commit local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" ' From 2e9329d58d1b2b309400c03612b918787540ceb4 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Wed, 8 May 2024 16:57:35 +0200 Subject: [PATCH 16/21] Use Bazel --distdir Don't unpack the archives just put them into a folder Allows to verify that we used the right ones as Bazel checks the name and checksum. --- .../jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 11 +++-- .../jax-0.4.25_fix-pybind11-systemlib.patch | 46 +++++++++++++------ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index fcc807a7055..68ab99486f2 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -40,8 +40,9 @@ dependencies = [ # note: this *must* be the exact same commit as used in third_party/{xla,"other"}/workspace.bzl local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' -local_repo_opt = '--bazel_options="--override_repository=xla=%%(builddir)s/xla-%s" ' % local_xla_commit -local_repo_opt += '--bazel_options="--override_repository=tf_runtime=%%(builddir)s/runtime-%s" ' % local_tfrt_commit + +local_extract_cmd = 'cp %s %(builddir)s/archives' +local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" ' local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" ' @@ -94,10 +95,12 @@ components = [ { 'download_filename': '%s.tar.gz' % local_xla_commit, 'filename': 'xla-%s.tar.gz' % local_xla_commit, + 'extract_cmd': local_extract_cmd, }, { 'download_filename': '%s.tar.gz' % local_tfrt_commit, 'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit, + 'extract_cmd': local_extract_cmd, }, ], 'source_urls': [ @@ -105,7 +108,7 @@ components = [ 'https://github.com/tensorflow/runtime/archive', 'https://github.com/openxla/xla/archive' ], - 'patches': [('jax-0.4.25_fix-pybind11-systemlib.patch', '../xla-' + local_xla_commit)], + 'patches': ['jax-0.4.25_fix-pybind11-systemlib.patch'], 'checksums': [ {'jaxlib-v0.4.25.tar.gz': 'fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8'}, @@ -114,7 +117,7 @@ components = [ {'tf_runtime-0aeefb1660d7e37964b2bb71b1f518096bda9a25.tar.gz': 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, {'jax-0.4.25_fix-pybind11-systemlib.patch': - 'ec93de5628e4d40d3378b92784f7d1e5b0b43bd207a86badeffd44a42e0b1d47'}, + 'daad5b726d1a138431b05eb60ecf4c89c7b5148eb939721800bdf43d804ca033'}, ], 'start_dir': 'jax-jaxlib-v%(version)s', # Avoid warning (treated as error) in upb/table.c diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch index 620acb11d0a..c404ee6917f 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25_fix-pybind11-systemlib.patch @@ -2,17 +2,37 @@ Add missing value for System Pybind11 Bazel config Author: Alexander Grund (TU Dresden) ---- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD -+++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD -@@ -6,3 +6,10 @@ - "@tsl//third_party/python_runtime:headers", - ], - ) -+ -+# Needed by pybind11_bazel. -+config_setting( -+ name = "osx", -+ constraint_values = ["@platforms//os:osx"], -+) -+ +diff --git a/third_party/xla/fix-pybind11-systemlib.patch b/third_party/xla/fix-pybind11-systemlib.patch +new file mode 100644 +index 000000000..68bd2063d +--- /dev/null ++++ b/third_party/xla/fix-pybind11-systemlib.patch +@@ -0,0 +1,13 @@ ++--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD +++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD ++@@ -6,3 +6,10 @@ ++ "@tsl//third_party/python_runtime:headers", ++ ], ++ ) +++ +++# Needed by pybind11_bazel. +++config_setting( +++ name = "osx", +++ constraint_values = ["@platforms//os:osx"], +++) +++ +diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl +index ebc8d9838..125e1c173 100644 +--- a/third_party/xla/workspace.bzl ++++ b/third_party/xla/workspace.bzl +@@ -29,6 +29,9 @@ def repo(): + sha256 = XLA_SHA256, + strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), + urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), ++ patch_file = [ ++ "//third_party/xla:fix-pybind11-systemlib.patch", ++ ], + ) + + # For development, one often wants to make changes to the TF repository as well From a643eebd36dfc0e40282c56bb1f4cb1fd1126e1d Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Wed, 22 May 2024 13:49:20 +0200 Subject: [PATCH 17/21] Update easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb Co-authored-by: Kenneth Hoste --- .../easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 68ab99486f2..33b8d5546fa 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -41,7 +41,7 @@ dependencies = [ local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' -local_extract_cmd = 'cp %s %(builddir)s/archives' +local_extract_cmd = 'cp %s %%(builddir)s/archives' local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" ' local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" ' From a045edf6a863721ff8d9cc7e2a0b7db0326dd015 Mon Sep 17 00:00:00 2001 From: Kenneth Hoste Date: Thu, 30 May 2024 14:33:50 +0200 Subject: [PATCH 18/21] revert workaround for framework bug in extract command in easyconfig for jax 0.4.25 --- .../easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 33b8d5546fa..68ab99486f2 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -41,7 +41,7 @@ dependencies = [ local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' -local_extract_cmd = 'cp %s %%(builddir)s/archives' +local_extract_cmd = 'cp %s %(builddir)s/archives' local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" ' local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" ' From 2a1b972fdf200e6c2b9ba86f579e8c42327c438e Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:36:42 +0200 Subject: [PATCH 19/21] Delete easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch --- .../ml_dtypes-0.3.2_EigenAvx512.patch | 1219 ----------------- 1 file changed, 1219 deletions(-) delete mode 100644 easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch deleted file mode 100644 index 42ea0606391..00000000000 --- a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2_EigenAvx512.patch +++ /dev/null @@ -1,1219 +0,0 @@ -# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/01 -# ml_dtype 0.3.2 ships a copy of Eigen commit 7bf2968 (https://gitlab.com/libeigen/eigen/-/commit/7bf2968). -# This copy is missing the file src/Core/arch/AVX512/TrsmUnrolls.inc, which is added by the present patch. -diff -ru --new-file old/third_party_ori/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc ---- old/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 1970-01-01 01:00:00.000000000 +0100 -+++ new/third_party/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc 2024-02-14 10:32:25.492978066 +0100 -@@ -0,0 +1,1212 @@ -+// This file is part of Eigen, a lightweight C++ template library -+// for linear algebra. -+// -+// Copyright (C) 2022 Intel Corporation -+// -+// This Source Code Form is subject to the terms of the Mozilla -+// Public License v. 2.0. If a copy of the MPL was not distributed -+// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -+ -+#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H -+#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H -+ -+template -+EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) { -+ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; -+ else return i + j * LDA; -+} -+ -+/** -+ * This namespace contains various classes used to generate compile-time unrolls which are -+ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested -+ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion -+ * -+ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop. -+ * -+ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++) -+ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ) -+ * func(startI,startJ) startJ = (startC)%(endJ) -+ * func(...) -+ * -+ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function -+ * with a template parameter used as a counter. -+ * -+ * template -+ * std::enable_if_t<(counter <= 0)> <---- tail case. -+ * aux_func {} -+ * -+ * template -+ * std::enable_if_t<(counter > 0)> <---- actual for-loop -+ * aux_func { -+ * startC = endI*endJ - counter -+ * startI = (startC)/(endJ) -+ * startJ = (startC)%(endJ) -+ * func(startI, startJ) -+ * aux_func() -+ * } -+ * -+ * Note: Additional wrapper functions are provided for aux_func which hides the counter template -+ * parameter since counter usually depends on endI, endJ, etc... -+ * -+ * Conventions: -+ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++)) -+ * -+ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for -+ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW) -+ */ -+namespace unrolls { -+ -+template -+EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { -+ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); } -+ else EIGEN_IF_CONSTEXPR(N == 8) { -+ return 0xFF >> (8 - m); -+ } -+ else EIGEN_IF_CONSTEXPR(N == 4) { -+ return 0x0F >> (4 - m); -+ } -+ return 0; -+} -+ -+template -+EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel); -+ -+template <> -+EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { -+ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); -+ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); -+ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); -+ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); -+ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); -+ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); -+ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); -+ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); -+ -+ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); -+ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2))); -+ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); -+ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3))); -+ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); -+ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6))); -+ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); -+ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7))); -+ -+ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E)); -+ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); -+ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E)); -+ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); -+ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E)); -+ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); -+ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E)); -+ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); -+ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E)); -+ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); -+ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E)); -+ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); -+ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E)); -+ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); -+ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E)); -+ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); -+ -+ kernel.packet[0] = T0; -+ kernel.packet[1] = T1; -+ kernel.packet[2] = T2; -+ kernel.packet[3] = T3; -+ kernel.packet[4] = T4; -+ kernel.packet[5] = T5; -+ kernel.packet[6] = T6; -+ kernel.packet[7] = T7; -+} -+ -+template <> -+EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock &kernel) { -+ ptranspose(kernel); -+} -+ -+/*** -+ * Unrolls for tranposed C stores -+ */ -+template -+class trans { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - storeC -+ *********************************** -+ */ -+ -+ /** -+ * aux_storeC -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN++) -+ * -+ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC -+ * -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { -+ EIGEN_IF_CONSTEXPR(remM) { -+ pstoreu( -+ C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), -+ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]), -+ remMask(remM_)), -+ remMask(remM_)); -+ } -+ else { -+ pstoreu(C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN), -+ preinterpret(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]))); -+ } -+ } -+ else { // This block is only needed for fp32 case -+ // Reinterpret as __m512 for _mm512_shuffle_f32x4 -+ vecFullFloat zmm2vecFullFloat = preinterpret( -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]); -+ // Swap lower and upper half of avx register. -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] = -+ preinterpret(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110)); -+ -+ EIGEN_IF_CONSTEXPR(remM) { -+ pstoreu( -+ C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN, remMask(remM_)), -+ preinterpret( -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])), -+ remMask(remM_)); -+ } -+ else { -+ pstoreu( -+ C_arr + LDC * startN, -+ padd(ploadu((const Scalar *)C_arr + LDC * startN), -+ preinterpret( -+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]))); -+ } -+ } -+ aux_storeC(C_arr, LDC, zmm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(C_arr); -+ EIGEN_UNUSED_VARIABLE(LDC); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, -+ PacketBlock &zmm, -+ int64_t remM_ = 0) { -+ aux_storeC(C_arr, LDC, zmm, remM_); -+ } -+ -+ /** -+ * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to -+ * "unrollN"xL ymm registers to be stored col-major into C. -+ * -+ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows: -+ * -+ * row0: zmm0 zmm1 zmm2 -+ * row1: zmm3 zmm4 zmm5 -+ * . -+ * . -+ * row7: zmm21 zmm22 zmm23 -+ * -+ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows: -+ * -+ * row0: zmm0 zmm1 -+ * row1: zmm2 zmm3 -+ * . -+ * . -+ * row7: zmm14 zmm15 -+ * -+ * -+ * In general we will have {1,2,3} groups of avx registers each of size -+ * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of -+ * avx registers are being transposed. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock &zmm) { -+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted -+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. -+ constexpr int64_t zmmStride = unrollN / PacketSize; -+ PacketBlock r; -+ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0]; -+ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1]; -+ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2]; -+ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3]; -+ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4]; -+ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5]; -+ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6]; -+ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7]; -+ trans8x8blocks(r); -+ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0]; -+ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1]; -+ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2]; -+ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3]; -+ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4]; -+ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5]; -+ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6]; -+ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7]; -+ } -+}; -+ -+/** -+ * Unrolls for copyBToRowMajor -+ * -+ * Idea: -+ * 1) Load a block of right-hand sides to registers (using loadB). -+ * 2) Convert the block from column-major to row-major (transposeLxL) -+ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false). -+ * -+ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are -+ * used as temps for transposing. -+ * -+ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks -+ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm). -+ */ -+template -+class transB { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - loadB -+ * - storeB -+ * - loadBBlock -+ * - storeBBlock -+ *********************************** -+ */ -+ -+ /** -+ * aux_loadB -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(remM) { -+ ymm.packet[packetIndexOffset + startN] = -+ ploadu((const Scalar *)&B_arr[startN * LDB], remMask(remM_)); -+ } -+ else ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar *)&B_arr[startN * LDB]); -+ -+ aux_loadB(B_arr, LDB, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ /** -+ * aux_storeB -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(remK || remM) { -+ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN], -+ remMask(rem_)); -+ } -+ else { -+ pstoreu(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]); -+ } -+ -+ aux_storeB(B_arr, LDB, ymm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB( -+ Scalar *B_arr, int64_t LDB, PacketBlock &ymm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_loadBBlock -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ transB::template loadB(&B_temp[startN], LDB_, ymm); -+ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(B_temp); -+ EIGEN_UNUSED_VARIABLE(LDB_); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ /** -+ * aux_storeBBlock -+ * -+ * 1-D unroll -+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ constexpr int64_t counterReverse = endN - counter; -+ constexpr int64_t startN = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(toTemp) { -+ transB::template storeB(&B_temp[startN], LDB_, ymm, remK_); -+ } -+ else { -+ transB::template storeB(&B_arr[0 + startN * LDB], LDB, -+ ymm, remM_); -+ } -+ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock( -+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, int64_t remM_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(B_temp); -+ EIGEN_UNUSED_VARIABLE(LDB_); -+ EIGEN_UNUSED_VARIABLE(ymm); -+ EIGEN_UNUSED_VARIABLE(remM_); -+ } -+ -+ /******************************************************** -+ * Wrappers for aux_XXXX to hide counter parameter -+ ********************************************************/ -+ -+ template -+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ aux_loadB(B_arr, LDB, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB, -+ PacketBlock &ymm, -+ int64_t rem_ = 0) { -+ aux_storeB(B_arr, LDB, ymm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB(&B_arr[0], LDB, ymm, remM_); } -+ else { -+ aux_loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ aux_storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock &ymm) { -+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted -+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. -+ PacketBlock r; -+ r.packet[0] = ymm.packet[packetIndexOffset + 0]; -+ r.packet[1] = ymm.packet[packetIndexOffset + 1]; -+ r.packet[2] = ymm.packet[packetIndexOffset + 2]; -+ r.packet[3] = ymm.packet[packetIndexOffset + 3]; -+ r.packet[4] = ymm.packet[packetIndexOffset + 4]; -+ r.packet[5] = ymm.packet[packetIndexOffset + 5]; -+ r.packet[6] = ymm.packet[packetIndexOffset + 6]; -+ r.packet[7] = ymm.packet[packetIndexOffset + 7]; -+ ptranspose(r); -+ ymm.packet[packetIndexOffset + 0] = r.packet[0]; -+ ymm.packet[packetIndexOffset + 1] = r.packet[1]; -+ ymm.packet[packetIndexOffset + 2] = r.packet[2]; -+ ymm.packet[packetIndexOffset + 3] = r.packet[3]; -+ ymm.packet[packetIndexOffset + 4] = r.packet[4]; -+ ymm.packet[packetIndexOffset + 5] = r.packet[5]; -+ ymm.packet[packetIndexOffset + 6] = r.packet[6]; -+ ymm.packet[packetIndexOffset + 7] = r.packet[7]; -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, -+ PacketBlock &ymm, -+ int64_t remM_ = 0) { -+ constexpr int64_t U3 = PacketSize * 3; -+ constexpr int64_t U2 = PacketSize * 2; -+ constexpr int64_t U1 = PacketSize * 1; -+ /** -+ * Unrolls needed for each case: -+ * - AVX512 fp32 48 32 16 8 4 2 1 -+ * - AVX512 fp64 24 16 8 4 2 1 -+ * -+ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up. -+ */ -+ EIGEN_IF_CONSTEXPR(unrollN == U3) { -+ // load LxU3 B col major, transpose LxU3 row major -+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3); -+ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ -+ EIGEN_IF_CONSTEXPR(maxUBlock < U3) { -+ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, -+ ymm, remM_); -+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_, -+ ymm, remM_); -+ } -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == U2) { -+ // load LxU2 B col major, transpose LxU2 row major -+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2); -+ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm); -+ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ -+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) { -+ transB::template loadBBlock(&B_arr[maxUBlock * LDB], LDB, -+ &B_temp[maxUBlock], LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock(&B_arr[maxUBlock * LDB], LDB, -+ &B_temp[maxUBlock], LDB_, ymm, remM_); -+ } -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == U1) { -+ // load LxU1 B col major, transpose LxU1 row major -+ transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); } -+ transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { -+ // load Lx4 B col major, transpose Lx4 row major -+ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { -+ // load Lx4 B col major, transpose Lx4 row major -+ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 2) { -+ // load Lx2 B col major, transpose Lx2 row major -+ transB::template loadBBlock<2, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ else EIGEN_IF_CONSTEXPR(unrollN == 1) { -+ // load Lx1 B col major, transpose Lx1 row major -+ transB::template loadBBlock<1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ transB::template transposeLxL<0>(ymm); -+ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); -+ } -+ } -+}; -+ -+/** -+ * Unrolls for triSolveKernel -+ * -+ * Idea: -+ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS). -+ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix) -+ * stored in AInPacket (using triSolveMicroKernel). -+ * 3) Store final results (in avx registers) back into memory (using storeRHS). -+ * -+ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most -+ * EIGEN_AVX_MAX_NUM_ROW registers. -+ */ -+template -+class trsm { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - loadRHS -+ * - storeRHS -+ * - divRHSByDiag -+ * - updateRHS -+ * - triSolveMicroKernel -+ ************************************/ -+ /** -+ * aux_loadRHS -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ constexpr int64_t counterReverse = endM * endK - counter; -+ constexpr int64_t startM = counterReverse / (endK); -+ constexpr int64_t startK = counterReverse % endK; -+ -+ constexpr int64_t packetIndex = startM * endK + startK; -+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; -+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; -+ EIGEN_IF_CONSTEXPR(krem) { -+ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); -+ } -+ else { -+ RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); -+ } -+ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(rem); -+ } -+ -+ /** -+ * aux_storeRHS -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ constexpr int64_t counterReverse = endM * endK - counter; -+ constexpr int64_t startM = counterReverse / (endK); -+ constexpr int64_t startK = counterReverse % endK; -+ -+ constexpr int64_t packetIndex = startM * endK + startK; -+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM; -+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB; -+ EIGEN_IF_CONSTEXPR(krem) { -+ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); -+ } -+ else { -+ pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); -+ } -+ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS( -+ Scalar *B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { -+ EIGEN_UNUSED_VARIABLE(B_arr); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(rem); -+ } -+ -+ /** -+ * aux_divRHSByDiag -+ * -+ * currM may be -1, (currM >=0) in enable_if checks for this -+ * -+ * 1-D unroll -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag( -+ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { -+ constexpr int64_t counterReverse = endK - counter; -+ constexpr int64_t startK = counterReverse; -+ -+ constexpr int64_t packetIndex = currM * endK + startK; -+ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); -+ aux_divRHSByDiag(RHSInPacket, AInPacket); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> aux_divRHSByDiag( -+ PacketBlock &RHSInPacket, PacketBlock &AInPacket) { -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(AInPacket); -+ } -+ -+ /** -+ * aux_updateRHS -+ * -+ * 2-D unroll -+ * for(startM = initM; startM < endM; startM++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ constexpr int64_t counterReverse = (endM - initM) * endK - counter; -+ constexpr int64_t startM = initM + counterReverse / (endK); -+ constexpr int64_t startK = counterReverse % endK; -+ -+ // For each row of A, first update all corresponding RHS -+ constexpr int64_t packetIndex = startM * endK + startK; -+ EIGEN_IF_CONSTEXPR(currentM > 0) { -+ RHSInPacket.packet[packetIndex] = -+ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK], -+ RHSInPacket.packet[packetIndex]); -+ } -+ -+ EIGEN_IF_CONSTEXPR(startK == endK - 1) { -+ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. -+ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { -+ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. -+ // This will be used in divRHSByDiag -+ EIGEN_IF_CONSTEXPR(isFWDSolve) -+ AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(currentM, currentM, LDA)]); -+ else AInPacket.packet[currentM] = pset1(Scalar(1) / A_arr[idA(-currentM, -currentM, LDA)]); -+ } -+ else { -+ // Broadcast next off diagonal element of A -+ EIGEN_IF_CONSTEXPR(isFWDSolve) -+ AInPacket.packet[startM] = pset1(A_arr[idA(startM, currentM, LDA)]); -+ else AInPacket.packet[startM] = pset1(A_arr[idA(-startM, -currentM, LDA)]); -+ } -+ } -+ -+ aux_updateRHS( -+ A_arr, LDA, RHSInPacket, AInPacket); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ EIGEN_UNUSED_VARIABLE(A_arr); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(AInPacket); -+ } -+ -+ /** -+ * aux_triSolverMicroKernel -+ * -+ * 1-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ constexpr int64_t counterReverse = endM - counter; -+ constexpr int64_t startM = counterReverse; -+ -+ constexpr int64_t currentM = startM; -+ // Divides the right-hand side in row startM, by digonal value of A -+ // broadcasted to AInPacket.packet[startM-1] in the previous iteration. -+ // -+ // Without "if constexpr" the compiler instantiates the case <-1, numK> -+ // this is handled with enable_if to prevent out-of-bound warnings -+ // from the compiler -+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) -+ trsm::template divRHSByDiag(RHSInPacket, AInPacket); -+ -+ // After division, the rhs corresponding to subsequent rows of A can be partially updated -+ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) -+ // to be used in the next iteration. -+ trsm::template updateRHS(A_arr, LDA, RHSInPacket, -+ AInPacket); -+ -+ // Handle division for the RHS corresponding to the final row of A. -+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1) -+ trsm::template divRHSByDiag(RHSInPacket, AInPacket); -+ -+ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, -+ AInPacket); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel( -+ Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ EIGEN_UNUSED_VARIABLE(A_arr); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(RHSInPacket); -+ EIGEN_UNUSED_VARIABLE(AInPacket); -+ } -+ -+ /******************************************************** -+ * Wrappers for aux_XXXX to hide counter parameter -+ ********************************************************/ -+ -+ /** -+ * Load endMxendK block of B to RHSInPacket -+ * Masked loads are used for cases where endK is not a multiple of PacketSize -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB, -+ PacketBlock &RHSInPacket, int64_t rem = 0) { -+ aux_loadRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ /** -+ * Load endMxendK block of B to RHSInPacket -+ * Masked loads are used for cases where endK is not a multiple of PacketSize -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB, -+ PacketBlock &RHSInPacket, int64_t rem = 0) { -+ aux_storeRHS(B_arr, LDB, RHSInPacket, rem); -+ } -+ -+ /** -+ * Only used if Triangular matrix has non-unit diagonal values -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ aux_divRHSByDiag(RHSInPacket, AInPacket); -+ } -+ -+ /** -+ * Update right-hand sides (stored in avx registers) -+ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA, -+ PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ aux_updateRHS( -+ A_arr, LDA, RHSInPacket, AInPacket); -+ } -+ -+ /** -+ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW -+ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3. -+ * isFWDSolve: true => forward substitution, false => backwards substitution -+ * isUnitDiag: true => triangular matrix has unit diagonal. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, -+ PacketBlock &RHSInPacket, -+ PacketBlock &AInPacket) { -+ static_assert(numK >= 1 && numK <= 3, "numK out of range"); -+ aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); -+ } -+}; -+ -+/** -+ * Unrolls for gemm kernel -+ * -+ * isAdd: true => C += A*B, false => C -= A*B -+ */ -+template -+class gemm { -+ public: -+ using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; -+ static constexpr int64_t PacketSize = packet_traits::size; -+ -+ /*********************************** -+ * Auxillary Functions for: -+ * - setzero -+ * - updateC -+ * - storeC -+ * - startLoadB -+ * - triSolveMicroKernel -+ ************************************/ -+ -+ /** -+ * aux_setzero -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero( -+ PacketBlock &zmm) { -+ constexpr int64_t counterReverse = endM * endN - counter; -+ constexpr int64_t startM = counterReverse / (endN); -+ constexpr int64_t startN = counterReverse % endN; -+ -+ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]); -+ aux_setzero(zmm); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero( -+ PacketBlock &zmm) { -+ EIGEN_UNUSED_VARIABLE(zmm); -+ } -+ -+ /** -+ * aux_updateC -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endM * endN - counter; -+ constexpr int64_t startM = counterReverse / (endN); -+ constexpr int64_t startN = counterReverse % endN; -+ -+ EIGEN_IF_CONSTEXPR(rem) -+ zmm.packet[startN * endM + startM] = -+ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize], remMask(rem_)), -+ zmm.packet[startN * endM + startM], remMask(rem_)); -+ else zmm.packet[startN * endM + startM] = -+ padd(ploadu(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]); -+ aux_updateC(C_arr, LDC, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(C_arr); -+ EIGEN_UNUSED_VARIABLE(LDC); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_storeC -+ * -+ * 2-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endM * endN - counter; -+ constexpr int64_t startM = counterReverse / (endN); -+ constexpr int64_t startN = counterReverse % endN; -+ -+ EIGEN_IF_CONSTEXPR(rem) -+ pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM], -+ remMask(rem_)); -+ else pstoreu(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]); -+ aux_storeC(C_arr, LDC, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC( -+ Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(C_arr); -+ EIGEN_UNUSED_VARIABLE(LDC); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_startLoadB -+ * -+ * 1-D unroll -+ * for(startL = 0; startL < endL; startL++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endL - counter; -+ constexpr int64_t startL = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(rem) -+ zmm.packet[unrollM * unrollN + startL] = -+ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask(rem_)); -+ else zmm.packet[unrollM * unrollN + startL] = -+ ploadu(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]); -+ -+ aux_startLoadB(B_t, LDB, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_t); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_startBCastA -+ * -+ * 1-D unroll -+ * for(startB = 0; startB < endB; startB++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA( -+ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { -+ constexpr int64_t counterReverse = endB - counter; -+ constexpr int64_t startB = counterReverse; -+ -+ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1(&A_t[idA(startB, 0, LDA)]); -+ -+ aux_startBCastA(A_t, LDA, zmm); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA( -+ Scalar *A_t, int64_t LDA, PacketBlock &zmm) { -+ EIGEN_UNUSED_VARIABLE(A_t); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ } -+ -+ /** -+ * aux_loadB -+ * currK: current K -+ * -+ * 1-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ if ((numLoad / endM + currK < unrollK)) { -+ constexpr int64_t counterReverse = endM - counter; -+ constexpr int64_t startM = counterReverse; -+ -+ EIGEN_IF_CONSTEXPR(rem) { -+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = -+ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask(rem_)); -+ } -+ else { -+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] = -+ ploadu(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]); -+ } -+ -+ aux_loadB(B_t, LDB, zmm, rem_); -+ } -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB( -+ Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_t); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /** -+ * aux_microKernel -+ * -+ * 3-D unroll -+ * for(startM = 0; startM < endM; startM++) -+ * for(startN = 0; startN < endN; startN++) -+ * for(startK = 0; startK < endK; startK++) -+ **/ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel( -+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ constexpr int64_t counterReverse = endM * endN * endK - counter; -+ constexpr int startK = counterReverse / (endM * endN); -+ constexpr int startN = (counterReverse / (endM)) % endN; -+ constexpr int startM = counterReverse % endM; -+ -+ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { -+ gemm::template startLoadB(B_t, LDB, zmm, rem_); -+ gemm::template startBCastA(A_t, LDA, zmm); -+ } -+ -+ { -+ // Interleave FMA and Bcast -+ EIGEN_IF_CONSTEXPR(isAdd) { -+ zmm.packet[startN * endM + startM] = -+ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], -+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); -+ } -+ else { -+ zmm.packet[startN * endM + startM] = -+ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast], -+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]); -+ } -+ // Bcast -+ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) { -+ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1(&A_t[idA( -+ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]); -+ } -+ } -+ -+ // We have updated all accumlators, time to load next set of B's -+ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) { -+ gemm::template loadB(B_t, LDB, zmm, rem_); -+ } -+ aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel( -+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(B_t); -+ EIGEN_UNUSED_VARIABLE(A_t); -+ EIGEN_UNUSED_VARIABLE(LDB); -+ EIGEN_UNUSED_VARIABLE(LDA); -+ EIGEN_UNUSED_VARIABLE(zmm); -+ EIGEN_UNUSED_VARIABLE(rem_); -+ } -+ -+ /******************************************************** -+ * Wrappers for aux_XXXX to hide counter parameter -+ ********************************************************/ -+ -+ template -+ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock &zmm) { -+ aux_setzero(zmm); -+ } -+ -+ /** -+ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_updateC(C_arr, LDC, zmm, rem_); -+ } -+ -+ template -+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_storeC(C_arr, LDC, zmm, rem_); -+ } -+ -+ /** -+ * Use numLoad registers for loading B at start of microKernel -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_startLoadB(B_t, LDB, zmm, rem_); -+ } -+ -+ /** -+ * Use numBCast registers for broadcasting A at start of microKernel -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA, -+ PacketBlock &zmm) { -+ aux_startBCastA(A_t, LDA, zmm); -+ } -+ -+ /** -+ * Loads next set of B into vector registers between each K unroll. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_loadB(B_t, LDB, zmm, rem_); -+ } -+ -+ /** -+ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B. -+ * A matrix can be row/col-major. B matrix is assumed row-major. -+ * -+ * isARowMajor: is A row major -+ * endM: Number registers per row -+ * endN: Number of rows -+ * endK: Loop unroll for K. -+ * numLoad: Number of registers for loading B. -+ * numBCast: Number of registers for broadcasting A. -+ * -+ * Ex: microkernel: 8x48 unroll (24 accumulators), k unrolled 4 times, -+ * 6 register for loading B, 2 for broadcasting A. -+ * -+ * Note: Ideally the microkernel should not have any register spilling. -+ * The avx instruction counts should be: -+ * - endK*endN vbroadcasts{s,d} -+ * - endK*endM vmovup{s,d} -+ * - endK*endN*endM FMAs -+ * -+ * From testing, there are no register spills with clang. There are register spills with GNU, which -+ * causes a performance hit. -+ */ -+ template -+ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, -+ PacketBlock &zmm, -+ int64_t rem_ = 0) { -+ EIGEN_UNUSED_VARIABLE(rem_); -+ aux_microKernel(B_t, A_t, LDB, LDA, zmm, -+ rem_); -+ } -+}; -+} // namespace unrolls -+ -+#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H From 53afba36acdc2b495bb37c089a6a3f3a8724d260 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:36:56 +0200 Subject: [PATCH 20/21] Delete easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb --- .../m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb | 51 ------------------- 1 file changed, 51 deletions(-) delete mode 100644 easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb diff --git a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb b/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb deleted file mode 100644 index 9c3a18bfdb5..00000000000 --- a/easybuild/easyconfigs/m/ml_dtypes/ml_dtypes-0.3.2-gfbf-2023a.eb +++ /dev/null @@ -1,51 +0,0 @@ -# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/02 -easyblock = 'PythonBundle' - -name = 'ml_dtypes' -version = '0.3.2' - -homepage = 'https://github.com/jax-ml/ml_dtypes' -description = """ -ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used -in machine learning libraries, including: - -bfloat16: an alternative to the standard float16 format -float8_*: several experimental 8-bit floating point representations including: -float8_e4m3b11fnuz -float8_e4m3fn -float8_e4m3fnuz -float8_e5m2 -float8_e5m2fnuz -""" - -toolchain = {'name': 'gfbf', 'version': '2023a'} - -dependencies = [ - ('Python', '3.11.3'), - ('SciPy-bundle', '2023.07'), -] - - -use_pip = True - -default_easyblock = 'PythonPackage' - -exts_list = [ - ('opt_einsum', '3.3.0', { - 'checksums': ['59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549'], - }), - ('etils', '1.6.0', { - 'checksums': ['c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a'], - }), - (name, version, { - 'patches': [('ml_dtypes-0.3.2_EigenAvx512.patch', 1)], - 'checksums': [ - {'ml_dtypes-0.3.2.tar.gz': '533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967'}, - {'ml_dtypes-0.3.2_EigenAvx512.patch': '197b05b0b7f611749824369f026099f6a172f9e8eab6ebb6504a16573746c892'}, - ], - }), -] - -sanity_pip_check = True - -moduleclass = 'tools' From c6518ee38c0cb655b2fc4c8c89cc083d2270bcc5 Mon Sep 17 00:00:00 2001 From: Thomas Hoffmann <81254262+ThomasHoffmann77@users.noreply.github.com> Date: Thu, 27 Jun 2024 10:07:38 +0200 Subject: [PATCH 21/21] Update jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb fix local_extract_cmd according to @akesandgren 's suggestion --- .../easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb index 68ab99486f2..c193b35c83b 100644 --- a/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb +++ b/easybuild/easyconfigs/j/jax/jax-0.4.25-gfbf-2023a-CUDA-12.1.1.eb @@ -41,7 +41,7 @@ dependencies = [ local_xla_commit = '4ccfe33c71665ddcbca5b127fefe8baa3ed632d4' local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' -local_extract_cmd = 'cp %s %(builddir)s/archives' +local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives' local_repo_opt = '--bazel_options="--distdir=%(builddir)s/archives" ' local_repo_opt += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' local_repo_opt += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include" '