diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 7494035e4784e..23ded3bfc1e68 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
+option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
@@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()
+set(USE_JBLAS FALSE)
+if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
+ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
+ add_compile_definitions(MLAS_JBLAS)
+ set(USE_JBLAS TRUE)
+ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
+ add_compile_definitions(MLAS_JBLAS)
+ set(USE_JBLAS TRUE)
+ endif()
+endif()
+
# TVM EP
if (onnxruntime_USE_TVM)
if (NOT TARGET tvm)
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 26e4380af4c23..bee83ff07c74b 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -45,6 +45,15 @@ endif()
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
+function(add_jblas)
+ add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
+ target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
+ target_sources(onnxruntime_mlas PRIVATE
+ ${MLAS_SRC_DIR}/jblas_gemm.cpp
+ )
+ set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
+endfunction()
+
#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)
@@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
endif()
-
else()
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
@@ -566,7 +574,7 @@ else()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
- endif()
+ endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})
@@ -604,6 +612,10 @@ else()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()
+if(USE_JBLAS)
+ add_jblas()
+endif()
+
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index e5b43ddba8cc7..131db5d8d9b37 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2824,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
size of each input feature
N : int (required)
size of each output feature
+accuracy_level : int
+The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.
bits : int (required)
number of bits used for weight quantization (default 4)
block_size : int (required)
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index 320a05bb97dac..b060d500c6484 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -20,30 +20,158 @@ class MatMulNBits final : public OpKernel {
K_{narrow(info.GetAttr("K"))},
N_{narrow(info.GetAttr("N"))},
block_size_{narrow(info.GetAttr("block_size"))},
- nbits_{narrow(info.GetAttr("bits"))} {
+ nbits_{narrow(info.GetAttr("bits"))},
+ accuracy_level_{info.GetAttr("accuracy_level")} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
+ is_asym_ = info.GetInputCount() >= 4;
+ const Tensor* tensor_B = nullptr;
+ const Tensor* tensor_scale = nullptr;
+ const Tensor* tensor_zero_point = nullptr;
+ bool B_constant = info.TryGetConstantInput(1, &tensor_B);
+ bool scale_constant = info.TryGetConstantInput(2, &tensor_scale);
+ bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point);
+ all_constant_ = B_constant && scale_constant;
+ all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_;
}
Status Compute(OpKernelContext* context) const override;
+ Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) override;
+
+ Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx,
+ /*out*/ bool& used_shared_buffers) override;
+
private:
const size_t K_;
const size_t N_;
const size_t block_size_;
const size_t nbits_;
+ const int64_t accuracy_level_;
const bool column_wise_quant_{true};
+ IAllocatorUniquePtr packed_b_;
+ size_t packed_b_size_{0};
+ bool is_asym_{false};
+ bool all_constant_{false};
};
+Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) {
+ is_packed = false;
+ if (!all_constant_) {
+ return Status::OK();
+ }
+ auto compt_type = static_cast(accuracy_level_);
+ MLAS_THREADPOOL* pool = NULL;
+ if (input_idx == 1) {
+ packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type);
+ if (packed_b_size_ == 0) return Status::OK();
+ auto qptr = tensor.Data();
+ packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true);
+ if (packed_b_ == nullptr) {
+ return Status::OK();
+ }
+ std::memset(packed_b_.get(), 0, packed_b_size_);
+ MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_),
+ is_asym_, false, compt_type, pool);
+ if (prepacked_weights) {
+ prepacked_weights->buffers_.push_back(std::move(packed_b_));
+ prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
+ }
+ is_packed = true;
+ }
+ if (input_idx == 2 && packed_b_ != nullptr) {
+ auto sptr = tensor.Data();
+ MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_),
+ is_asym_, !is_asym_, compt_type, pool);
+ if (prepacked_weights) {
+ prepacked_weights->buffers_.push_back(std::move(packed_b_));
+ prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
+ }
+ is_packed = true;
+ }
+ if (input_idx == 3 && packed_b_ != nullptr) {
+ auto zptr = tensor.Data();
+ MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_),
+ is_asym_, is_asym_, compt_type, pool);
+ if (prepacked_weights) {
+ prepacked_weights->buffers_.push_back(std::move(packed_b_));
+ prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
+ }
+ is_packed = true;
+ }
+
+ return Status::OK();
+}
+
+Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx,
+ /*out*/ bool& used_shared_buffers) {
+ used_shared_buffers = false;
+ // Pack three tensors into one buffer
+ if (input_idx == 1) {
+ used_shared_buffers = true;
+ packed_b_ = std::move(prepacked_buffers[0]);
+ }
+ if (input_idx == 2) {
+ used_shared_buffers = true;
+ packed_b_ = std::move(prepacked_buffers[0]);
+ }
+ if (input_idx == 3) {
+ used_shared_buffers = true;
+ packed_b_ = std::move(prepacked_buffers[0]);
+ }
+ return Status::OK();
+}
+
Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const Tensor* a = ctx->Input(0);
+ const auto* a_data = a->Data();
+
+ if (packed_b_.get()) {
+ TensorShape b_shape({static_cast(N_), static_cast(K_)});
+
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
+
+ Tensor* y = ctx->Output(0, helper.OutputShape());
+
+ // Bail out early if the output is going to be empty
+ if (y->Shape().Size() == 0) return Status::OK();
+
+ auto* y_data = y->MutableData();
+
+ const size_t max_len = helper.OutputOffsets().size();
+ const size_t M = static_cast(helper.M());
+ const size_t N = static_cast(helper.N());
+ const size_t K = static_cast(helper.K());
+ const size_t lda = helper.Lda(false);
+ std::vector gemm_params(max_len);
+ AllocatorPtr allocator;
+ auto status = ctx->GetTempSpaceAllocator(&allocator);
+ ORT_RETURN_IF_ERROR(status);
+ for (size_t i = 0; i < max_len; i++) {
+ gemm_params[i].A = a_data + helper.LeftOffsets()[i];
+ gemm_params[i].lda = lda;
+ gemm_params[i].B = packed_b_.get();
+ gemm_params[i].C = y_data + helper.OutputOffsets()[i];
+ gemm_params[i].ldc = N;
+ }
+ auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
+ // workspace for activation process(dynamic quantization and others)
+ auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size);
+ MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
+ thread_pool);
+ return Status::OK();
+ }
+
const Tensor* b = ctx->Input(1);
const Tensor* scales = ctx->Input(2);
const Tensor* zero_points = ctx->Input(3);
-
- const auto* a_data = a->Data();
const uint8_t* b_data = b->Data();
const auto* scales_data = scales->Data();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data();
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 26fca454c96f0..54eb43753931a 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3359,6 +3359,13 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
+ .Attr("accuracy_level",
+ "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) "
+ "(default unset). It is used to control how input A is quantized or downcast internally while "
+ "doing computation, for example: 0 means input A will not be quantized or downcast while doing "
+ "computation. 4 means input A can be quantized with the same block_size to int8 internally from "
+ "type T1.",
+ AttributeProto::INT, static_cast(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h
index 9620dd42d1da9..1e83dd1cec400 100644
--- a/onnxruntime/core/mlas/inc/mlas_qnbit.h
+++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h
@@ -77,3 +77,144 @@ MlasIsSQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen
);
+
+/**
+ * @brief Define compute types of block quantization
+ */
+typedef enum {
+ CompUndef = 0, /*!< undef */
+ CompFp32 = 1, /*!< input fp32, accumulator fp32 */
+ CompFp16 = 2, /*!< input fp16, accumulator fp16 */
+ CompBf16 = 3, /*!< input bf16, accumulator fp32 */
+ CompInt8 = 4 /*!< input int8, accumulator int32 */
+} MLAS_SQNBIT_COMPUTE_TYPE;
+
+/**
+ * @brief Data parameters for NBits GEMM routine
+ * C = A * B
+ * A, C must be a float32 matrix
+ * B must be a packed nbits blob
+ * All except C are [in] parameters
+ */
+struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS {
+ const float* A = nullptr; /**< address of A (float32 matrix)*/
+ const void* B = nullptr; /**< address of B (packed nbits blob)*/
+ float* C = nullptr; /**< address of result matrix */
+ size_t lda = 0; /**< leading dimension of A */
+ size_t ldc = 0; /**< leading dimension of C*/
+};
+
+/**
+ * @brief Compute the byte size of the parameter combination
+ *
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param block_size size of the block to quantize, elements from the same block share the same
+ * scale and zero point
+ * @param nbits number of bits used for weight quantization
+ * @param is_asym flag for asymmetric quantization
+ * @param comp_type specify input data type and accumulator data type
+ * @return size of the packing buffer, 0 if the operation is not yet supported.
+ */
+size_t MLASCALL
+MlasNBitsGemmPackBSize(
+ size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type
+);
+
+/**
+ * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers.
+ *
+ * @param PackedBuf packed data buffer
+ * @param QData quantized data buffer
+ * @param Scale scale pointer
+ * @param Zp zero point pointer
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param ldb leading dimension of B
+ * @param block_size size of the block to quantize, elements from the same block share the same
+ * scale and zero point
+ * @param nbits number of bits used for weight quantization (default 4)
+ * @param is_asym flag for asymmetric quantization
+ * @param comp_type specify input data type and accumulator data type
+ * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
+ * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
+ * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
+ * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
+ * (is_asym is false) and Zp(is_asym is true).
+ * @param thread_pool
+ */
+void MLASCALL
+MlasNBitsGemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t block_size,
+ int nbits,
+ bool is_asym,
+ bool last_call,
+ MLAS_SQNBIT_COMPUTE_TYPE comp_type,
+ MLAS_THREADPOOL* thread_pool
+);
+
+/**
+ * @brief Unpack and dequantize to fp32
+ *
+ * @param FpData unpacked float32 data
+ * @param PackedBuf quantized and packed data
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param ldb leading dimension of B
+ * @param thread_pool
+ */
+void MLASCALL
+MlasNBitsGemmUnPackB(
+ float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool
+);
+
+/**
+ * @brief Get the workspace size required by computation.
+ *
+ * @param[in] M row size of matrix A and C
+ * @param[in] N column size of matrix B and C
+ * @param[in] K column size of matrix A and row size of matrix B
+ * @param[in] BatchN number of batches
+ * @param[inout] DataParams An array (size BatchN) of parameter blocks
+ * @return Workspace size in bytes
+ */
+size_t MLASCALL
+MlasSQNBitsGemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+);
+
+/**
+ * @brief Batched GEMM: C = A * B
+ * A, C must be a float32 matrix
+ * B must be a packed nbits blob
+ *
+ * @param[in] M row size of matrix A and C
+ * @param[in] N column size of matrix B and C
+ * @param[in] K column size of matrix A and row size of matrix B
+ * @param[in] BatchN number of batches
+ * @param[inout] DataParams An array (size BatchN) of parameter blocks
+ * @param[in] WorkSpace temporary buffer
+ * @param[in] ThreadPool
+ * @return
+ */
+void MLASCALL
+MlasSQNBitsGemmBatchPackedB(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ void* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool = nullptr
+);
diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h
new file mode 100644
index 0000000000000..9cd1711a3ffd2
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/jblas_defs.h
@@ -0,0 +1,73 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+--*/
+
+#pragma once
+
+#include "jblas/jit_blas_prologue_b.h"
+#include "jblas/jit_blas_wrapper.h"
+
+namespace jblas
+{
+
+/*
+Name conversion explaination:
+Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore)
+S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight
+classes)
+F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and
+jblas::epilogue::gemm::AccumulatorWriteBackFp32.
+
+Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores.
+*/
+template
+using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock<
+ GemmCore_T::ISA,
+ GemmCore_T,
+ jblas::prologue_a::gemm::ActivationKBlockBaseF32,
+ jblas::prologue_b::gemm::WeightKBlockS4,
+ jblas::epilogue::gemm::CompFp32BlockEpilogue,
+ jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+
+/*
+Name conversion explaination:
+Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore)
+S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only)
+F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and
+jblas::epilogue::gemm::AccumulatorWriteBackFp32.
+
+Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores.
+*/
+template
+using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock<
+ GemmCore_T::ISA,
+ GemmCore_T,
+ jblas::prologue_a::gemm::ActivationF32KBlockQuantize,
+ jblas::prologue_b::gemm::WeightKBlockS4,
+ jblas::epilogue::gemm::CompInt8BlockEpilogue,
+ jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+
+using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>;
+using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>;
+using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>;
+using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency
+using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>;
+using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>;
+using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>;
+using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency
+
+class ORTThreading : public jblas::parallel::IThreading
+{
+ public:
+ ORTThreading(void* tp);
+ void parallel_for(const jblas::parallel::thread_func& func) override;
+ void set_threads(int nthreads) override { assert(0); }
+ void sync() override { assert(0); }
+ void* mTp;
+};
+
+} // namespace jblas
diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp
new file mode 100644
index 0000000000000..f3cae3186c28e
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/jblas_gemm.cpp
@@ -0,0 +1,534 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ jblas_gemm.cpp
+
+Abstract:
+
+ Currently only support Q4 gemm.
+--*/
+
+#include "jblas_gemm.h"
+
+#include "jblas_defs.h"
+#include "mlasi.h"
+
+using namespace jblas;
+
+jblas::ORTThreading::ORTThreading(void* tp)
+ : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp)
+{
+}
+
+void
+jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func)
+{
+ MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) {
+ func(static_cast(tid));
+ });
+}
+
+template
+static void
+JblasSQ4GemmCompF32(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc,
+ int8_t* WorkSpace,
+ jblas::parallel::IThreading* th
+)
+{
+ auto M_ = static_cast(M);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto lda_ = static_cast(lda);
+ auto ldc_ = static_cast(ldc);
+ if (M <= 16) {
+ using Parallel = jblas::parallel::gemm::SchedulerKBlock;
+ using Launcher = tLauncher_Fp32_S4_F32F32;
+ static Launcher kernel;
+ auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize);
+ if (B->mIsAsym) {
+ reduceA.assign(WorkSpace);
+ ORTThreading single(nullptr);
+ kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single);
+ }
+ typename Launcher::BEpiParam blkargs{
+ B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(),
+ reduceA.template get(), reduceA.lda};
+
+ typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}};
+ jblas::parallel::GemmKBlockRun(kernel, args, th);
+ } else {
+ using Parallel = jblas::parallel::gemm::SchedulerBase;
+ using Launcher = jblas::wrapper::gemm::LauncherBase<
+ GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase,
+ jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+ static Launcher kernel;
+
+ typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}};
+ jblas::parallel::GemmBaseRun(kernel, args, th);
+ }
+}
+
+template
+static void
+JblasSQ4GemmCompInt8(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc,
+ int8_t* WorkSpace,
+ jblas::parallel::IThreading* th
+)
+{
+ using Parallel = jblas::parallel::gemm::SchedulerKBlock;
+ using Launcher = tLauncher_Int8_S4_F32F32;
+ auto M_ = static_cast(M);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto lda_ = static_cast(lda);
+ auto ldc_ = static_cast(ldc);
+ static Launcher kernel;
+ auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym);
+ quanA.assign(WorkSpace);
+ if (M <= 16) {
+ ORTThreading single(nullptr);
+ kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single);
+ } else {
+ kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th);
+ }
+ typename Launcher::Param args{
+ M_,
+ N_,
+ K_,
+ B->mBlockSize,
+ {A, lda_, &quanA},
+ {B},
+ {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep,
+ quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(),
+ quanA.template RPtr(), B->mBlockSize},
+ {C, ldc_}};
+ jblas::parallel::GemmKBlockRun(kernel, args, th);
+}
+
+bool
+JblasSQ4GemmBatchDriver(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ int8_t* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ GetCPUDevice();
+ ORTThreading orth(ThreadPool);
+ bool processed = true;
+ for (size_t i = 0; i < BatchN; i++) {
+ auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
+ auto uptr = std::unique_ptr(ptr);
+ if (ptr) {
+ if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
+ auto kptr = reinterpret_cast(ptr);
+ auto coretype = ptr->mCoreId;
+ auto NTile = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT
+ );
+ auto CType = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT
+ );
+ if (CType == uint32_t(gemm::CompType::COMP_FP32)) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
+ JblasSQ4GemmCompF32(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
+ JblasSQ4GemmCompF32(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) {
+ if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) {
+ if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ }
+ }
+ }
+ } else {
+ processed = false;
+ break;
+ }
+ }
+ return processed;
+}
+
+template
+static size_t
+JblasSQ4GemmCompF32WorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc
+)
+{
+ auto M_ = static_cast(M);
+ auto K_ = static_cast(K);
+ (void)(N);
+ (void)(lda);
+ (void)(ldc);
+ if (M <= 16) {
+ using Launcher = tLauncher_Fp32_S4_F32F32;
+ static Launcher kernel;
+ if (B->mIsAsym) {
+ auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize);
+ return reduceA.mSize;
+ }
+ return 0;
+ } else {
+ using Launcher = jblas::wrapper::gemm::LauncherBase<
+ GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase,
+ jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+ static Launcher kernel;
+ return 0;
+ }
+ return 0;
+}
+
+template
+static size_t
+JblasSQ4GemmCompInt8WorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc
+)
+{
+ using Parallel = jblas::parallel::gemm::SchedulerKBlock;
+ using Launcher = tLauncher_Int8_S4_F32F32;
+ static Launcher kernel;
+ (void)(N);
+ (void)(lda);
+ (void)(ldc);
+ auto quanA = kernel.mProA.createStorage(
+ static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym
+ );
+ return quanA.mSize;
+}
+
+size_t
+JblasSQ4GemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+)
+{
+ GetCPUDevice();
+ size_t size = 0;
+ for (size_t i = 0; i < BatchN; i++) {
+ auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
+ auto uptr = std::unique_ptr(ptr);
+ if (ptr) {
+ if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
+ auto kptr = reinterpret_cast(ptr);
+ auto coretype = ptr->mCoreId;
+ auto NTile = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT
+ );
+ auto CType = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT
+ );
+ if (CType == uint32_t(gemm::CompType::COMP_FP32)) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
+ size = std::max(
+ JblasSQ4GemmCompF32WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
+ size = std::max(
+ JblasSQ4GemmCompF32WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) {
+ if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) {
+ if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ }
+ }
+ }
+ }
+ }
+ return size;
+}
+
+template
+static size_t
+JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym)
+{
+ static T launcher;
+ auto stor = launcher.mProB.createStorage(
+ static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32,
+ JBLAS_DTYPE::BF16, isAsym
+ );
+ // TODO(Yu) support more scale dtype
+ return stor.mSize;
+}
+
+size_t
+JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType)
+{
+ GetCPUDevice();
+ if (K % BlkSize != 0) {
+ return 0;
+ }
+ // from low precision to high precision
+ switch (CompType) {
+ case CompInt8:
+ if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ case CompBf16:
+ case CompFp16:
+ case CompFp32:
+ case CompUndef:
+ if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ break;
+ default:
+ return 0;
+ }
+ return 0;
+}
+
+template
+static void
+JblasQ4GemmPackBImpl(
+ void* PackedBuf,
+ size_t BlkSize,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ bool IsAsym,
+ bool lastCall,
+ size_t ldb,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ static T JblasKernel;
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto stor = JblasKernel.mProB.createStorage(
+ N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym
+ );
+ stor.assign(reinterpret_cast(PackedBuf));
+ ORTThreading orth(ThreadPool);
+ JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth);
+ if (lastCall) {
+ JblasKernel.mProB.reduceWeight(&stor, &orth);
+ }
+}
+
+bool
+JblasQ4GemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t BlkSize,
+ bool isAsym,
+ bool lastCall,
+ MLAS_SQNBIT_COMPUTE_TYPE CompType,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ GetCPUDevice();
+ // explicit statement fall through.
+ switch (CompType) {
+ case CompInt8:
+ if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ case CompBf16:
+ case CompFp16:
+ case CompFp32:
+ case CompUndef:
+ if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ default:
+ return false;
+ }
+ return false;
+}
+
+bool
+JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool)
+{
+ auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf);
+ auto uptr = std::unique_ptr(ptr);
+ ORTThreading orth(ThreadPool);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto ldb_ = static_cast(ldb);
+ GetCPUDevice();
+ if (ptr) {
+ if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
+ auto NTile = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT
+ );
+ auto CType = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT
+ );
+ if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ }
+ }
+ if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) {
+ if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ }
+ }
+ if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) {
+ if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ }
+ }
+ }
+ return true;
+ }
+ return false;
+}
diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h
new file mode 100644
index 0000000000000..044dc5e849a0a
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/jblas_gemm.h
@@ -0,0 +1,61 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ jblas_gemm.h
+
+Abstract:
+
+ Currently only support Q4 gemm.
+--*/
+
+#pragma once
+
+#include "mlas_qnbit.h"
+
+size_t
+JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType);
+
+bool
+JblasQ4GemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t BlkSize,
+ bool isAsym,
+ bool lastCall,
+ MLAS_SQNBIT_COMPUTE_TYPE CompType,
+ MLAS_THREADPOOL* ThreadPool
+);
+
+bool
+JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb
+ , MLAS_THREADPOOL* ThreadPool);
+
+bool
+JblasSQ4GemmBatchDriver(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ int8_t* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool
+);
+
+size_t
+JblasSQ4GemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+);
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index 7bda1bb504173..7bb8b17031a84 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -50,7 +50,9 @@ Module Name:
#include
#endif
#if defined(__x86_64__) || defined(__i386__)
+#if !defined(signature_VORTEX_ebx) && !defined(signature_NEXGEN_ebx) && !defined(signature_AMD_ebx)//workaround for Bug 96238 - [i386] cpuid.h header needs include guards
#include
+#endif
#if defined(__GNUC__) && __GNUC__ >= 12
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h.
diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
index f964b1affec31..7f1d1b084aec0 100644
--- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
+++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
@@ -15,6 +15,9 @@ Module Name:
--*/
#include "sqnbitgemm.h"
+#ifdef MLAS_JBLAS
+#include "jblas_gemm.h"
+#endif
namespace
{
@@ -142,3 +145,127 @@ MlasIsSQNBitGemmAvailable(
return true;
}
+
+size_t MLASCALL
+MlasNBitsGemmPackBSize(
+ size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType
+)
+{
+#ifdef MLAS_JBLAS
+ if (nbits == 4) {
+ auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType);
+ if (jsize) {
+ return jsize;
+ }
+ }
+#endif
+ (void)(N);
+ (void)(K);
+ (void)(BlkSize);
+ (void)(nbits);
+ (void)(isAsym);
+ (void)(CompType);
+ return 0;
+}
+
+void MLASCALL
+MlasNBitsGemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t BlkSize,
+ int nbits,
+ bool isAsym,
+ bool lastCall,
+ MLAS_SQNBIT_COMPUTE_TYPE CompType,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+#ifdef MLAS_JBLAS
+ if (nbits == 4) {
+ if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) {
+ return;
+ }
+ }
+#endif
+ (void)(PackedBuf);
+ (void)(QData);
+ (void)(Scale);
+ (void)(Zp);
+ (void)(N);
+ (void)(K);
+ (void)(ldb);
+ (void)(BlkSize);
+ (void)(nbits);
+ (void)(isAsym);
+ (void)(lastCall);
+ (void)(CompType);
+ (void)(ThreadPool);
+}
+
+void MLASCALL
+MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool)
+{
+#ifdef MLAS_JBLAS
+ if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) {
+ return;
+ }
+#endif
+ (void)(FpData);
+ (void)(PackedBuf);
+ (void)(N);
+ (void)(K);
+ (void)(ldb);
+ (void)(ThreadPool);
+}
+
+size_t MLASCALL
+MlasSQNBitsGemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+)
+{
+#ifdef MLAS_JBLAS
+ return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams);
+#endif
+ (void)(M);
+ (void)(N);
+ (void)(K);
+ (void)(BatchN);
+ (void)(DataParams);
+ return 0;
+}
+
+void MLASCALL
+MlasSQNBitsGemmBatchPackedB(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ void* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ GetMlasPlatform();
+#ifdef MLAS_JBLAS
+ if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) {
+ // PackedWeight is created by jblas
+ return;
+ }
+#endif
+ (void)(M);
+ (void)(N);
+ (void)(K);
+ (void)(BatchN);
+ (void)(DataParams);
+ (void)(WorkSpace);
+ (void)(ThreadPool);
+}
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format
new file mode 100644
index 0000000000000..84b876706161d
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format
@@ -0,0 +1,7 @@
+Language: Cpp
+BasedOnStyle: Google
+DerivePointerAlignment: false
+ColumnLimit: 120
+SpaceBeforeParens: ControlStatements
+SpaceBeforeRangeBasedForLoopColon: true
+SortIncludes: false
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt
new file mode 100644
index 0000000000000..5d9c5edf45a96
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt
@@ -0,0 +1,33 @@
+cmake_minimum_required(VERSION 3.5)
+
+project(jblas LANGUAGES CXX VERSION 0.1.0)
+
+file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp)
+file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp)
+
+add_library(${PROJECT_NAME} INTERFACE)
+add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
+
+target_include_directories(
+ ${PROJECT_NAME} INTERFACE
+ "$"
+ "$"
+)
+
+if(WIN32)
+ target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX)
+ target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100)
+ #4068 ignore unroll and GCC flags
+ #4849 ignore collapse
+ #6262 ignore stack too large
+ #4702 unreachable code(false warning on constexpr condition)
+ #4100 unreferenced formal parameter
+
+ target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size
+endif(WIN32)
+
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17)
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h
new file mode 100644
index 0000000000000..143adb771760b
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h
@@ -0,0 +1,303 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include
+
+#include
+#include
+#include "xbyak/xbyak.h"
+#include "xbyak/xbyak_util.h"
+
+#define OFFSET(field) offsetof(params, field)
+
+namespace jblas {
+
+namespace xbyak {
+class JitBase : protected Xbyak::CodeGenerator {
+ protected:
+ JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {}
+
+ void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) {
+ xor_(reg, reg);
+ mov(reg.cvt32(), addr);
+ }
+
+ void vreg_push(const Xbyak::Reg64& baseaddr) {
+#ifdef _WIN32
+ for (int i = 0; i < 10; i++) {
+ movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i));
+ }
+#endif
+ }
+
+ void vreg_pop(const Xbyak::Reg64& baseaddr) {
+#ifdef _WIN32
+ for (int i = 0; i < 10; i++) {
+ movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]);
+ }
+#endif
+ }
+
+ void padto_le(const Xbyak::Reg64& _src, int padding) {
+ // _src=_src/padding*padding
+ if (padding == 1) {
+ return;
+ }
+ for (int i = 1; i < 16; i++) {
+ if ((1 << i) == padding) {
+ shr(_src, i);
+ shl(_src, i);
+ return;
+ }
+ }
+ assert(0);
+ }
+
+ void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total,
+ const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
+ inLocalLabel();
+ lea(_tmp, _total);
+ sub(_tmp, _pos);
+ cmp(_tmp, N);
+ jb(".maskflag");
+ cmp(_tmp, 0);
+ jl(".zeroflag");
+ uint64_t allmask = (static_cast(1) << N) - 1;
+ if (N == 64) {
+ allmask = static_cast(-1);
+ }
+ mov(_tmp, allmask);
+ kmovq(_msk, _tmp);
+ jmp(".maskend");
+ L(".maskflag");
+ mov(_tmp1, 1);
+ shlx(_tmp1, _tmp1, _tmp);
+ sub(_tmp1, 1);
+ kmovq(_msk, _tmp1);
+ jmp(".maskend");
+ L(".zeroflag");
+ mov(_tmp1, 0);
+ kmovq(_msk, _tmp1);
+ L(".maskend");
+ outLocalLabel();
+ }
+ void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total,
+ const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
+ generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N);
+ }
+};
+
+class JitAvx : protected JitBase {
+ protected:
+ static int constexpr VBits = 256;
+ static int constexpr VecBytes = VBits / 8;
+ static int constexpr RegCount = 16;
+ typedef Xbyak::Ymm vreg_t;
+};
+
+class JitAvx2 : protected JitAvx {
+ protected:
+ static int constexpr VBits = 256;
+ typedef Xbyak::Ymm vreg_t;
+ void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); }
+
+ void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) {
+ vpmovzxwd(dst, addr);
+ vpslld(dst, dst, 16);
+ }
+};
+
+class JitAvx512f : protected JitAvx2 {
+ protected:
+ static int constexpr VBits = 512;
+ static int constexpr VecBytes = VBits / 8;
+ static int constexpr RegCount = 32;
+ typedef Xbyak::Zmm vreg_t;
+
+ void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); }
+
+ void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) {
+ vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]);
+ vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]);
+ vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6));
+ vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6));
+ vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6));
+ vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6));
+ }
+
+ void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) {
+ for (int i = 0; i < 8; ++i) {
+ vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]);
+ vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]);
+ }
+
+ for (int i = 0; i < 4; ++i) {
+ vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]);
+ vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]);
+ vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]);
+ vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]);
+ }
+
+ for (int i = 0; i < 2; ++i) {
+ vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88);
+ vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88);
+ vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88);
+ vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88);
+ vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd);
+ vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd);
+ vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd);
+ vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd);
+ }
+
+ // last step and move out
+ for (int i = 0; i < N; ++i) {
+ vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd);
+ }
+ }
+
+ void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) {
+ vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]);
+ vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]);
+ vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]);
+ vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]);
+
+ vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]);
+ vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]);
+ vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]);
+ vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]);
+ vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4);
+ vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4);
+ vmovups(src_4regs[0], tmp_regs[1]);
+ vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6));
+ vmovups(src_4regs[1], tmp_regs[3]);
+ vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6));
+ vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14);
+ vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14);
+ vmovups(src_4regs[2], tmp_regs[1]);
+ vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6));
+ vmovups(src_4regs[3], tmp_regs[3]);
+ vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6));
+ }
+
+ void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) {
+ vpsrld(_fp32, _fp32, 16);
+ vpmovdw(_bf16, _fp32);
+ }
+
+ void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) {
+ vpmovzxwd(dst, addr);
+ vpslld(dst, dst, 16);
+ }
+
+ void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) {
+ mov(tmp.cvt16(), addr);
+ shl(tmp.cvt32(), 16);
+ vpbroadcastd(dst, tmp.cvt32());
+ }
+
+ void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) {
+ auto bf16 = Xbyak::Ymm(_fp32.getIdx());
+ cvt_fp32_bf16(bf16, _fp32);
+ vmovups(_add, bf16);
+ }
+};
+
+class JitAvx512_bf16 : protected JitAvx512f {};
+
+class JitAvx512_fp16 : protected JitAvx512f {};
+
+class JitAvx512vnni : protected JitAvx512f {
+ protected:
+ void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
+ vpdpbusds(x1, x2, op, Xbyak::EvexEncoding);
+ }
+};
+
+class JitAvxvnni : protected JitAvx2 {
+ protected:
+ void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
+ vpdpbusds(x1, x2, op, Xbyak::VexEncoding);
+ }
+};
+
+class JitAmxtile : protected JitAvx512f {
+ public:
+ struct alignas(64) tileconfig_t {
+ uint8_t palette_id;
+ uint8_t reserved[15];
+ uint16_t colb[16];
+ uint8_t rows[16];
+ };
+ static int constexpr TileCount = 8;
+
+ typedef long long (*configure_t)(void*);
+
+ static void generate_config(Xbyak::CodeGenerator* g) {
+ Xbyak::util::StackFrame st(g, 1, 0, 0);
+ auto& parambase = st.p[0];
+ g->ldtilecfg(g->ptr[parambase]);
+ }
+
+ static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum,
+ int CNum) {
+ // Filling tile configure structure. Could be done offline.
+ tc.palette_id = 1;
+ // Configure C tiles
+ int t = 0;
+ for (; t < CNum; ++t) {
+ tc.rows[t] = static_cast(TILE_M);
+ tc.colb[t] = static_cast(TILE_N * 4);
+ }
+ // Configure A tiles
+ for (; t < CNum + ANum; ++t) {
+ tc.rows[t] = static_cast(TILE_M);
+ tc.colb[t] = static_cast(TILE_K * elesize);
+ }
+ // Configure B tile. B effectively has 64 rows and 16 columns.
+ int kpack = 4 / elesize;
+ for (; t < CNum + ANum + BNum; ++t) {
+ tc.rows[t] = static_cast(TILE_K / kpack);
+ tc.colb[t] = static_cast(TILE_N * 4);
+ }
+ }
+};
+
+class JitAmxbf16 : protected JitAmxtile {
+ protected:
+ void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); }
+};
+
+class JitAmxint8 : protected JitAmxtile {
+ protected:
+ template
+ void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3);
+};
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbssd(x1, x2, x3);
+}
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbsud(x1, x2, x3);
+}
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbusd(x1, x2, x3);
+}
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbuud(x1, x2, x3);
+}
+} // namespace xbyak
+} // namespace jblas
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h
new file mode 100644
index 0000000000000..8ecf3535c17f4
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h
@@ -0,0 +1,96 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include
+enum JBLAS_CODE {
+ JblasSuccess = 0,
+ JblasInvalidParam = 1,
+ JblasInvalidISA = 2,
+ JblasRuntimeError = 4,
+ JblasNotSupport = 8,
+};
+enum JBLAS_ISA : uint32_t {
+ JblasNoSIMD = 0,
+ JblasAVX,
+ JblasAVX2,
+ JblasAVX_VNNI,
+ JblasAVX512F,
+ JblasAVX512_VNNI,
+ JblasAMX_BF16,
+ JblasAMX_INT8,
+ JblasAVX512_FP16,
+ JblasAVX512_BF16,
+};
+enum class JBLAS_DTYPE : uint32_t {
+ EleBitsMask = 0xff,
+ EleBitsUndef = 0,
+ EleBits4 = 4,
+ EleBits8 = 8,
+ EleBits16 = 16,
+ EleBits32 = 32,
+ EleBits64 = 64,
+ TypeMask = 0xff00,
+ TypeFloat = 0 << 8,
+ TypeInt = 1 << 8,
+ SubTypeMask = 0xff0000,
+ SubType0 = 0 << 16,
+ SubType1 = 1 << 16,
+ SubType2 = 2 << 16,
+ F64 = EleBits64 | TypeFloat,
+ F32 = EleBits32 | TypeFloat,
+ F16 = EleBits16 | TypeFloat,
+ BF16 = EleBits16 | TypeFloat | SubType1,
+ F8_E4M3 = EleBits8 | TypeFloat,
+ F8_E5M2 = EleBits8 | TypeFloat | SubType1,
+ F8_E3M4 = EleBits8 | TypeFloat | SubType2,
+ S8 = EleBits8 | TypeInt,
+ U8 = EleBits8 | TypeInt | SubType1,
+ S4_CLIP = EleBits4 | TypeInt,
+ S4_FULLRANGE = EleBits4 | TypeInt | SubType1,
+ F4_E2M1 = EleBits4 | TypeFloat,
+ F4_BNB = EleBits4 | TypeFloat | SubType1,
+ F4_NF4 = EleBits4 | TypeFloat | SubType2,
+ S32 = EleBits32 | TypeInt,
+ U32 = EleBits32 | TypeInt | SubType1,
+};
+
+enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 };
+enum JBLAS_TRANSPOSE {
+ JblasNoTrans = 111,
+ JblasTrans = 112,
+ JblasConjTrans = 113,
+};
+enum JBLAS_ELTWISEOP {
+ GELU,
+ SWISH,
+ TANH,
+ EXP,
+ LOW_PRECISION_EXP,
+ RELU,
+ LINEAR,
+};
+
+enum class JBLAS_PROLOGUEB_IDS : uint32_t {
+ Undef = (uint32_t)-1,
+ Begin = 0,
+ NormalBegin = Begin,
+ WeightPack = NormalBegin,
+ NormalEnd,
+ KBlockBegin = NormalEnd,
+ WeightKBlockS8 = KBlockBegin,
+ WeightKBlockS4,
+ WeightKBlockF4,
+ KBlockEnd,
+ End,
+};
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h
new file mode 100644
index 0000000000000..5cac1080bc610
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h
@@ -0,0 +1,277 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include "jit_blas.h"
+#include "xbyak/xbyak_util.h"
+
+namespace jblas {
+
+namespace device {
+
+struct X64_ISA {
+ int64_t MMX : 1; // 0
+ int64_t SSE : 1; // 1
+ int64_t SSE2 : 1; // 2
+ int64_t SSE3 : 1; // 3
+ int64_t SSSE3 : 1; // 4
+ int64_t SSE41 : 1; // 5
+ int64_t SSE42 : 1; // 6
+ int64_t AVX : 1; // 7
+ int64_t F16C : 1; // 8
+ int64_t FMA : 1; // 9
+ int64_t AVX2 : 1; // 10
+ int64_t AVX_VNNI : 1; // 11
+ int64_t AVX_VNNI_INT8 : 1; // 12
+ int64_t AVX_NE_CONVERT : 1; // 13
+ int64_t AVX_IFMA : 1; // 14
+ int64_t AVX512F : 1; // 15
+ int64_t AVX512BW : 1; // 16
+ int64_t AVX512CD : 1; // 17
+ int64_t AVX512DQ : 1; // 18
+ int64_t AVX512ER : 1; // 19
+ int64_t AVX512IFMA52 : 1; // 20
+ int64_t AVX512PF : 1; // 21
+ int64_t AVX512VL : 1; // 22
+ int64_t AVX512VPOPCNTDQ : 1; // 23
+ int64_t AVX512_4FMAPS : 1; // 24
+ int64_t AVX512_4VNNIW : 1; // 25
+ int64_t AVX512_BF16 : 1; // 26
+ int64_t AVX512_BITALG : 1; // 27
+ int64_t AVX512_VBMI : 1; // 28
+ int64_t AVX512_VBMI2 : 1; // 29
+ int64_t AVX512_VNNI : 1; // 30
+ int64_t AVX512_VP2INTERSECT : 1; // 31
+ int64_t AVX512_FP16 : 1; // 32
+ int64_t AMX_TILE : 1; // 33
+ int64_t AMX_BF16 : 1; // 34
+ int64_t AMX_INT8 : 1; // 35
+ int64_t AMX_FP16 : 1; // 36
+ int64_t AMX_COMPLEX : 1; // 37
+ int64_t reserved : (64 - 38);
+};
+
+class AVX2_Default {
+ public:
+ static constexpr bool MMX = 1;
+ static constexpr bool SSE = 1;
+ static constexpr bool SSE2 = 1;
+ static constexpr bool SSE3 = 1;
+ static constexpr bool SSSE3 = 1;
+ static constexpr bool SSE41 = 1;
+ static constexpr bool SSE42 = 1;
+ static constexpr bool AVX = 1;
+ static constexpr bool F16C = 1;
+ static constexpr bool FMA = 1;
+ static constexpr bool AVX2 = 1;
+ static constexpr bool AVX_VNNI = 0;
+ static constexpr bool AVX_VNNI_INT8 = 0;
+ static constexpr bool AVX_NE_CONVERT = 0;
+ static constexpr bool AVX_IFMA = 0;
+ static constexpr bool AVX512F = 0;
+ static constexpr bool AVX512BW = 0;
+ static constexpr bool AVX512CD = 0;
+ static constexpr bool AVX512DQ = 0;
+ static constexpr bool AVX512ER = 0;
+ static constexpr bool AVX512IFMA52 = 0;
+ static constexpr bool AVX512PF = 0;
+ static constexpr bool AVX512VL = 0;
+ static constexpr bool AVX512VPOPCNTDQ = 0;
+ static constexpr bool AVX512_4FMAPS = 0;
+ static constexpr bool AVX512_4VNNIW = 0;
+ static constexpr bool AVX512_BF16 = 0;
+ static constexpr bool AVX512_BITALG = 0;
+ static constexpr bool AVX512_VBMI = 0;
+ static constexpr bool AVX512_VBMI2 = 0;
+ static constexpr bool AVX512_VNNI = 0;
+ static constexpr bool AVX512_VP2INTERSECT = 0;
+ static constexpr bool AVX512_FP16 = 0;
+ static constexpr bool AMX_TILE = 0;
+ static constexpr bool AMX_BF16 = 0;
+ static constexpr bool AMX_INT8 = 0;
+ static constexpr bool AMX_FP16 = 0;
+ static constexpr bool AMX_COMPLEX = 0;
+};
+
+class AVX512_VNNI_Default {
+ public:
+ static constexpr bool MMX = 1;
+ static constexpr bool SSE = 1;
+ static constexpr bool SSE2 = 1;
+ static constexpr bool SSE3 = 1;
+ static constexpr bool SSSE3 = 1;
+ static constexpr bool SSE41 = 1;
+ static constexpr bool SSE42 = 1;
+ static constexpr bool AVX = 1;
+ static constexpr bool F16C = 1;
+ static constexpr bool FMA = 1;
+ static constexpr bool AVX2 = 1;
+ static constexpr bool AVX_VNNI = 0;
+ static constexpr bool AVX_VNNI_INT8 = 0;
+ static constexpr bool AVX_NE_CONVERT = 0;
+ static constexpr bool AVX_IFMA = 0;
+ static constexpr bool AVX512F = 1;
+ static constexpr bool AVX512BW = 1;
+ static constexpr bool AVX512CD = 1;
+ static constexpr bool AVX512DQ = 1;
+ static constexpr bool AVX512ER = 0;
+ static constexpr bool AVX512IFMA52 = 0;
+ static constexpr bool AVX512PF = 0;
+ static constexpr bool AVX512VL = 1;
+ static constexpr bool AVX512VPOPCNTDQ = 0;
+ static constexpr bool AVX512_4FMAPS = 0;
+ static constexpr bool AVX512_4VNNIW = 0;
+ static constexpr bool AVX512_BF16 = 0;
+ static constexpr bool AVX512_BITALG = 0;
+ static constexpr bool AVX512_VBMI = 0;
+ static constexpr bool AVX512_VBMI2 = 0;
+ static constexpr bool AVX512_VNNI = 1;
+ static constexpr bool AVX512_VP2INTERSECT = 0;
+ static constexpr bool AVX512_FP16 = 0;
+ static constexpr bool AMX_TILE = 0;
+ static constexpr bool AMX_BF16 = 0;
+ static constexpr bool AMX_INT8 = 0;
+ static constexpr bool AMX_FP16 = 0;
+ static constexpr bool AMX_COMPLEX = 0;
+};
+
+class SapphireRapids {
+ public:
+ static constexpr bool MMX = 1;
+ static constexpr bool SSE = 1;
+ static constexpr bool SSE2 = 1;
+ static constexpr bool SSE3 = 1;
+ static constexpr bool SSSE3 = 1;
+ static constexpr bool SSE41 = 1;
+ static constexpr bool SSE42 = 1;
+ static constexpr bool AVX = 1;
+ static constexpr bool F16C = 1;
+ static constexpr bool FMA = 1;
+ static constexpr bool AVX2 = 1;
+ static constexpr bool AVX_VNNI = 0;
+ static constexpr bool AVX_VNNI_INT8 = 0;
+ static constexpr bool AVX_NE_CONVERT = 0;
+ static constexpr bool AVX_IFMA = 0;
+ static constexpr bool AVX512F = 1;
+ static constexpr bool AVX512BW = 1;
+ static constexpr bool AVX512CD = 1;
+ static constexpr bool AVX512DQ = 1;
+ static constexpr bool AVX512ER = 0;
+ static constexpr bool AVX512IFMA52 = 0;
+ static constexpr bool AVX512PF = 0;
+ static constexpr bool AVX512VL = 1;
+ static constexpr bool AVX512VPOPCNTDQ = 0;
+ static constexpr bool AVX512_4FMAPS = 0;
+ static constexpr bool AVX512_4VNNIW = 0;
+ static constexpr bool AVX512_BF16 = 0;
+ static constexpr bool AVX512_BITALG = 0;
+ static constexpr bool AVX512_VBMI = 0;
+ static constexpr bool AVX512_VBMI2 = 0;
+ static constexpr bool AVX512_VNNI = 1;
+ static constexpr bool AVX512_VP2INTERSECT = 0;
+ static constexpr bool AVX512_FP16 = 0;
+ static constexpr bool AMX_TILE = 1;
+ static constexpr bool AMX_BF16 = 1;
+ static constexpr bool AMX_INT8 = 1;
+ static constexpr bool AMX_FP16 = 0;
+ static constexpr bool AMX_COMPLEX = 0;
+};
+
+template
+class isa_base {
+ public:
+ static bool constexpr avx = ISA_T >= JblasAVX;
+ static bool constexpr avx2 = ISA_T >= JblasAVX2;
+ static bool constexpr avx512f = ISA_T >= JblasAVX512F;
+ static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI;
+ static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16;
+ static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16;
+ static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8;
+};
+
+class CpuDevice {
+ public:
+ inline void setThreads(int _nth) {
+ if (_nth <= 0) {
+ numthreads = numcores;
+ } else {
+ numthreads = std::min(numcores, _nth);
+ }
+ }
+ inline int getThreads() { return numthreads; }
+ inline int getCores() { return numcores; }
+ inline uint32_t getL2CacheSize() { return L2Cache; }
+ inline uint32_t getL1CacheSize() { return L1Cache; }
+ inline bool AVX() { return mHasAVX; }
+ inline bool AVX2() { return mHasAVX2; }
+ inline bool AVX_VNNI() { return mHasAVX_VNNI; }
+ inline bool AVX512F() { return mHasAVX512F; }
+ inline bool AVX512_VNNI() { return mHasAVX512_VNNI; }
+ inline bool AMX_INT8() { return mHasAMX_INT8; }
+ inline bool AMX_BF16() { return mHasAMX_BF16; }
+ inline bool AVX512_BF16() { return mHasAVX512_BF16; }
+ inline bool AVX512_FP16() { return mHasAVX512_FP16; }
+#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa)
+ CpuDevice() {
+ static Xbyak::util::Cpu _cpu;
+ L1Cache = _cpu.getDataCacheSize(0);
+ L2Cache = _cpu.getDataCacheSize(1);
+ ADD_FLAG(AVX);
+ ADD_FLAG(AVX2);
+ ADD_FLAG(AVX512F);
+ ADD_FLAG(AVX512_VNNI);
+ ADD_FLAG(AVX_VNNI);
+ ADD_FLAG(AMX_BF16);
+ ADD_FLAG(AMX_INT8);
+ ADD_FLAG(AVX512_BF16);
+ ADD_FLAG(AVX512_FP16);
+ numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
+ numthreads = numcores;
+ }
+
+ static CpuDevice* getInstance() {
+ static CpuDevice instance;
+ return &instance;
+ }
+
+ void print() {
+ printf(
+ "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n",
+ mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16,
+ mHasAVX512_FP16);
+ }
+#undef ADD_FLAG
+
+ protected:
+ uint32_t L2Cache, L1Cache;
+ bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16,
+ mHasAVX512_FP16;
+ int numcores;
+ int numthreads;
+};
+
+#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance();
+
+class CpuBase {
+ public:
+ CpuBase() {
+ GetCPUDevice();
+ mL2Cache = _cd->getL2CacheSize();
+ mL1Cache = _cd->getL1CacheSize();
+ mNumThreads = _cd->getThreads();
+ }
+ size_t mL2Cache, mL1Cache;
+ int mNumThreads;
+};
+} // namespace device
+} // namespace jblas
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h
new file mode 100644
index 0000000000000..ceb7a545092d8
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h
@@ -0,0 +1,329 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include
+
+#include "jit_base.h"
+#include "jit_blas.h"
+#include "jit_blas_utils.h"
+#include "kernel_wrapper.h"
+
+namespace jblas {
+namespace epilogue {
+namespace gemm {
+
+template
+class AccumulatorWriteBack {
+ public:
+ using SType = _SRC_T;
+ using DType = _DST_T;
+ struct Param {
+ DType* C;
+ int ldc;
+ void* elt_const_v;
+ };
+
+ template
+ JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) {
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ bool constexpr Valid = !std::is_same::value || std::is_same::value;
+ static_assert(Valid, "fp32 to bf16 conversion only.");
+ if constexpr (std::is_same::value) {
+ return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward(
+ const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false);
+ } else if constexpr (std::is_same, std::tuple>::value) {
+ return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward(
+ const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false);
+ } else if constexpr (sizeof(SType) == sizeof(DType)) {
+ return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep,
+ _param.ldc, _param.elt_const_v, ops...);
+ } else {
+ assert(false);
+ }
+ }
+};
+
+template
+class CustomAccumulatorWriteBackWithEltop {
+ public:
+ struct Param {
+ _DST_T* C;
+ int ldc;
+ void* elt_const_v;
+ };
+ JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize) {
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) {
+ return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep,
+ _param.ldc, _param.elt_const_v);
+ } else {
+ assert(false);
+ }
+ }
+};
+template
+using AccumulatorWriteBackFp32 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackInt32 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackBf16 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackFp16 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack;
+
+template
+using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop;
+
+template
+using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop;
+
+template
+class AlphaBetaProcessFp32 {
+ public:
+ struct Param {
+ float *C, *D;
+ int ldc, ldd;
+ float alpha, beta;
+ };
+
+ JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize) {
+ auto DOffset = M_offset * _param.ldd + N_offset;
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ auto dptr = _param.D + DOffset;
+ return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta,
+ dptr, _param.ldd, cptr, _param.ldc, M, N);
+ }
+};
+
+template
+class CompFp32BlockEpilogue {
+ public:
+ struct Param {
+ void* scales;
+ JBLAS_DTYPE scaledtype;
+ int ldsb;
+ int8_t* zps = nullptr;
+ float* reduce = nullptr;
+ int ldra;
+ };
+ JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset,
+ const int K_offset, const int M, const int N, const Param& _param, void* tmpcache,
+ size_t cachesize) {
+ auto ret = JblasNotSupport;
+ if (_param.scaledtype == JBLAS_DTYPE::F32) {
+ ret = kernel::wrapper::CompFp32BlockScale::template forward(
+ reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr,
+ cachestep, M, N);
+ assert(ret == JblasSuccess);
+ if (_param.zps != nullptr) {
+ ret = kernel::wrapper::RemoveZeroPointBias::forward_wei(
+ dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset,
+ reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra,
+ _param.reduce + M_offset * _param.ldra + K_offset);
+ }
+ assert(ret == JblasSuccess);
+ return ret;
+ } else if (_param.scaledtype == JBLAS_DTYPE::BF16) {
+ ret = kernel::wrapper::CompFp32BlockScale::template forward(
+ reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr,
+ cachestep, M, N);
+ assert(_param.zps == nullptr);
+ assert(ret == JblasSuccess);
+ return ret;
+ }
+ return JblasNotSupport;
+ }
+};
+
+template
+class DequantInt32ToFp32 {
+ public:
+ struct Param {
+ float* C;
+ int ldc;
+ int ldsa;
+ float* scalesA;
+ float* scalesB;
+ };
+ JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize) {
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N,
+ _param.scalesA + M_offset * _param.ldsa, _param.ldsa,
+ _param.scalesB + N_offset);
+ }
+};
+
+template
+class CompInt8BlockEpilogue {
+ public:
+ struct Param {
+ void* scalesB;
+ JBLAS_DTYPE scaleBdtype;
+ int ldsb;
+ float* scalesA;
+ int ldsa;
+ // optional if A asym
+ uint8_t* zpA = nullptr;
+ void* reduceB = nullptr;
+ JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32;
+ // optional if B asym
+ int8_t* zpB = nullptr;
+ float* reduceA = nullptr;
+ int K = 1;
+ };
+ JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset,
+ const int K_offset, const int M, const int N, const Param& _param, void* tmpcache,
+ size_t cachesize) {
+ JBLAS_CODE ret = JblasNotSupport;
+ float* scab = nullptr;
+ size_t ScaleBTmpSize = N * sizeof(float);
+ size_t ReduceBTmpSize = N * sizeof(float);
+ assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize));
+ if (_param.scaleBdtype == JBLAS_DTYPE::BF16) {
+ auto scache = reinterpret_cast(tmpcache);
+ ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward(
+ reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N,
+ false);
+ assert(ret == JblasSuccess);
+ scab = scache;
+ } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) {
+ scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb;
+ }
+ float* redb = nullptr;
+ if (_param.reduceB) {
+ if (_param.reduceBdtype == JBLAS_DTYPE::BF16) {
+ auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize);
+ ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward(
+ reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N,
+ false);
+ assert(ret == JblasSuccess);
+ redb = rcache;
+ } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) {
+ redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb;
+ }
+ }
+ ret = kernel::wrapper::DequanS32Fp32::template forward(
+ srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N,
+ _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab);
+ assert(ret == JblasSuccess);
+ ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep,
+ dstptr, cachestep, M, N);
+ assert(ret == JblasSuccess);
+
+ if (_param.zpA == nullptr) {
+ if (_param.zpB == nullptr) {
+ return ret;
+ } else {
+ ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei(
+ dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa,
+ _param.reduceA + M_offset * _param.ldsa + K_offset);
+ }
+ } else {
+ if (_param.zpB == nullptr) {
+ ret = kernel::wrapper::RemoveZeroPointBias::template forward_act(
+ dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset,
+ _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb);
+ } else {
+ ret = kernel::wrapper::RemoveZeroPointBias::template forward_both(
+ dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset,
+ _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab,
+ _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb);
+ }
+ }
+ return ret;
+ }
+};
+
+template
+class ZpDequantInt32ToFp32 {
+ public:
+ struct Param {
+ // necessary
+ float* C;
+ int ldc;
+ int ldsa;
+ float* scalesA;
+ float* scalesB;
+ // optional if A asym
+ uint8_t* zpA = nullptr;
+ float* reduceB = nullptr;
+ // optional if B asym
+ int8_t* zpB = nullptr;
+ float* reduceA = nullptr;
+ int K = 1;
+ };
+ JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize) {
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N,
+ _param.scalesA + M_offset * _param.ldsa,
+ _param.ldsa, _param.scalesB + N_offset);
+ if (ret != JblasSuccess) {
+ return ret;
+ }
+ if (_param.zpA == nullptr && _param.zpB == nullptr) {
+ return ret;
+ } else if (_param.zpA != nullptr && _param.zpB == nullptr) {
+ ret = kernel::wrapper::RemoveZeroPointBias::template forward_act(
+ cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa,
+ _param.ldsa, _param.reduceB + N_offset);
+ } else if (_param.zpA == nullptr && _param.zpB != nullptr) {
+ ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei(
+ cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa,
+ _param.reduceA + M_offset * _param.ldsa);
+ } else {
+ ret = kernel::wrapper::RemoveZeroPointBias::template forward_both(
+ cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset,
+ _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K,
+ _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset);
+ }
+ return ret;
+ }
+};
+
+template
+class AlphaBetaProcessS32U8 {
+ public:
+ struct Param {
+ uint8_t* C;
+ int ldc;
+ float alpha;
+ float scaleAcc, scaleC;
+ int zpC;
+ };
+
+ JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize) {
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc,
+ M, N, _param.scaleAcc, _param.scaleC, _param.zpC);
+ }
+};
+
+} // namespace gemm
+} // namespace epilogue
+} // namespace jblas
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h
new file mode 100644
index 0000000000000..364da9223940f
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h
@@ -0,0 +1,2699 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include
+
+#include "jit_blas_utils.h"
+#include "jit_base.h"
+
+namespace jblas {
+namespace gemm {
+enum class CompType : uint32_t {
+ COMP_FP32 = 0,
+ COMP_BF16_FP32 = 1,
+ COMP_FP16_FP16 = 2,
+ COMP_INT_START = 3,
+ COMP_INT8_US_INT32 = COMP_INT_START,
+ COMP_INT8_UU_INT32 = 4,
+ COMP_INT8_SS_INT32 = 5,
+ COMP_INT8_SU_INT32 = 6,
+ COMP_INT16_SS_INT32 = 7,
+ COMP_INT8_US_FP32 = 8,
+ COMP_INT8_UU_FP32 = 9,
+ COMP_INT8_SS_FP32 = 10,
+ COMP_INT8_SU_FP32 = 11,
+};
+
+class CoreAttr {
+ public:
+ // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**|
+ static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8,
+ COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24;
+
+ static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; }
+ static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) {
+ return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT);
+ }
+
+ static void parse_id(uint32_t id, uint32_t* vals) {
+ vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT);
+ vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT);
+ vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT);
+ vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT);
+ }
+
+ static const char* to_str(uint32_t id) {
+ static char tmp[128];
+ uint32_t vals[4];
+ parse_id(id, vals);
+ sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]);
+ return tmp;
+ }
+
+ static inline size_t get_bsize(uint32_t id) {
+ auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT);
+ return size_t(4 / packrow);
+ }
+};
+
+namespace code {
+
+template
+class Avx2N8P1 : protected jblas::xbyak::JitAvx2 {
+ public:
+ static int constexpr RegLen = 8, PackRow = 1;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32;
+ typedef float AType;
+ typedef float BType;
+ typedef float CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_ret = rax;
+ Xbyak::Opmask msk_wr = k1;
+
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = RegCount - ARegCount - CRegCount;
+ if (BRegCount < NRegs) {
+ BRegCount = 0;
+ ARegCount = BRegCount + 1;
+ }
+ if (BRegCount > NRegs) {
+ BRegCount = NRegs;
+ }
+ CReg = 0;
+ BReg = CReg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg <= RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 10, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _ktile) {
+ for (int kk = 0; kk < _ktile; kk++) {
+ lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]);
+ if (BRegCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ } else if (BRegCount == 0) {
+ for (int mm = 0; mm < _mtile; mm += ARegCount) {
+ int mm_re = utils::remainsize(mm, _mtile, ARegCount);
+ for (int imm = 0; imm < mm_re; imm++) {
+ vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm),
+ ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ }
+ }
+ } else {
+ assert(0);
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f {
+ public:
+ static int constexpr RegLen = 16, PackRow = 1;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32;
+ typedef float AType;
+ typedef float BType;
+ typedef float CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_ret = rax;
+ Xbyak::Opmask msk_wr = k1;
+
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = RegCount - ARegCount - CRegCount;
+ if (BRegCount < NRegs) {
+ BRegCount = 0;
+ ARegCount = BRegCount + 1;
+ }
+ if (BRegCount > NRegs) {
+ BRegCount = NRegs;
+ }
+ CReg = 0;
+ BReg = CReg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg <= RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 10, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _ktile) {
+ for (int kk = 0; kk < _ktile; kk++) {
+ lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]);
+ if (BRegCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ } else if (BRegCount == 0) {
+ for (int mm = 0; mm < _mtile; mm += ARegCount) {
+ int mm_re = utils::remainsize(mm, _mtile, ARegCount);
+ for (int imm = 0; imm < mm_re; imm++) {
+ vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm),
+ ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ }
+ }
+ } else {
+ assert(0);
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 {
+ public:
+ static int constexpr RegLen = 32, PackRow = 1;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16;
+ typedef utils::fp16 AType;
+ typedef utils::fp16 BType;
+ typedef utils::fp16 CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_ret = rax;
+ Xbyak::Opmask msk_wr = k1;
+
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = RegCount - ARegCount - CRegCount;
+ if (BRegCount < NRegs) {
+ BRegCount = 0;
+ ARegCount = BRegCount + 1;
+ }
+ if (BRegCount > NRegs) {
+ BRegCount = NRegs;
+ }
+ CReg = 0;
+ BReg = CReg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg <= RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 10, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _ktile) {
+ for (int kk = 0; kk < _ktile; kk++) {
+ lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]);
+ if (BRegCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ } else if (BRegCount == 0) {
+ for (int mm = 0; mm < _mtile; mm += ARegCount) {
+ int mm_re = utils::remainsize(mm, _mtile, ARegCount);
+ for (int imm = 0; imm < mm_re; imm++) {
+ vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm),
+ ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ }
+ }
+ } else {
+ assert(0);
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 {
+ public:
+ static int constexpr RegLen = 16, PackRow = 2;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32;
+ typedef utils::bf16 AType;
+ typedef utils::bf16 BType;
+ typedef float CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_ret = rax;
+ Xbyak::Opmask msk_wr = k1;
+
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = RegCount - ARegCount - CRegCount;
+ if (BRegCount < NRegs) {
+ BRegCount = 0;
+ ARegCount = BRegCount + 1;
+ }
+ if (BRegCount > NRegs) {
+ BRegCount = NRegs;
+ }
+ CReg = 0;
+ BReg = CReg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg <= RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 10, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _ktile) {
+ for (int kk = 0; kk < _ktile; kk++) {
+ lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]);
+ if (BRegCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ } else if (BRegCount == 0) {
+ for (int mm = 0; mm < _mtile; mm += ARegCount) {
+ int mm_re = utils::remainsize(mm, _mtile, ARegCount);
+ for (int imm = 0; imm < mm_re; imm++) {
+ vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm),
+ ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ }
+ }
+ } else {
+ assert(0);
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni {
+ public:
+ static int constexpr RegLen = 16, PackRow = 4;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32;
+ typedef uint8_t AType;
+ typedef int8_t BType;
+ typedef int32_t CType;
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ private:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_ret = rax;
+
+ protected:
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = RegCount - ARegCount - CRegCount;
+ if (BRegCount < NRegs) {
+ BRegCount = 0;
+ ARegCount = BRegCount + 1;
+ }
+ if (BRegCount > NRegs) {
+ BRegCount = NRegs;
+ }
+ CReg = 0;
+ BReg = CReg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg <= RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel();
+ Xbyak::util::StackFrame st(this, 1, 10, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _kunroll) {
+ for (int kk = 0; kk < _kunroll; kk++) {
+ lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]);
+ if (BRegCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ } else if (BRegCount == 0) {
+ for (int mm = 0; mm < _mtile; mm += ARegCount) {
+ int mm_re = utils::remainsize(mm, _mtile, ARegCount);
+ for (int imm = 0; imm < mm_re; imm++) {
+ vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm),
+ ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ }
+ }
+ } else {
+ assert(0);
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni {
+ public:
+ static int constexpr RegLen = 8, PackRow = 4;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32;
+ typedef uint8_t AType;
+ typedef int8_t BType;
+ typedef int32_t CType;
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ private:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_ret = rax;
+ Xbyak::Opmask msk_wr = k1;
+
+ protected:
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = RegCount - ARegCount - CRegCount;
+ if (BRegCount < NRegs) {
+ BRegCount = 0;
+ ARegCount = BRegCount + 1;
+ }
+ if (BRegCount > NRegs) {
+ BRegCount = NRegs;
+ }
+ CReg = 0;
+ BReg = CReg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg <= RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel();
+ Xbyak::util::StackFrame st(this, 1, 10, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _kunroll) {
+ for (int kk = 0; kk < _kunroll; kk++) {
+ lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]);
+ if (BRegCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ } else if (BRegCount == 0) {
+ for (int mm = 0; mm < _mtile; mm += ARegCount) {
+ int mm_re = utils::remainsize(mm, _mtile, ARegCount);
+ for (int imm = 0; imm < mm_re; imm++) {
+ vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm),
+ ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ }
+ }
+ } else {
+ assert(0);
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 {
+ public:
+ static int constexpr RegLen = 16, PackRow = 2;
+ static_assert(_NTILE % RegLen == 0);
+ static_assert(_MTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen;
+ static_assert(NRegs * MRegs + 2 <= TileCount);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32;
+ typedef utils::bf16 AType;
+ typedef utils::bf16 BType;
+ typedef float CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ void* workspace;
+ };
+ typedef long long (*func_t)(params*);
+
+ int TmpRegCount = RegCount;
+ int TmpReg = 0;
+ int CTileCount = 0, ATileCount = 0, BTileCount = 0;
+ int CTile = 0, ATile = 0, BTile = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_tmp3;
+ Xbyak::Reg64 reg_ret = rax;
+
+ void assign_regs() {
+ CTileCount = NRegs * MRegs;
+ auto tile_re = TileCount - CTileCount;
+ if (tile_re - 1 >= NRegs) {
+ BTileCount = NRegs;
+ ATileCount = tile_re - BTileCount;
+ } else if (tile_re - 1 >= MRegs) {
+ ATileCount = MRegs;
+ BTileCount = tile_re - ATileCount;
+ } else {
+ ATileCount = 1;
+ BTileCount = tile_re - ATileCount;
+ }
+ CTile = 0;
+ ATile = CTile + CTileCount;
+ BTile = ATile + ATileCount;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 11, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_tmp3 = st.t[10];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int kunrll) {
+ auto& reg_Bstride = reg_tmp1;
+ mov(reg_Bstride, NTILE * 4);
+ int mtiles = _mtile / RegLen;
+
+ for (int kk = 0; kk < kunrll; kk++) {
+ auto& reg_Atmp = reg_tmp2;
+ if (mtiles == 1) {
+ reg_Atmp = reg_matAptr;
+ } else {
+ mov(reg_Atmp, reg_matAptr);
+ }
+ if (BTileCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
+ }
+ for (int mm = 0; mm < mtiles; mm++) {
+ tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
+ for (int i = 0; i < NRegs; i++) {
+ tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i));
+ }
+ if (mm != mtiles - 1) {
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ }
+ }
+ } else {
+ if (ATileCount == mtiles) {
+ for (int mm = 0; mm < mtiles; mm++) {
+ tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
+ if (mm != mtiles - 1) {
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ }
+ }
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
+ for (int mm = 0; mm < mtiles; mm++) {
+ tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile));
+ }
+ }
+ } else {
+ for (int mm = 0; mm < mtiles; mm++) {
+ tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
+ tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile));
+ }
+ if (mm != mtiles - 1) {
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < CTileCount; i++) {
+ tilezero(Xbyak::Tmm(CTile + i));
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ int mtnum = _mtile / 16;
+ for (int mm = 0; mm < mtnum; mm++) {
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]);
+ }
+ if (mm != mtnum - 1) {
+ lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]);
+ lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]);
+ }
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, dword[parambase + OFFSET(workspace)]);
+ mov(reg_tmp1, NTILE * 4);
+ for (int mm = 0; mm < MRegs; mm++) {
+ for (int i = 0; i < NRegs; i++) {
+ tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i));
+ }
+ }
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ int zunroll = TmpRegCount / NRegs;
+ for (int i = 0; i < _mtile; i += zunroll) {
+ int m_re = utils::remainsize(i, _mtile, zunroll);
+ for (int im = 0; im < m_re; im++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]);
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 {
+ public:
+ static int constexpr RegLen = 16, PackRow = 4;
+ static_assert(_NTILE % RegLen == 0);
+ static_assert(_MTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen;
+ static_assert(NRegs * MRegs + 2 <= TileCount);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8;
+ static uint32_t constexpr COMPUTE =
+ (uint32_t)(std::is_same_v
+ ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32
+ : std::is_same_v ? CompType::COMP_INT8_US_INT32
+ : CompType::COMP_INT8_UU_INT32);
+ using AType = AT;
+ using BType = BT;
+ typedef int32_t CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ void* workspace;
+ };
+ typedef long long (*func_t)(params*);
+
+ int TmpRegCount = RegCount;
+ int TmpReg = 0;
+ int CTileCount = 0, ATileCount = 0, BTileCount = 0;
+ int CTile = 0, ATile = 0, BTile = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_tmp3;
+ Xbyak::Reg64 reg_ret = rax;
+
+ void assign_regs() {
+ CTileCount = NRegs * MRegs;
+ auto tile_re = TileCount - CTileCount;
+ if (tile_re - 1 >= NRegs) {
+ BTileCount = NRegs;
+ ATileCount = tile_re - BTileCount;
+ } else if (tile_re - 1 >= MRegs) {
+ ATileCount = MRegs;
+ BTileCount = tile_re - ATileCount;
+ } else {
+ ATileCount = 1;
+ BTileCount = tile_re - ATileCount;
+ }
+ CTile = 0;
+ ATile = CTile + CTileCount;
+ BTile = ATile + ATileCount;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 11, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_tmp3 = st.t[10];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int kunrll) {
+ auto& reg_Bstride = reg_tmp1;
+ mov(reg_Bstride, NTILE * 4);
+ int mtiles = _mtile / RegLen;
+
+ for (int kk = 0; kk < kunrll; kk++) {
+ auto& reg_Atmp = reg_tmp2;
+ if (mtiles == 1) {
+ reg_Atmp = reg_matAptr;
+ } else {
+ mov(reg_Atmp, reg_matAptr);
+ }
+ if (BTileCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
+ }
+ for (int mm = 0; mm < mtiles; mm++) {
+ tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
+ for (int i = 0; i < NRegs; i++) {
+ _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i));
+ }
+ if (mm != mtiles - 1) {
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ }
+ }
+ } else {
+ if (ATileCount == mtiles) {
+ for (int mm = 0; mm < mtiles; mm++) {
+ tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
+ if (mm != mtiles - 1) {
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ }
+ }
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
+ for (int mm = 0; mm < mtiles; mm++) {
+ _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile));
+ }
+ }
+ } else {
+ for (int mm = 0; mm < mtiles; mm++) {
+ tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]);
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]);
+ _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile));
+ }
+ if (mm != mtiles - 1) {
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < CTileCount; i++) {
+ tilezero(Xbyak::Tmm(CTile + i));
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ int mtnum = _mtile / 16;
+ for (int mm = 0; mm < mtnum; mm++) {
+ for (int i = 0; i < NRegs; i++) {
+ tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]);
+ }
+ if (mm != mtnum - 1) {
+ lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]);
+ lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]);
+ }
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, dword[parambase + OFFSET(workspace)]);
+ mov(reg_tmp1, NTILE * 4);
+ for (int mm = 0; mm < MRegs; mm++) {
+ for (int i = 0; i < NRegs; i++) {
+ tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i));
+ }
+ }
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ int zunroll = TmpRegCount / NRegs;
+ for (int i = 0; i < _mtile; i += zunroll) {
+ int m_re = utils::remainsize(i, _mtile, zunroll);
+ for (int im = 0; im < m_re; im++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]);
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ }
+ outLocalLabel();
+ }
+};
+template
+using Amxint8N16P4US = Amxint8N16P4;
+
+template
+using Amxint8N16P4SS = Amxint8N16P4;
+
+class AmxConfigure : protected jblas::xbyak::JitAmxtile {
+ public:
+ typedef long long (*func_t)(tileconfig_t*);
+
+ static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) {
+ static AmxConfigure code;
+ tileconfig_t cfg;
+ std::memset(&cfg, 0, sizeof(cfg));
+ configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum);
+ code.mKernel(&cfg);
+ }
+
+ protected:
+ AmxConfigure() {
+ generate_config(this);
+ mKernel = getCode();
+ }
+
+ func_t mKernel = nullptr;
+};
+
+namespace kblock {
+// optimize for kblock gemm, each block size in k dimension has dequant operation
+// all accumulators use fp32 dtype.
+template
+class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f {
+ public:
+ static int constexpr RegLen = 16, PackRow = 1;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32;
+ typedef float AType;
+ typedef float BType;
+ typedef float CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ int k;
+ int n;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_ret = rax;
+ Xbyak::Opmask msk_wr = k1;
+
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = RegCount - ARegCount - CRegCount;
+ if (BRegCount < NRegs) {
+ BRegCount = 0;
+ ARegCount = BRegCount + 1;
+ }
+ if (BRegCount > NRegs) {
+ BRegCount = NRegs;
+ }
+ CReg = 0;
+ BReg = CReg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg <= RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 10, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ mov(reg_tmp, reg_ksize);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kloop", T_NEAR);
+ L(".unkloop");
+ generate_fma(_mtile, KUNROLL);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_iterk, KUNROLL * KTILE);
+ cmp(reg_iterk, reg_tmp); // k iteration variable
+ jb(".unkloop");
+ cmp(reg_tmp, reg_ksize);
+ jge(".kend", T_NEAR);
+ L(".kloop");
+ generate_fma(_mtile, 1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_iterk, 1 * KTILE);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+ L(".kend");
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _ktile) {
+ for (int kk = 0; kk < _ktile; kk++) {
+ lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]);
+ if (BRegCount == NRegs) {
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ } else if (BRegCount == 0) {
+ for (int mm = 0; mm < _mtile; mm += ARegCount) {
+ int mm_re = utils::remainsize(mm, _mtile, ARegCount);
+ for (int imm = 0; imm < mm_re; imm++) {
+ vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm),
+ ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ }
+ }
+ } else {
+ assert(0);
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+template
+class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni {
+ public:
+ static int constexpr RegLen = 16, PackRow = 4;
+ static_assert(_NTILE % RegLen == 0);
+ static int constexpr NRegs = _NTILE / RegLen;
+ static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE;
+ static_assert(NRegs * MRegs <= RegCount - 1);
+ static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4;
+ static int constexpr KUNROLL = 2;
+ static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI;
+ static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32;
+ typedef uint8_t AType;
+ typedef int8_t BType;
+ typedef float CType;
+
+ struct params {
+ AType* matA;
+ int astride;
+ BType* matB;
+ int bstride;
+ CType* matC;
+ int cstride;
+ uint8_t* zpA;
+ float* scaleA;
+ int ldsa;
+ float* scaleB;
+ float* reduceB;
+ int ldsb;
+ int k;
+ int n;
+ int kblock;
+ int init;
+ };
+ typedef long long (*func_t)(params*);
+
+ int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0;
+ int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0;
+ static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType);
+ static int constexpr AKStepSize = KTILE * sizeof(AType);
+
+ void generate_code(int _mtile) {
+ assign_regs();
+ reset();
+ generate_mtile(_mtile);
+ ready();
+ mKernel = getCode();
+ }
+ func_t mKernel = nullptr;
+
+ protected:
+ Xbyak::Reg64 parambase;
+ Xbyak::Reg64 reg_matAptr;
+ Xbyak::Reg64 reg_matBptr;
+ Xbyak::Reg64 reg_matCptr;
+ Xbyak::Reg64 reg_ksize;
+ Xbyak::Reg64 reg_nsize;
+ Xbyak::Reg64 reg_cstride;
+ Xbyak::Reg64 reg_astride;
+ Xbyak::Reg64 reg_iterk;
+ Xbyak::Reg64 reg_iterkb;
+ Xbyak::Reg64 reg_itern;
+ Xbyak::Reg64 reg_tmp;
+ Xbyak::Reg64 reg_tmp1;
+ Xbyak::Reg64 reg_tmp2;
+ Xbyak::Reg64 reg_tmp3;
+ Xbyak::Reg64 reg_tmp4;
+ Xbyak::Reg64 reg_ret = rax;
+
+ void assign_regs() {
+ CRegCount = MRegs * NRegs;
+ ARegCount = 1;
+ BRegCount = NRegs;
+ CReg = 0;
+ CF32Reg = CReg + CRegCount;
+ BReg = CF32Reg + CRegCount;
+ AReg = BReg + BRegCount;
+ TmpReg = AReg + ARegCount;
+ assert(TmpReg < RegCount);
+ TmpRegCount = RegCount - TmpReg;
+ assert(TmpRegCount >= 1);
+ }
+
+ void generate_mtile(int _mtile) {
+ inLocalLabel(); // use local label for multiple instance
+ Xbyak::util::StackFrame st(this, 1, 13, 16 * 10);
+ parambase = st.p[0];
+ reg_matAptr = st.t[0];
+ reg_matBptr = st.t[1];
+ reg_matCptr = st.t[0];
+ reg_ksize = st.t[2];
+ reg_astride = st.t[3];
+ reg_cstride = st.t[3];
+ reg_iterk = st.t[4];
+ reg_iterkb = st.t[12];
+ reg_tmp = st.t[5];
+ reg_tmp1 = st.t[6];
+ reg_tmp2 = st.t[7];
+ reg_tmp3 = st.t[10];
+ reg_tmp4 = st.t[11];
+ reg_nsize = st.t[8];
+ reg_itern = st.t[9];
+ reg_ret = rax;
+
+ vreg_push(rsp);
+
+ load32(reg_ksize, ptr[parambase + OFFSET(k)]);
+ load32(reg_nsize, ptr[parambase + OFFSET(n)]);
+ xor_(reg_itern, reg_itern);
+ L(".nloop");
+ init_regs(_mtile);
+ mov(reg_matAptr, ptr[parambase + OFFSET(matA)]);
+ load32(reg_astride, ptr[parambase + OFFSET(astride)]);
+ mov(reg_matBptr, ptr[parambase + OFFSET(matB)]);
+ load32(reg_tmp, ptr[parambase + OFFSET(bstride)]);
+ imul(reg_tmp, reg_itern);
+ lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]);
+ xor_(reg_iterk, reg_iterk);
+ generate_kloop(_mtile);
+ write_back(_mtile);
+ add(reg_itern, NTILE);
+ cmp(reg_itern, reg_nsize);
+ jb(".nloop");
+ mov(reg_ret, 0);
+ vreg_pop(rsp);
+
+ outLocalLabel(); // end of local label
+ }
+
+ void generate_kloop(int _mtile) {
+ inLocalLabel();
+ xor_(reg_iterkb, reg_iterkb);
+ L(".kloop");
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j));
+ }
+ }
+ xor_(reg_tmp2, reg_tmp2);
+ load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]);
+ mov(reg_tmp, reg_tmp3);
+ padto_le(reg_tmp, KUNROLL * KTILE);
+ cmp(reg_tmp, 0);
+ jz(".kbloop", T_NEAR);
+ L(".unkbloop");
+ generate_fma(_mtile, KUNROLL, reg_tmp1);
+ add(reg_matAptr, KUNROLL * AKStepSize);
+ add(reg_matBptr, KUNROLL * BKStepSize);
+ add(reg_tmp2, KUNROLL * KTILE);
+ cmp(reg_tmp2, reg_tmp);
+ jb(".unkbloop");
+ cmp(reg_tmp, reg_tmp3);
+ jge(".kend", T_NEAR);
+ L(".kbloop");
+ generate_fma(_mtile, 1, reg_tmp1);
+ add(reg_matAptr, 1 * AKStepSize);
+ add(reg_matBptr, 1 * BKStepSize);
+ add(reg_tmp2, 1 * KTILE);
+ cmp(reg_tmp2, reg_tmp3);
+ jb(".kbloop");
+ L(".kend");
+ add(reg_iterk, reg_tmp2);
+ generate_f32_accumulate(_mtile);
+ generate_zp_correction(_mtile);
+ inc(reg_iterkb);
+ cmp(reg_iterk, reg_ksize); // k iteration variable
+ jb(".kloop");
+
+ outLocalLabel();
+ }
+
+ void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) {
+ for (int kk = 0; kk < _ktile; kk++) {
+ lea(tmp, ptr[reg_matAptr + kk * AKStepSize]);
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]);
+ add(reg_tmp1, reg_astride);
+ for (int i = 0; i < NRegs; i++) {
+ vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i));
+ }
+ }
+ }
+ }
+
+ void init_regs(int _mtile) {
+ inLocalLabel();
+ load32(reg_tmp, ptr[parambase + OFFSET(init)]);
+ cmp(reg_tmp, 0);
+ je(".read", T_NEAR);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j));
+ }
+ }
+ jmp(".end", T_NEAR);
+ L(".read");
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]);
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ L(".end");
+ outLocalLabel();
+ }
+
+ void generate_f32_accumulate(int _mtile) {
+ load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]);
+ imul(reg_tmp, reg_iterkb);
+ mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]);
+ lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]);
+ lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]);
+
+ mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]);
+ lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]);
+ load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]);
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]);
+ }
+ for (int mm = 0; mm < _mtile; mm++) {
+ vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]);
+ lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]);
+ for (int i = 0; i < NRegs; i++) {
+ vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i));
+ vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i));
+ vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg));
+ vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i));
+ }
+ }
+ }
+
+ void generate_zp_correction(int _mtile) {
+ load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]);
+ imul(reg_tmp1, reg_iterkb);
+ mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]);
+ lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]);
+ lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]);
+ auto& reg_redB = reg_tmp2;
+
+ mov(reg_tmp, ptr[parambase + OFFSET(zpA)]);
+ lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]);
+ auto& reg_zpA = reg_tmp;
+
+ mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]);
+ lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]);
+ auto& reg_scaleA = reg_tmp1;
+
+ load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]);
+ auto& reg_ldsa = reg_tmp3;
+ for (int i = 0; i < NRegs; i++) {
+ vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]);
+ }
+
+ for (int i = 0; i < _mtile; i++) {
+ vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]);
+ vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg));
+ vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg));
+ vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]);
+ for (int j = 0; j < NRegs; j++) {
+ vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j));
+ vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j));
+ }
+ lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]);
+ lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]);
+ }
+ }
+
+ void write_back(int _mtile) {
+ inLocalLabel();
+ mov(reg_matCptr, ptr[parambase + OFFSET(matC)]);
+ load32(reg_cstride, ptr[parambase + OFFSET(cstride)]);
+ lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]);
+ for (int i = 0; i < _mtile; i++) {
+ for (int j = 0; j < NRegs; j++) {
+ vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j));
+ }
+ add(reg_matCptr, reg_cstride);
+ }
+ outLocalLabel();
+ }
+};
+
+} // namespace kblock
+} // namespace code
+template class CodeT, int _NTILE, int _MTILE = 0>
+class CoreCodeBase {
+ public:
+ using Code = CodeT<_NTILE, _MTILE>;
+ using AType = typename Code::AType;
+ using BType = typename Code::BType;
+ using CType = typename Code::CType;
+ static int constexpr NTILE = Code::NTILE;
+ static int constexpr MTILE = Code::MTILE;
+ static int constexpr KTILE = Code::KTILE;
+ static int constexpr PACK_ROW = Code::PackRow;
+ static int constexpr COMP = Code::COMPUTE;
+ static int constexpr PREFERRED_N = NTILE * 3;
+ static JBLAS_ISA constexpr ISA = (JBLAS_ISA)Code::ISA;
+ static uint32_t constexpr ID = CoreAttr::make_core_id(NTILE, PACK_ROW, COMP, ISA);
+ void configure() { (void)(0); }
+
+ protected:
+ CoreCodeBase() {
+ for (int i = 0; i < mCodes.size(); i++) {
+ mCodes[i].generate_code(i + 1);
+ }
+ }
+ std::array mCodes;
+};
+
+template