From 4dff154f51d8e1fa4db63729a5c0796494886d6c Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 19 Dec 2023 09:18:00 -0800 Subject: [PATCH 01/19] Fix nightly pipeline failure (#18867) ### Description Fixes a failure in the ortmodule nightly pipeline. ### Motivation and Context --- .../ortmodule/stage1/requirements_torch_nightly/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt index fc8e542cb9833..0cd5e5c5d5c46 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_nightly/requirements.txt @@ -1,4 +1,5 @@ scikit-learn packaging==21.3 transformers==v4.30.0 +accelerate==0.20.1 wget From 5f00bc99311a5d4eb7b0269c1f1cf1a7db1a1f9a Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Wed, 20 Dec 2023 01:36:31 +0800 Subject: [PATCH 02/19] Integrate high-performance x64 gemm library to MLAS (#17669) ### Description Improve MLAS to support high-performance x64 INT4 kernels ### Motivation and Context 1. improve LLM inference performance on Intel CPUs. 2. support more 4bit quantization types: nf4, fp4 3. support dynamic block size: block size aligned with kernel's tiling size(e.g. 4 for VNNI kernel), per channel on N dimension 4. support most Intel ISAs: avx2, avx_vnni, avx512f, avx512_vnni, amx_bf16, amx_int8, avx512_fp16 5. support MatMulNBits' data format ### Tasks - [x] support block_size: 32, 128, -1(per channel) - [x] get weight pack size without memory allocation - [x] use ort's thread pool for parallelism - [x] support ISAs: avx2, avx512f, avx_vnni, avx512_vnni, amx_int8 ### Benchmark Ubuntu 20.22 + Intel(R) Xeon(R) Platinum 8480+ 56 cores Benchmark | Time | CPU | Iterations -- | -- | -- | -- Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:4096/K:4096/Threads:56/real_time | 47613 | 47401 | 12970 Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:4096/K:4096/Threads:56/real_time | 6347792 | 6317562 | 109 Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:4096/K:4096/Threads:56/real_time | 11814014 | 11757847 | 59 Q4GEMM_Jblas/Q4G128SymInt8/M:1/N:4096/K:4096/Threads:56/real_time | 50222 | 50031 | 13759 Q4GEMM_Jblas/Q4G128SymInt8/M:1024/N:4096/K:4096/Threads:56/real_time | 2038222 | 2028743 | 341 Q4GEMM_Jblas/Q4G128SymInt8/M:2048/N:4096/K:4096/Threads:56/real_time | 3792832 | 3774485 | 191 Q4GEMM_Jblas/Q4GPerNSymInt8/M:1/N:4096/K:4096/Threads:56/real_time | 58717 | 58501 | 11467 Q4GEMM_Jblas/Q4GPerNSymInt8/M:1024/N:4096/K:4096/Threads:56/real_time | 1360846 | 1354598 | 543 Q4GEMM_Jblas/Q4GPerNSymInt8/M:2048/N:4096/K:4096/Threads:56/real_time | 2564232 | 2551365 | 266 Q4GEMM_Jblas/Q4G32SymFp32/M:1/N:4096/K:4096/Threads:56/real_time | 57929 | 57694 | 12047 Q4GEMM_Jblas/Q4G32SymFp32/M:1024/N:4096/K:4096/Threads:56/real_time | 5495330 | 5465810 | 126 Q4GEMM_Jblas/Q4G32SymFp32/M:2048/N:4096/K:4096/Threads:56/real_time | 10676240 | 10617817 | 66 Q4GEMM_Jblas/Q4G128SymFp32/M:1/N:4096/K:4096/Threads:56/real_time | 68305 | 68047 | 10026 Q4GEMM_Jblas/Q4G128SymFp32/M:1024/N:4096/K:4096/Threads:56/real_time | 5504862 | 5476215 | 126 Q4GEMM_Jblas/Q4G128SymFp32/M:2048/N:4096/K:4096/Threads:56/real_time | 11758623 | 11697337 | 66 Q4GEMM_Jblas/Q4GPerNSymFp32/M:1/N:4096/K:4096/Threads:56/real_time | 67713 | 67451 | 10298 Q4GEMM_Jblas/Q4GPerNSymFp32/M:1024/N:4096/K:4096/Threads:56/real_time | 5508325 | 5480237 | 126 Q4GEMM_Jblas/Q4GPerNSymFp32/M:2048/N:4096/K:4096/Threads:56/real_time | 10738528 | 10681656 | 64 Q4GEMM_Jblas/Q4G32AsymFp32/M:1/N:4096/K:4096/Threads:56/real_time | 60708 | 60486 | 11321 Q4GEMM_Jblas/Q4G32AsymFp32/M:1024/N:4096/K:4096/Threads:56/real_time | 5523784 | 5495736 | 126 Q4GEMM_Jblas/Q4G32AsymFp32/M:2048/N:4096/K:4096/Threads:56/real_time | 10829633 | 10772161 | 67 Reference: Benchmark | Time | CPU | Iterations -- | -- | -- | -- Q4GEMM/Q4Sym/M:1/N:4096/K:4096/Threads:56/real_time | 53088 | 52911 | 13364 Q4GEMM/Q4Sym/M:1024/N:4096/K:4096/Threads:56/real_time | 6268981 | 6230335 | 110 Q4GEMM/Q4Sym/M:2048/N:4096/K:4096/Threads:56/real_time | 11701237 | 11632339 | 59 Win11+12900K 8 cores: Benchmark | Time | CPU | Iterations -- | -- | -- | -- Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:4096/K:4096/Threads:8/real_time | 215976 | 211295 | 2884 Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:4096/K:4096/Threads:8/real_time | 60960590 | 60937500 | 10 Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:4096/K:4096/Threads:8/real_time | 1.18E+08 | 1.19E+08 | 5 Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:11008/K:4096/Threads:8/real_time | 470377 | 453059 | 1414 Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:11008/K:4096/Threads:8/real_time | 1.54E+08 | 1.53E+08 | 5 Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:11008/K:4096/Threads:8/real_time | 3.18E+08 | 3.13E+08 | 2 Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:4096/K:11008/Threads:8/real_time | 569072 | 559398 | 1229 Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:4096/K:11008/Threads:8/real_time | 1.54E+08 | 1.52E+08 | 4 Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:4096/K:11008/Threads:8/real_time | 3.22E+08 | 3.28E+08 | 2 Q4GEMM_Jblas/Q4G32SymInt8/M:1/N:11008/K:11008/Threads:8/real_time | 1486055 | 1473325 | 403 Q4GEMM_Jblas/Q4G32SymInt8/M:1024/N:11008/K:11008/Threads:8/real_time | 4.14E+08 | 4.14E+08 | 2 Q4GEMM_Jblas/Q4G32SymInt8/M:2048/N:11008/K:11008/Threads:8/real_time | 8.88E+08 | 8.59E+08 | 1 --------- Signed-off-by: Mengni Wang Co-authored-by: Mengni Wang --- cmake/CMakeLists.txt | 12 + cmake/onnxruntime_mlas.cmake | 16 +- docs/ContribOperators.md | 2 + .../cpu/quantization/matmul_nbits.cc | 134 +- .../core/graph/contrib_ops/contrib_defs.cc | 7 + onnxruntime/core/mlas/inc/mlas_qnbit.h | 141 + onnxruntime/core/mlas/lib/jblas_defs.h | 73 + onnxruntime/core/mlas/lib/jblas_gemm.cpp | 534 ++ onnxruntime/core/mlas/lib/jblas_gemm.h | 61 + onnxruntime/core/mlas/lib/mlasi.h | 2 + onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 127 + .../core/mlas/lib/x86_64/jblas/.clang-format | 7 + .../core/mlas/lib/x86_64/jblas/CMakeLists.txt | 33 + .../mlas/lib/x86_64/jblas/jblas/jit_base.h | 303 ++ .../mlas/lib/x86_64/jblas/jblas/jit_blas.h | 96 + .../lib/x86_64/jblas/jblas/jit_blas_device.h | 277 + .../x86_64/jblas/jblas/jit_blas_epilogue.h | 329 ++ .../lib/x86_64/jblas/jblas/jit_blas_gemm.h | 2699 ++++++++++ .../x86_64/jblas/jblas/jit_blas_parallel.h | 678 +++ .../x86_64/jblas/jblas/jit_blas_prologue_a.h | 214 + .../x86_64/jblas/jblas/jit_blas_prologue_b.h | 892 ++++ .../lib/x86_64/jblas/jblas/jit_blas_storage.h | 665 +++ .../lib/x86_64/jblas/jblas/jit_blas_utils.h | 638 +++ .../lib/x86_64/jblas/jblas/jit_blas_wrapper.h | 281 + .../mlas/lib/x86_64/jblas/jblas/kernel_avx2.h | 874 +++ .../x86_64/jblas/jblas/kernel_avx512_bf16.h | 92 + .../lib/x86_64/jblas/jblas/kernel_avx512f.h | 1966 +++++++ .../mlas/lib/x86_64/jblas/jblas/kernel_jit.h | 1375 +++++ .../x86_64/jblas/jblas/kernel_jit_injector.h | 930 ++++ .../mlas/lib/x86_64/jblas/jblas/kernel_ref.h | 1039 ++++ .../lib/x86_64/jblas/jblas/kernel_wrapper.h | 702 +++ .../mlas/lib/x86_64/jblas/jblas/xbyak/xbyak.h | 3313 ++++++++++++ .../x86_64/jblas/jblas/xbyak/xbyak_bin2hex.h | 271 + .../x86_64/jblas/jblas/xbyak/xbyak_mnemonic.h | 4728 +++++++++++++++++ .../lib/x86_64/jblas/jblas/xbyak/xbyak_util.h | 1160 ++++ .../test/contrib_ops/matmul_4bits_test.cc | 187 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 54 + 37 files changed, 24902 insertions(+), 10 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/jblas_defs.h create mode 100644 onnxruntime/core/mlas/lib/jblas_gemm.cpp create mode 100644 onnxruntime/core/mlas/lib/jblas_gemm.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_parallel.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_a.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_prologue_b.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_storage.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_utils.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_wrapper.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_avx2.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_avx512_bf16.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_avx512f.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_jit.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_jit_injector.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_ref.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/kernel_wrapper.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak_bin2hex.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak_mnemonic.h create mode 100644 onnxruntime/core/mlas/lib/x86_64/jblas/jblas/xbyak/xbyak_util.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7494035e4784e..23ded3bfc1e68 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() +set(USE_JBLAS FALSE) +if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD) + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + add_compile_definitions(MLAS_JBLAS) + set(USE_JBLAS TRUE) + elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + add_compile_definitions(MLAS_JBLAS) + set(USE_JBLAS TRUE) + endif() +endif() + # TVM EP if (onnxruntime_USE_TVM) if (NOT TARGET tvm) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 26e4380af4c23..bee83ff07c74b 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -45,6 +45,15 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) +function(add_jblas) + add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) + target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/jblas_gemm.cpp + ) + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) +endfunction() + #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/q4gemm_avx512.cpp ) endif() - else() target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp @@ -566,7 +574,7 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() + endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) @@ -604,6 +612,10 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() +if(USE_JBLAS) + add_jblas() +endif() + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e5b43ddba8cc7..131db5d8d9b37 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2824,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
size of each input feature
N : int (required)
size of each output feature
+
accuracy_level : int
+
The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.
bits : int (required)
number of bits used for weight quantization (default 4)
block_size : int (required)
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 320a05bb97dac..b060d500c6484 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -20,30 +20,158 @@ class MatMulNBits final : public OpKernel { K_{narrow(info.GetAttr("K"))}, N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, - nbits_{narrow(info.GetAttr("bits"))} { + nbits_{narrow(info.GetAttr("bits"))}, + accuracy_level_{info.GetAttr("accuracy_level")} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + is_asym_ = info.GetInputCount() >= 4; + const Tensor* tensor_B = nullptr; + const Tensor* tensor_scale = nullptr; + const Tensor* tensor_zero_point = nullptr; + bool B_constant = info.TryGetConstantInput(1, &tensor_B); + bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); + bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + all_constant_ = B_constant && scale_constant; + all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; } Status Compute(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + /*out*/ bool& used_shared_buffers) override; + private: const size_t K_; const size_t N_; const size_t block_size_; const size_t nbits_; + const int64_t accuracy_level_; const bool column_wise_quant_{true}; + IAllocatorUniquePtr packed_b_; + size_t packed_b_size_{0}; + bool is_asym_{false}; + bool all_constant_{false}; }; +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (!all_constant_) { + return Status::OK(); + } + auto compt_type = static_cast(accuracy_level_); + MLAS_THREADPOOL* pool = NULL; + if (input_idx == 1) { + packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.Data(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + if (packed_b_ == nullptr) { + return Status::OK(); + } + std::memset(packed_b_.get(), 0, packed_b_size_); + MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, false, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 2 && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, !is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 3 && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + + return Status::OK(); +} + +Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + // Pack three tensors into one buffer + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 2) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 3) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + return Status::OK(); +} + Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(0); + const auto* a_data = a->Data(); + + if (packed_b_.get()) { + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + std::vector gemm_params(max_len); + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + for (size_t i = 0; i < max_len; i++) { + gemm_params[i].A = a_data + helper.LeftOffsets()[i]; + gemm_params[i].lda = lda; + gemm_params[i].B = packed_b_.get(); + gemm_params[i].C = y_data + helper.OutputOffsets()[i]; + gemm_params[i].ldc = N; + } + auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + // workspace for activation process(dynamic quantization and others) + auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); + MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), + thread_pool); + return Status::OK(); + } + const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); - - const auto* a_data = a->Data(); const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 26fca454c96f0..54eb43753931a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3359,6 +3359,13 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored .Attr("N", "size of each output feature", AttributeProto::INT) .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("accuracy_level", + "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) " + "(default unset). It is used to control how input A is quantized or downcast internally while " + "doing computation, for example: 0 means input A will not be quantized or downcast while doing " + "computation. 4 means input A can be quantized with the same block_size to int8 internally from " + "type T1.", + AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1-dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 9620dd42d1da9..1e83dd1cec400 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -77,3 +77,144 @@ MlasIsSQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen ); + +/** + * @brief Define compute types of block quantization + */ +typedef enum { + CompUndef = 0, /*!< undef */ + CompFp32 = 1, /*!< input fp32, accumulator fp32 */ + CompFp16 = 2, /*!< input fp16, accumulator fp16 */ + CompBf16 = 3, /*!< input bf16, accumulator fp32 */ + CompInt8 = 4 /*!< input int8, accumulator int32 */ +} MLAS_SQNBIT_COMPUTE_TYPE; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t MLASCALL +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type +); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void MLASCALL +MlasNBitsGemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t block_size, + int nbits, + bool is_asym, + bool last_call, + MLAS_SQNBIT_COMPUTE_TYPE comp_type, + MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void MLASCALL +MlasNBitsGemmUnPackB( + float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t MLASCALL +MlasSQNBitsGemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSQNBitsGemmBatchPackedB( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + void* WorkSpace, + MLAS_THREADPOOL* ThreadPool = nullptr +); diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h new file mode 100644 index 0000000000000..9cd1711a3ffd2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_defs.h @@ -0,0 +1,73 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "jblas/jit_blas_prologue_b.h" +#include "jblas/jit_blas_wrapper.h" + +namespace jblas +{ + +/* +Name conversion explaination: +Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore) +S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight +classes) +F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and +jblas::epilogue::gemm::AccumulatorWriteBackFp32. + +Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores. +*/ +template +using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< + GemmCore_T::ISA, + GemmCore_T, + jblas::prologue_a::gemm::ActivationKBlockBaseF32, + jblas::prologue_b::gemm::WeightKBlockS4, + jblas::epilogue::gemm::CompFp32BlockEpilogue, + jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + +/* +Name conversion explaination: +Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore) +S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only) +F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and +jblas::epilogue::gemm::AccumulatorWriteBackFp32. + +Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores. +*/ +template +using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< + GemmCore_T::ISA, + GemmCore_T, + jblas::prologue_a::gemm::ActivationF32KBlockQuantize, + jblas::prologue_b::gemm::WeightKBlockS4, + jblas::epilogue::gemm::CompInt8BlockEpilogue, + jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + +using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency +using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency + +class ORTThreading : public jblas::parallel::IThreading +{ + public: + ORTThreading(void* tp); + void parallel_for(const jblas::parallel::thread_func& func) override; + void set_threads(int nthreads) override { assert(0); } + void sync() override { assert(0); } + void* mTp; +}; + +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp new file mode 100644 index 0000000000000..f3cae3186c28e --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_gemm.cpp @@ -0,0 +1,534 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + jblas_gemm.cpp + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#include "jblas_gemm.h" + +#include "jblas_defs.h" +#include "mlasi.h" + +using namespace jblas; + +jblas::ORTThreading::ORTThreading(void* tp) + : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) +{ +} + +void +jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) +{ + MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { + func(static_cast(tid)); + }); +} + +template +static void +JblasSQ4GemmCompF32( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + jblas::parallel::IThreading* th +) +{ + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + if (M <= 16) { + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Fp32_S4_F32F32; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->mIsAsym) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single); + } + typename Launcher::BEpiParam blkargs{ + B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), + reduceA.template get(), reduceA.lda}; + + typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}}; + jblas::parallel::GemmKBlockRun(kernel, args, th); + } else { + using Parallel = jblas::parallel::gemm::SchedulerBase; + using Launcher = jblas::wrapper::gemm::LauncherBase< + GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, + jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + + typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}}; + jblas::parallel::GemmBaseRun(kernel, args, th); + } +} + +template +static void +JblasSQ4GemmCompInt8( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + jblas::parallel::IThreading* th +) +{ + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Int8_S4_F32F32; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + typename Launcher::Param args{ + M_, + N_, + K_, + B->mBlockSize, + {A, lda_, &quanA}, + {B}, + {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep, + quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), + quanA.template RPtr(), B->mBlockSize}, + {C, ldc_}}; + jblas::parallel::GemmKBlockRun(kernel, args, th); +} + +bool +JblasSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool +) +{ + GetCPUDevice(); + ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto kptr = reinterpret_cast(ptr); + auto coretype = ptr->mCoreId; + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + JblasSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + JblasSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +template +static size_t +JblasSQ4GemmCompF32WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc +) +{ + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(N); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using Launcher = tLauncher_Fp32_S4_F32F32; + static Launcher kernel; + if (B->mIsAsym) { + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + using Launcher = jblas::wrapper::gemm::LauncherBase< + GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, + jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + return 0; + } + return 0; +} + +template +static size_t +JblasSQ4GemmCompInt8WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc +) +{ + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Int8_S4_F32F32; + static Launcher kernel; + (void)(N); + (void)(lda); + (void)(ldc); + auto quanA = kernel.mProA.createStorage( + static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym + ); + return quanA.mSize; +} + +size_t +JblasSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +) +{ + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto kptr = reinterpret_cast(ptr); + auto coretype = ptr->mCoreId; + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + size = std::max( + JblasSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + size = std::max( + JblasSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + } + } + } + return size; +} + +template +static size_t +JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) +{ + static T launcher; + auto stor = launcher.mProB.createStorage( + static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, + JBLAS_DTYPE::BF16, isAsym + ); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +size_t +JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) +{ + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + case CompBf16: + case CompFp16: + case CompFp32: + case CompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + break; + default: + return 0; + } + return 0; +} + +template +static void +JblasQ4GemmPackBImpl( + void* PackedBuf, + size_t BlkSize, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + bool IsAsym, + bool lastCall, + size_t ldb, + MLAS_THREADPOOL* ThreadPool +) +{ + static T JblasKernel; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = JblasKernel.mProB.createStorage( + N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym + ); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + JblasKernel.mProB.reduceWeight(&stor, &orth); + } +} + +bool +JblasQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +) +{ + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + case CompBf16: + case CompFp16: + case CompFp32: + case CompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + default: + return false; + } + return false; +} + +bool +JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) +{ + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h new file mode 100644 index 0000000000000..044dc5e849a0a --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_gemm.h @@ -0,0 +1,61 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + jblas_gemm.h + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#pragma once + +#include "mlas_qnbit.h" + +size_t +JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); + +bool +JblasQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +); + +bool +JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb + , MLAS_THREADPOOL* ThreadPool); + +bool +JblasSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool +); + +size_t +JblasSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bda1bb504173..7bb8b17031a84 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -50,7 +50,9 @@ Module Name: #include #endif #if defined(__x86_64__) || defined(__i386__) +#if !defined(signature_VORTEX_ebx) && !defined(signature_NEXGEN_ebx) && !defined(signature_AMD_ebx)//workaround for Bug 96238 - [i386] cpuid.h header needs include guards #include +#endif #if defined(__GNUC__) && __GNUC__ >= 12 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index f964b1affec31..7f1d1b084aec0 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -15,6 +15,9 @@ Module Name: --*/ #include "sqnbitgemm.h" +#ifdef MLAS_JBLAS +#include "jblas_gemm.h" +#endif namespace { @@ -142,3 +145,127 @@ MlasIsSQNBitGemmAvailable( return true; } + +size_t MLASCALL +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType +) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } +#endif + (void)(N); + (void)(K); + (void)(BlkSize); + (void)(nbits); + (void)(isAsym); + (void)(CompType); + return 0; +} + +void MLASCALL +MlasNBitsGemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + int nbits, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +#endif + (void)(PackedBuf); + (void)(QData); + (void)(Scale); + (void)(Zp); + (void)(N); + (void)(K); + (void)(ldb); + (void)(BlkSize); + (void)(nbits); + (void)(isAsym); + (void)(lastCall); + (void)(CompType); + (void)(ThreadPool); +} + +void MLASCALL +MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) +{ +#ifdef MLAS_JBLAS + if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +#endif + (void)(FpData); + (void)(PackedBuf); + (void)(N); + (void)(K); + (void)(ldb); + (void)(ThreadPool); +} + +size_t MLASCALL +MlasSQNBitsGemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +) +{ +#ifdef MLAS_JBLAS + return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +#endif + (void)(M); + (void)(N); + (void)(K); + (void)(BatchN); + (void)(DataParams); + return 0; +} + +void MLASCALL +MlasSQNBitsGemmBatchPackedB( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + void* WorkSpace, + MLAS_THREADPOOL* ThreadPool +) +{ + GetMlasPlatform(); +#ifdef MLAS_JBLAS + if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by jblas + return; + } +#endif + (void)(M); + (void)(N); + (void)(K); + (void)(BatchN); + (void)(DataParams); + (void)(WorkSpace); + (void)(ThreadPool); +} diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format new file mode 100644 index 0000000000000..84b876706161d --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format @@ -0,0 +1,7 @@ +Language: Cpp +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 120 +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SortIncludes: false diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt new file mode 100644 index 0000000000000..5d9c5edf45a96 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.5) + +project(jblas LANGUAGES CXX VERSION 0.1.0) + +file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) +file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) + +add_library(${PROJECT_NAME} INTERFACE) +add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) + +target_include_directories( + ${PROJECT_NAME} INTERFACE + "$" + "$" +) + +if(WIN32) + target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) + target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100) + #4068 ignore unroll and GCC flags + #4849 ignore collapse + #6262 ignore stack too large + #4702 unreachable code(false warning on constexpr condition) + #4100 unreferenced formal parameter + + target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size +endif(WIN32) + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h new file mode 100644 index 0000000000000..143adb771760b --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h @@ -0,0 +1,303 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include +#include +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +#define OFFSET(field) offsetof(params, field) + +namespace jblas { + +namespace xbyak { +class JitBase : protected Xbyak::CodeGenerator { + protected: + JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {} + + void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) { + xor_(reg, reg); + mov(reg.cvt32(), addr); + } + + void vreg_push(const Xbyak::Reg64& baseaddr) { +#ifdef _WIN32 + for (int i = 0; i < 10; i++) { + movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i)); + } +#endif + } + + void vreg_pop(const Xbyak::Reg64& baseaddr) { +#ifdef _WIN32 + for (int i = 0; i < 10; i++) { + movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]); + } +#endif + } + + void padto_le(const Xbyak::Reg64& _src, int padding) { + // _src=_src/padding*padding + if (padding == 1) { + return; + } + for (int i = 1; i < 16; i++) { + if ((1 << i) == padding) { + shr(_src, i); + shl(_src, i); + return; + } + } + assert(0); + } + + void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, + const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { + inLocalLabel(); + lea(_tmp, _total); + sub(_tmp, _pos); + cmp(_tmp, N); + jb(".maskflag"); + cmp(_tmp, 0); + jl(".zeroflag"); + uint64_t allmask = (static_cast(1) << N) - 1; + if (N == 64) { + allmask = static_cast(-1); + } + mov(_tmp, allmask); + kmovq(_msk, _tmp); + jmp(".maskend"); + L(".maskflag"); + mov(_tmp1, 1); + shlx(_tmp1, _tmp1, _tmp); + sub(_tmp1, 1); + kmovq(_msk, _tmp1); + jmp(".maskend"); + L(".zeroflag"); + mov(_tmp1, 0); + kmovq(_msk, _tmp1); + L(".maskend"); + outLocalLabel(); + } + void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total, + const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { + generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N); + } +}; + +class JitAvx : protected JitBase { + protected: + static int constexpr VBits = 256; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 16; + typedef Xbyak::Ymm vreg_t; +}; + +class JitAvx2 : protected JitAvx { + protected: + static int constexpr VBits = 256; + typedef Xbyak::Ymm vreg_t; + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } + + void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { + vpmovzxwd(dst, addr); + vpslld(dst, dst, 16); + } +}; + +class JitAvx512f : protected JitAvx2 { + protected: + static int constexpr VBits = 512; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 32; + typedef Xbyak::Zmm vreg_t; + + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } + + void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { + vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); + vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); + vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6)); + vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6)); + vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6)); + vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6)); + } + + void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) { + for (int i = 0; i < 8; ++i) { + vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]); + vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]); + } + + for (int i = 0; i < 4; ++i) { + vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]); + vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]); + vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]); + vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]); + } + + for (int i = 0; i < 2; ++i) { + vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88); + vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88); + vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88); + vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88); + vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd); + vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd); + vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd); + vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd); + } + + // last step and move out + for (int i = 0; i < N; ++i) { + vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd); + } + } + + void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) { + vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]); + vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]); + vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]); + vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]); + + vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]); + vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]); + vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]); + vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]); + vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4); + vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4); + vmovups(src_4regs[0], tmp_regs[1]); + vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); + vmovups(src_4regs[1], tmp_regs[3]); + vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); + vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14); + vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14); + vmovups(src_4regs[2], tmp_regs[1]); + vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); + vmovups(src_4regs[3], tmp_regs[3]); + vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); + } + + void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { + vpsrld(_fp32, _fp32, 16); + vpmovdw(_bf16, _fp32); + } + + void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) { + vpmovzxwd(dst, addr); + vpslld(dst, dst, 16); + } + + void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) { + mov(tmp.cvt16(), addr); + shl(tmp.cvt32(), 16); + vpbroadcastd(dst, tmp.cvt32()); + } + + void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) { + auto bf16 = Xbyak::Ymm(_fp32.getIdx()); + cvt_fp32_bf16(bf16, _fp32); + vmovups(_add, bf16); + } +}; + +class JitAvx512_bf16 : protected JitAvx512f {}; + +class JitAvx512_fp16 : protected JitAvx512f {}; + +class JitAvx512vnni : protected JitAvx512f { + protected: + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); + } +}; + +class JitAvxvnni : protected JitAvx2 { + protected: + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + vpdpbusds(x1, x2, op, Xbyak::VexEncoding); + } +}; + +class JitAmxtile : protected JitAvx512f { + public: + struct alignas(64) tileconfig_t { + uint8_t palette_id; + uint8_t reserved[15]; + uint16_t colb[16]; + uint8_t rows[16]; + }; + static int constexpr TileCount = 8; + + typedef long long (*configure_t)(void*); + + static void generate_config(Xbyak::CodeGenerator* g) { + Xbyak::util::StackFrame st(g, 1, 0, 0); + auto& parambase = st.p[0]; + g->ldtilecfg(g->ptr[parambase]); + } + + static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, + int CNum) { + // Filling tile configure structure. Could be done offline. + tc.palette_id = 1; + // Configure C tiles + int t = 0; + for (; t < CNum; ++t) { + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_N * 4); + } + // Configure A tiles + for (; t < CNum + ANum; ++t) { + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_K * elesize); + } + // Configure B tile. B effectively has 64 rows and 16 columns. + int kpack = 4 / elesize; + for (; t < CNum + ANum + BNum; ++t) { + tc.rows[t] = static_cast(TILE_K / kpack); + tc.colb[t] = static_cast(TILE_N * 4); + } + } +}; + +class JitAmxbf16 : protected JitAmxtile { + protected: + void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); } +}; + +class JitAmxint8 : protected JitAmxtile { + protected: + template + void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3); +}; +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbssd(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbsud(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbusd(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbuud(x1, x2, x3); +} +} // namespace xbyak +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h new file mode 100644 index 0000000000000..8ecf3535c17f4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h @@ -0,0 +1,96 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +enum JBLAS_CODE { + JblasSuccess = 0, + JblasInvalidParam = 1, + JblasInvalidISA = 2, + JblasRuntimeError = 4, + JblasNotSupport = 8, +}; +enum JBLAS_ISA : uint32_t { + JblasNoSIMD = 0, + JblasAVX, + JblasAVX2, + JblasAVX_VNNI, + JblasAVX512F, + JblasAVX512_VNNI, + JblasAMX_BF16, + JblasAMX_INT8, + JblasAVX512_FP16, + JblasAVX512_BF16, +}; +enum class JBLAS_DTYPE : uint32_t { + EleBitsMask = 0xff, + EleBitsUndef = 0, + EleBits4 = 4, + EleBits8 = 8, + EleBits16 = 16, + EleBits32 = 32, + EleBits64 = 64, + TypeMask = 0xff00, + TypeFloat = 0 << 8, + TypeInt = 1 << 8, + SubTypeMask = 0xff0000, + SubType0 = 0 << 16, + SubType1 = 1 << 16, + SubType2 = 2 << 16, + F64 = EleBits64 | TypeFloat, + F32 = EleBits32 | TypeFloat, + F16 = EleBits16 | TypeFloat, + BF16 = EleBits16 | TypeFloat | SubType1, + F8_E4M3 = EleBits8 | TypeFloat, + F8_E5M2 = EleBits8 | TypeFloat | SubType1, + F8_E3M4 = EleBits8 | TypeFloat | SubType2, + S8 = EleBits8 | TypeInt, + U8 = EleBits8 | TypeInt | SubType1, + S4_CLIP = EleBits4 | TypeInt, + S4_FULLRANGE = EleBits4 | TypeInt | SubType1, + F4_E2M1 = EleBits4 | TypeFloat, + F4_BNB = EleBits4 | TypeFloat | SubType1, + F4_NF4 = EleBits4 | TypeFloat | SubType2, + S32 = EleBits32 | TypeInt, + U32 = EleBits32 | TypeInt | SubType1, +}; + +enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; +enum JBLAS_TRANSPOSE { + JblasNoTrans = 111, + JblasTrans = 112, + JblasConjTrans = 113, +}; +enum JBLAS_ELTWISEOP { + GELU, + SWISH, + TANH, + EXP, + LOW_PRECISION_EXP, + RELU, + LINEAR, +}; + +enum class JBLAS_PROLOGUEB_IDS : uint32_t { + Undef = (uint32_t)-1, + Begin = 0, + NormalBegin = Begin, + WeightPack = NormalBegin, + NormalEnd, + KBlockBegin = NormalEnd, + WeightKBlockS8 = KBlockBegin, + WeightKBlockS4, + WeightKBlockF4, + KBlockEnd, + End, +}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h new file mode 100644 index 0000000000000..5cac1080bc610 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h @@ -0,0 +1,277 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "jit_blas.h" +#include "xbyak/xbyak_util.h" + +namespace jblas { + +namespace device { + +struct X64_ISA { + int64_t MMX : 1; // 0 + int64_t SSE : 1; // 1 + int64_t SSE2 : 1; // 2 + int64_t SSE3 : 1; // 3 + int64_t SSSE3 : 1; // 4 + int64_t SSE41 : 1; // 5 + int64_t SSE42 : 1; // 6 + int64_t AVX : 1; // 7 + int64_t F16C : 1; // 8 + int64_t FMA : 1; // 9 + int64_t AVX2 : 1; // 10 + int64_t AVX_VNNI : 1; // 11 + int64_t AVX_VNNI_INT8 : 1; // 12 + int64_t AVX_NE_CONVERT : 1; // 13 + int64_t AVX_IFMA : 1; // 14 + int64_t AVX512F : 1; // 15 + int64_t AVX512BW : 1; // 16 + int64_t AVX512CD : 1; // 17 + int64_t AVX512DQ : 1; // 18 + int64_t AVX512ER : 1; // 19 + int64_t AVX512IFMA52 : 1; // 20 + int64_t AVX512PF : 1; // 21 + int64_t AVX512VL : 1; // 22 + int64_t AVX512VPOPCNTDQ : 1; // 23 + int64_t AVX512_4FMAPS : 1; // 24 + int64_t AVX512_4VNNIW : 1; // 25 + int64_t AVX512_BF16 : 1; // 26 + int64_t AVX512_BITALG : 1; // 27 + int64_t AVX512_VBMI : 1; // 28 + int64_t AVX512_VBMI2 : 1; // 29 + int64_t AVX512_VNNI : 1; // 30 + int64_t AVX512_VP2INTERSECT : 1; // 31 + int64_t AVX512_FP16 : 1; // 32 + int64_t AMX_TILE : 1; // 33 + int64_t AMX_BF16 : 1; // 34 + int64_t AMX_INT8 : 1; // 35 + int64_t AMX_FP16 : 1; // 36 + int64_t AMX_COMPLEX : 1; // 37 + int64_t reserved : (64 - 38); +}; + +class AVX2_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 0; + static constexpr bool AVX512BW = 0; + static constexpr bool AVX512CD = 0; + static constexpr bool AVX512DQ = 0; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 0; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 0; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class AVX512_VNNI_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class SapphireRapids { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 1; + static constexpr bool AMX_BF16 = 1; + static constexpr bool AMX_INT8 = 1; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +template +class isa_base { + public: + static bool constexpr avx = ISA_T >= JblasAVX; + static bool constexpr avx2 = ISA_T >= JblasAVX2; + static bool constexpr avx512f = ISA_T >= JblasAVX512F; + static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; + static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; + static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; + static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; +}; + +class CpuDevice { + public: + inline void setThreads(int _nth) { + if (_nth <= 0) { + numthreads = numcores; + } else { + numthreads = std::min(numcores, _nth); + } + } + inline int getThreads() { return numthreads; } + inline int getCores() { return numcores; } + inline uint32_t getL2CacheSize() { return L2Cache; } + inline uint32_t getL1CacheSize() { return L1Cache; } + inline bool AVX() { return mHasAVX; } + inline bool AVX2() { return mHasAVX2; } + inline bool AVX_VNNI() { return mHasAVX_VNNI; } + inline bool AVX512F() { return mHasAVX512F; } + inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } + inline bool AMX_INT8() { return mHasAMX_INT8; } + inline bool AMX_BF16() { return mHasAMX_BF16; } + inline bool AVX512_BF16() { return mHasAVX512_BF16; } + inline bool AVX512_FP16() { return mHasAVX512_FP16; } +#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) + CpuDevice() { + static Xbyak::util::Cpu _cpu; + L1Cache = _cpu.getDataCacheSize(0); + L2Cache = _cpu.getDataCacheSize(1); + ADD_FLAG(AVX); + ADD_FLAG(AVX2); + ADD_FLAG(AVX512F); + ADD_FLAG(AVX512_VNNI); + ADD_FLAG(AVX_VNNI); + ADD_FLAG(AMX_BF16); + ADD_FLAG(AMX_INT8); + ADD_FLAG(AVX512_BF16); + ADD_FLAG(AVX512_FP16); + numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + numthreads = numcores; + } + + static CpuDevice* getInstance() { + static CpuDevice instance; + return &instance; + } + + void print() { + printf( + "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", + mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, + mHasAVX512_FP16); + } +#undef ADD_FLAG + + protected: + uint32_t L2Cache, L1Cache; + bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, + mHasAVX512_FP16; + int numcores; + int numthreads; +}; + +#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); + +class CpuBase { + public: + CpuBase() { + GetCPUDevice(); + mL2Cache = _cd->getL2CacheSize(); + mL1Cache = _cd->getL1CacheSize(); + mNumThreads = _cd->getThreads(); + } + size_t mL2Cache, mL1Cache; + int mNumThreads; +}; +} // namespace device +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h new file mode 100644 index 0000000000000..ceb7a545092d8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h @@ -0,0 +1,329 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "jit_base.h" +#include "jit_blas.h" +#include "jit_blas_utils.h" +#include "kernel_wrapper.h" + +namespace jblas { +namespace epilogue { +namespace gemm { + +template +class AccumulatorWriteBack { + public: + using SType = _SRC_T; + using DType = _DST_T; + struct Param { + DType* C; + int ldc; + void* elt_const_v; + }; + + template + JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + bool constexpr Valid = !std::is_same::value || std::is_same::value; + static_assert(Valid, "fp32 to bf16 conversion only."); + if constexpr (std::is_same::value) { + return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + } else if constexpr (std::is_same, std::tuple>::value) { + return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + } else if constexpr (sizeof(SType) == sizeof(DType)) { + return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, + _param.ldc, _param.elt_const_v, ops...); + } else { + assert(false); + } + } +}; + +template +class CustomAccumulatorWriteBackWithEltop { + public: + struct Param { + _DST_T* C; + int ldc; + void* elt_const_v; + }; + JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { + return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, + _param.ldc, _param.elt_const_v); + } else { + assert(false); + } + } +}; +template +using AccumulatorWriteBackFp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackInt32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackBf16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; + +template +using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; + +template +using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; + +template +class AlphaBetaProcessFp32 { + public: + struct Param { + float *C, *D; + int ldc, ldd; + float alpha, beta; + }; + + JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto DOffset = M_offset * _param.ldd + N_offset; + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto dptr = _param.D + DOffset; + return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, + dptr, _param.ldd, cptr, _param.ldc, M, N); + } +}; + +template +class CompFp32BlockEpilogue { + public: + struct Param { + void* scales; + JBLAS_DTYPE scaledtype; + int ldsb; + int8_t* zps = nullptr; + float* reduce = nullptr; + int ldra; + }; + JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + auto ret = JblasNotSupport; + if (_param.scaledtype == JBLAS_DTYPE::F32) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + assert(ret == JblasSuccess); + if (_param.zps != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( + dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, + _param.reduce + M_offset * _param.ldra + K_offset); + } + assert(ret == JblasSuccess); + return ret; + } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + assert(_param.zps == nullptr); + assert(ret == JblasSuccess); + return ret; + } + return JblasNotSupport; + } +}; + +template +class DequantInt32ToFp32 { + public: + struct Param { + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; + }; + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, _param.ldsa, + _param.scalesB + N_offset); + } +}; + +template +class CompInt8BlockEpilogue { + public: + struct Param { + void* scalesB; + JBLAS_DTYPE scaleBdtype; + int ldsb; + float* scalesA; + int ldsa; + // optional if A asym + uint8_t* zpA = nullptr; + void* reduceB = nullptr; + JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; + }; + JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + JBLAS_CODE ret = JblasNotSupport; + float* scab = nullptr; + size_t ScaleBTmpSize = N * sizeof(float); + size_t ReduceBTmpSize = N * sizeof(float); + assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); + if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { + auto scache = reinterpret_cast(tmpcache); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, + false); + assert(ret == JblasSuccess); + scab = scache; + } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { + scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; + } + float* redb = nullptr; + if (_param.reduceB) { + if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { + auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, + false); + assert(ret == JblasSuccess); + redb = rcache; + } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { + redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; + } + } + ret = kernel::wrapper::DequanS32Fp32::template forward( + srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); + assert(ret == JblasSuccess); + ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, + dstptr, cachestep, M, N); + assert(ret == JblasSuccess); + + if (_param.zpA == nullptr) { + if (_param.zpB == nullptr) { + return ret; + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa + K_offset); + } + } else { + if (_param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, + _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); + } + } + return ret; + } +}; + +template +class ZpDequantInt32ToFp32 { + public: + struct Param { + // necessary + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; + // optional if A asym + uint8_t* zpA = nullptr; + float* reduceB = nullptr; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; + }; + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.scalesB + N_offset); + if (ret != JblasSuccess) { + return ret; + } + if (_param.zpA == nullptr && _param.zpB == nullptr) { + return ret; + } else if (_param.zpA != nullptr && _param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.reduceB + N_offset); + } else if (_param.zpA == nullptr && _param.zpB != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, + _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, + _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); + } + return ret; + } +}; + +template +class AlphaBetaProcessS32U8 { + public: + struct Param { + uint8_t* C; + int ldc; + float alpha; + float scaleAcc, scaleC; + int zpC; + }; + + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, + M, N, _param.scaleAcc, _param.scaleC, _param.zpC); + } +}; + +} // namespace gemm +} // namespace epilogue +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h new file mode 100644 index 0000000000000..364da9223940f --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h @@ -0,0 +1,2699 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "jit_blas_utils.h" +#include "jit_base.h" + +namespace jblas { +namespace gemm { +enum class CompType : uint32_t { + COMP_FP32 = 0, + COMP_BF16_FP32 = 1, + COMP_FP16_FP16 = 2, + COMP_INT_START = 3, + COMP_INT8_US_INT32 = COMP_INT_START, + COMP_INT8_UU_INT32 = 4, + COMP_INT8_SS_INT32 = 5, + COMP_INT8_SU_INT32 = 6, + COMP_INT16_SS_INT32 = 7, + COMP_INT8_US_FP32 = 8, + COMP_INT8_UU_FP32 = 9, + COMP_INT8_SS_FP32 = 10, + COMP_INT8_SU_FP32 = 11, +}; + +class CoreAttr { + public: + // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| + static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, + COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; + + static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } + static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { + return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); + } + + static void parse_id(uint32_t id, uint32_t* vals) { + vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); + vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); + vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); + } + + static const char* to_str(uint32_t id) { + static char tmp[128]; + uint32_t vals[4]; + parse_id(id, vals); + sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); + return tmp; + } + + static inline size_t get_bsize(uint32_t id) { + auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + return size_t(4 / packrow); + } +}; + +namespace code { + +template +class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { + public: + static int constexpr RegLen = 8, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { + public: + static int constexpr RegLen = 32, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; + typedef utils::fp16 AType; + typedef utils::fp16 BType; + typedef utils::fp16 CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; + +template +class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; + static uint32_t constexpr COMPUTE = + (uint32_t)(std::is_same_v + ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 + : std::is_same_v ? CompType::COMP_INT8_US_INT32 + : CompType::COMP_INT8_UU_INT32); + using AType = AT; + using BType = BT; + typedef int32_t CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; +template +using Amxint8N16P4US = Amxint8N16P4; + +template +using Amxint8N16P4SS = Amxint8N16P4; + +class AmxConfigure : protected jblas::xbyak::JitAmxtile { + public: + typedef long long (*func_t)(tileconfig_t*); + + static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { + static AmxConfigure code; + tileconfig_t cfg; + std::memset(&cfg, 0, sizeof(cfg)); + configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); + code.mKernel(&cfg); + } + + protected: + AmxConfigure() { + generate_config(this); + mKernel = getCode(); + } + + func_t mKernel = nullptr; +}; + +namespace kblock { +// optimize for kblock gemm, each block size in k dimension has dequant operation +// all accumulators use fp32 dtype. +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; + typedef uint8_t AType; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = NRegs; + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= 1); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); + } + } + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); + vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); + vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + } + } + } + + void generate_zp_correction(int _mtile) { + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); + } + + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); + for (int j = 0; j < NRegs; j++) { + vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); + vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +} // namespace kblock +} // namespace code +template