diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index bcd0b2a92a5c3..03e3f84547a68 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", + "commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -192,6 +192,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", + "repositoryUrl": "https://github.com/intel/neural-speed.git" + }, + "comments": "neural_speed" + } + }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index bc96218dac79e..94d650f685235 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -87,7 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -96,7 +96,6 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) @@ -706,20 +705,16 @@ if (onnxruntime_USE_CUDA) enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") + if (onnxruntime_DISABLE_CONTRIB_OPS) + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) - message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") - set(onnxruntime_USE_CUTLASS OFF) + message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") + set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() else() - set(onnxruntime_USE_CUTLASS OFF) -endif() - -if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS) - if (onnxruntime_DISABLE_CONTRIB_OPS) - message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") - else() - message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled") - endif() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -905,8 +900,8 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() - if (onnxruntime_USE_CUTLASS) - target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) + if(USE_NEURAL_SPEED) + target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) endif() set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) @@ -1193,14 +1188,10 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -set(USE_JBLAS FALSE) -if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD) - if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) - elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) +if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) + include(neural_speed) + if (USE_NEURAL_SPEED) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) endif() endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index ff07803013071..ba9c2bb73cf7a 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -34,6 +34,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 +neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 @@ -54,4 +55,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 \ No newline at end of file diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3bcd4109e2888..57cfbee4644ef 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -19,7 +19,7 @@ if(WIN32 AND NOT Patch_FOUND) set(ABSL_ENABLE_INSTALL ON) endif() # NB! Advancing Abseil version changes its internal namespace, -# currently absl::lts_20230125 which affects abseil-cpp.natvis debugger +# currently absl::lts_20240116 which affects abseil-cpp.natvis debugger # visualization file, that must be adjusted accordingly, unless we eliminate # that namespace at build time. FetchContent_Declare( diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 1e5a36fb9efb9..a4fb63b6a8377 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -1,6 +1,6 @@ - + @@ -24,7 +24,7 @@ - + @@ -51,7 +51,7 @@ - + *($T1 *){value} (*($T1 *){value}) @@ -60,7 +60,7 @@ - + *($T1 *)this (*($T1 *)this) @@ -68,7 +68,7 @@ - + {value.first}, {value.second} ({value.first}, {value.second}) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index efc708bd681c0..f04f4bec76cd5 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,13 +1,11 @@ -if (onnxruntime_USE_CUTLASS) - include(FetchContent) - FetchContent_Declare( - cutlass - URL ${DEP_URL_cutlass} - URL_HASH SHA1=${DEP_SHA1_cutlass} - ) +include(FetchContent) +FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} +) - FetchContent_GetProperties(cutlass) - if(NOT cutlass_POPULATED) - FetchContent_Populate(cutlass) - endif() +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) endif() diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake new file mode 100644 index 0000000000000..ed711351403a7 --- /dev/null +++ b/cmake/external/neural_speed.cmake @@ -0,0 +1,15 @@ +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + set(USE_NEURAL_SPEED TRUE) +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + set(USE_NEURAL_SPEED TRUE) +endif() + +if(USE_NEURAL_SPEED) + FetchContent_Declare( + neural_speed + URL ${DEP_URL_neural_speed} + URL_HASH SHA1=${DEP_SHA1_neural_speed} + ) + set(BTLA_USE_OPENMP OFF) + onnxruntime_fetchcontent_makeavailable(neural_speed) +endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b995b27123218..17de2aa4aaea6 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -57,15 +57,6 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) -function(add_jblas) - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) - target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/jblas_gemm.cpp - ) - set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) -endfunction() - #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -364,19 +355,23 @@ else() ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -622,10 +617,6 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() -if(USE_JBLAS) - add_jblas() -endif() - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f60faa4d39116..b81a5c79ac0cc 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() + set(onnxruntime_cpu_neural_speed_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" + ) + if(NOT USE_NEURAL_SPEED) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(USE_NEURAL_SPEED) + onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) + endif() +endif() + if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45c0e6f822ce9..22e82443167f6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
unidirectional : int
+
Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8) @@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
+
num_heads : int
+
Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
+
rotary_embedding_dim : int
+
Rotary embedding dimension. Default value is 0.
scale : float
Custom scale will be used if specified. Default value is 1.0
@@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
sin_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
#### Outputs @@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output types to float tensors.
M : tensor(int64)
Constrain input and output types to integer tensors
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 394bd7ad2abae..9ecc58bee0725 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -868,7 +868,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index ea4f52f99649d..1de0217c7e1fa 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -326,6 +326,15 @@ class IExecutionProvider { */ virtual std::vector CreatePreferredAllocators() { return std::vector(); }; + /** + * Get the array of pointers for EPContext nodes + * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it. + * Default return an empty vector if not provided by the Execution Provider + */ + virtual const InlinedVector GetEpContextNodes() const { + return InlinedVector(); + } + private: const std::string type_; diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 60196d0c80cbb..32a9f06464ace 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -11,6 +11,8 @@ /// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions. /// struct OrtTensorRTProviderOptionsV2 { + OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator + int device_id{0}; // cuda device id. int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. void* user_compute_stream{nullptr}; // user specified CUDA compute stream. @@ -46,8 +48,26 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT - int trt_dump_ep_context_model{0}; // Dump EP context node model - int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data - int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute - const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix + + /* + * Please note that there are rules for using following context model related provider options: + * + * 1. In the case of dumping the context model and loading the context model, + * for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be + * the absolute path or relative path that is outside of context model directory. + * It means engine cache needs to be in the same directory or sub-directory of context model. + * + * 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. + * For example: + * If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, + * if "trt_ep_context_file_path" is "./context_model_dir", + * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" + * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" + * + */ + int trt_dump_ep_context_model{0}; // Dump EP context node model + const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path. + int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data + + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b321b2b2bac27..64095a31ac2b8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3608,6 +3608,14 @@ struct OrtApi { * - "1": Faster preparation time, less optimal graph. * - "2": Longer preparation time, more optimal graph. * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index df79cb6e5b21b..b282438795eb5 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; -// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file. +// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) // "1": enable. @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index 9f7b8d3a3dcfc..464234c34798a 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes } } wchar_t* optimizerStr = NULL; - if (optimizerPath == NULL) { + if (optimizerPath != NULL) { optimizerStr = copyAndPad(jniEnv, optimizerPath); if (optimizerStr == NULL) { // exception has been thrown in Java, go to cleanup and return null. diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index d9f63fec9c492..31ecffb07e40c 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -31,6 +31,12 @@ export const initializeFlags = (): void => { } if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) { + // Web: when crossOriginIsolated is false, SharedArrayBuffer is not available so WebAssembly threads will not work. + // Node.js: onnxruntime-web does not support multi-threads in Node.js. + if ((typeof self !== 'undefined' && !self.crossOriginIsolated) || + (typeof process !== 'undefined' && process.versions && process.versions.node)) { + env.wasm.numThreads = 1; + } const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency; env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2)); } diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 81508a253ce8b..9b9334c93b78c 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -28,13 +28,34 @@ let initialized = false; let initializing = false; let aborted = false; -const isMultiThreadSupported = (): boolean => { - try { - // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. - if (typeof SharedArrayBuffer === 'undefined') { - return false; +const isMultiThreadSupported = (numThreads: number): boolean => { + // WebAssembly threads are set to 1 (single thread). + if (numThreads === 1) { + return false; + } + + // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. + if (typeof SharedArrayBuffer === 'undefined') { + if (typeof self !== 'undefined' && !self.crossOriginIsolated) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', but this will not work unless you enable crossOriginIsolated mode. ' + + 'See https://web.dev/cross-origin-isolation-guide/ for more info.'); } + return false; + } + + // onnxruntime-web does not support multi-threads in Node.js. + if (typeof process !== 'undefined' && process.versions && process.versions.node) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', however, currently onnxruntime-web does not support multi-threads in Node.js. ' + + 'Please consider using onnxruntime-node for performance critical scenarios.'); + } + try { // Test for transferability of SABs (for browsers. needed for Firefox) // https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ if (typeof MessageChannel !== 'undefined') { @@ -106,7 +127,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const numThreads = flags.numThreads!; const simd = flags.simd!; - const useThreads = numThreads > 1 && isMultiThreadSupported(); + const useThreads = isMultiThreadSupported(numThreads); const useSimd = simd && isSimdSupported(); const wasmPaths = flags.wasmPaths; diff --git a/js/web/lib/wasm/wasm-utils-load-file.ts b/js/web/lib/wasm/wasm-utils-load-file.ts index abe480a43c790..c6cdba2320bde 100644 --- a/js/web/lib/wasm/wasm-utils-load-file.ts +++ b/js/web/lib/wasm/wasm-utils-load-file.ts @@ -47,9 +47,19 @@ export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Pro } const reader = response.body.getReader(); - // use WebAssembly Memory to allocate larger ArrayBuffer - const pages = Math.ceil(fileSize / 65536); - const buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + let buffer; + try { + // try to create ArrayBuffer directly + buffer = new ArrayBuffer(fileSize); + } catch (e) { + if (e instanceof RangeError) { + // use WebAssembly Memory to allocate larger ArrayBuffer + const pages = Math.ceil(fileSize / 65536); + buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + } else { + throw e; + } + } let offset = 0; // eslint-disable-next-line no-constant-condition diff --git a/js/web/package-lock.json b/js/web/package-lock.json index cd71c20ba4d2f..74cd0d81a3943 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -28,7 +28,7 @@ "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", - "electron": "^23.1.2", + "electron": "^28.1.4", "globby": "^13.1.3", "karma": "^6.4.1", "karma-browserstack-launcher": "^1.6.0", @@ -862,9 +862,9 @@ } }, "node_modules/cross-spawn/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -1042,14 +1042,14 @@ "dev": true }, "node_modules/electron": { - "version": "23.3.13", - "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz", - "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==", + "version": "28.1.4", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", + "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", "dev": true, "hasInstallScript": true, "dependencies": { "@electron/get": "^2.0.0", - "@types/node": "^16.11.26", + "@types/node": "^18.11.18", "extract-zip": "^2.0.1" }, "bin": { @@ -1059,12 +1059,6 @@ "node": ">= 12.20.55" } }, - "node_modules/electron/node_modules/@types/node": { - "version": "16.18.14", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz", - "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==", - "dev": true - }, "node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", @@ -1432,9 +1426,9 @@ } }, "node_modules/get-func-name": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz", - "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz", + "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==", "dev": true, "engines": { "node": "*" @@ -1542,9 +1536,9 @@ } }, "node_modules/global-agent/node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "optional": true, "dependencies": { @@ -2908,9 +2902,9 @@ "dev": true }, "node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, "bin": { "semver": "bin/semver.js" @@ -4203,9 +4197,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } @@ -4339,22 +4333,14 @@ "dev": true }, "electron": { - "version": "23.3.13", - "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz", - "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==", + "version": "28.1.4", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", + "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", "dev": true, "requires": { "@electron/get": "^2.0.0", - "@types/node": "^16.11.26", + "@types/node": "^18.11.18", "extract-zip": "^2.0.1" - }, - "dependencies": { - "@types/node": { - "version": "16.18.14", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz", - "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==", - "dev": true - } } }, "emoji-regex": { @@ -4657,9 +4643,9 @@ "dev": true }, "get-func-name": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz", - "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz", + "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==", "dev": true }, "get-intrinsic": { @@ -4742,9 +4728,9 @@ }, "dependencies": { "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "optional": true, "requires": { @@ -5780,9 +5766,9 @@ "dev": true }, "semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true }, "semver-compare": { diff --git a/js/web/package.json b/js/web/package.json index 7ffc9ba16aaa9..047de382943e6 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -47,7 +47,7 @@ "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", - "electron": "^23.1.2", + "electron": "^28.1.4", "globby": "^13.1.3", "karma": "^6.4.1", "karma-browserstack-launcher": "^1.6.0", diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 4711ccf487cc8..768676259aa14 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const { relative_position_bias, ¶meters)); + if (parameters.do_rotary) { + ORT_NOT_IMPLEMENTED( + "Rotary embedding is not supported in Attention CPU kernel. \ + Please fuse the model with MHA + RotaryEmbedding."); + } + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int input_hidden_size = parameters.input_hidden_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 694c40bf3eda6..eb25d0fd7cc1e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i num_heads_ = static_cast(num_heads); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } // Reshape Q/K/V from BxSxD to BxSxNxH @@ -283,8 +284,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { nullptr, ¶meters, num_heads_, - scale, mask_filter_value_, + scale, + is_unidirectional_, past_present_share_buffer, false)); diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index 4c86b777e9842..fb7da78a5c0a5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { protected: int num_heads_; // number of attention heads float mask_filter_value_; + bool is_unidirectional_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 00e82c9844b3d..c91f5b601b4e9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -25,6 +25,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing) { // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None @@ -315,7 +316,7 @@ Status CheckInputs(const T* query, output_parameters->head_size = hidden_size / num_heads; output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; - output_parameters->is_unidirectional = false; + output_parameters->is_unidirectional = is_unidirectional; output_parameters->past_present_share_buffer = past_present_share_buffer; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; @@ -342,6 +343,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing, int max_threads_per_block) { @@ -350,8 +352,8 @@ Status CheckInputs(const T* query, } return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, - past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer, - dmmha_packing); + past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, + past_present_share_buffer, dmmha_packing); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 47f462d75fcc4..aa8b5b5f608fa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + + if (rotary_embedding_dim > 0) { + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } } template @@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; + const int n_heads = parameters.num_heads; const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; - const int half_head_size = head_size / 2; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] int head_stride = head_size; - int seq_stride = num_heads * head_stride; + int seq_stride = n_heads * head_stride; int batch_stride = sequence_length * seq_stride; if (parameters.transposed) { - // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] seq_stride = head_size; head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; + batch_stride = n_heads * head_stride; } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - const int loop_len = batch_size * sequence_length * num_heads; - const double cost = static_cast(head_size); + const int loop_len = batch_size * sequence_length * n_heads; + const double cost = static_cast(rotary_emb_dim); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / num_heads) / sequence_length); - const int s = static_cast((ptr / num_heads) % sequence_length); - const int n = static_cast(ptr % num_heads); + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input_src + block_offset; T* output_data = output_dest + block_offset; - // Cache is (M, H/2) + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) const int position_id = (position_ids_format == 0) ? static_cast(pos_ids_data[0]) + s : static_cast(pos_ids_data[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_emb_dim; const T* cos_data = cos_cache_data + cache_offset; const T* sin_data = sin_cache_data + cache_offset; int cache_idx = 0; T sign = 0; int j = 0; - for (int i = 0; i < head_size; i++) { + for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_emb_dim; sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_emb_dim; + sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } + for (int i = rotary_emb_dim; i < head_size; i++) { + output_data[i] = input_data[i]; + } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index be834a66cdc69..4e32424a22b6c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index 7b2e8289f7b06..dcbb36d1c4a3c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -11,14 +11,15 @@ namespace rotary_embedding_helper { // Parameters deduced from node attributes and inputs/outputs. struct RotaryParameters { - int batch_size; // Batch size used by input - int sequence_length; // Sequence length used by input - int hidden_size; // Hidden size used by input - int head_size; // Head size used by cos/sin cache * 2 - int num_heads; // num_heads = hidden_size / head_size - int max_sequence_length; // Sequence length used by cos/sin cache - int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) - bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size + int rotary_embedding_dim; // Rotary embedding dimension. + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -26,11 +27,13 @@ Status CheckInputs(const T* input, const T* position_ids, const T* cos_cache, const T* sin_cache, + int num_heads, + int rotary_embedding_dim, void* parameters) { // input : (batch_size, sequence_length, hidden_size) // position ids : (1) or (batch_size, sequence_length) - // cos cache : (max_sequence_length, head_size / 2) - // sin cache : (max_sequence_length, head_size / 2) + // cos cache : (max_sequence_length, rotary_embedding_dim / 2) + // sin cache : (max_sequence_length, rotary_embedding_dim / 2) // Check input const auto& input_dims = input->Shape().GetDims(); @@ -60,6 +63,12 @@ Status CheckInputs(const T* input, "the same shape"); } + // Check num_heads and rotary_embedding_dim + if (rotary_embedding_dim > 0 && num_heads == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", + "specified"); + } + // Get attributes from inputs int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); @@ -73,8 +82,13 @@ Status CheckInputs(const T* input, transposed = true; } int max_sequence_length = static_cast(cos_cache_dims[0]); - int head_size = static_cast(cos_cache_dims[1]) * 2; - int num_heads = hidden_size / head_size; + int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + int position_ids_format = -1; // Check position_ids input shapes @@ -91,23 +105,15 @@ Status CheckInputs(const T* input, } else { position_ids_format = 0; } + // Check cos_cache input shapes if (max_sequence_length != static_cast(cos_cache_dims[0])) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", "max_sequence_length, got ", cos_cache_dims[0]); } - if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", - "head_size / 2, got ", cos_cache_dims[1]); - } - // Check sin_cache input shapes - if (max_sequence_length != static_cast(sin_cache_dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", - "max_sequence_length, got ", sin_cache_dims[0]); - } - if ((head_size / 2) != static_cast(sin_cache_dims[1])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", - "head_size / 2, got ", sin_cache_dims[1]); + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); } // Set rotary parameters @@ -117,10 +123,11 @@ Status CheckInputs(const T* input, output_parameters->sequence_length = sequence_length; output_parameters->hidden_size = hidden_size; output_parameters->head_size = head_size; - output_parameters->num_heads = num_heads; + output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; output_parameters->transposed = transposed; + output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 406c73c95d444..72948c74d7877 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,6 +9,9 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#ifdef ORT_NEURAL_SPEED +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#endif namespace onnxruntime { namespace contrib { @@ -24,15 +27,17 @@ class MatMulNBits final : public OpKernel { accuracy_level_{info.GetAttr("accuracy_level")} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); - is_asym_ = info.GetInputCount() >= 4; +#ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; const Tensor* tensor_zero_point = nullptr; bool B_constant = info.TryGetConstantInput(1, &tensor_B); bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + is_asym_ = info.GetInputCount() >= 4; all_constant_ = B_constant && scale_constant; all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; +#endif } Status Compute(OpKernelContext* context) const override; @@ -53,30 +58,34 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; +#ifdef ORT_NEURAL_SPEED bool is_asym_{false}; bool all_constant_{false}; +#endif }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; +#ifdef ORT_NEURAL_SPEED if (!all_constant_) { return Status::OK(); } - -#if defined(MLAS_JBLAS) - - auto compt_type = static_cast(accuracy_level_); MLAS_THREADPOOL* pool = NULL; + if (nbits_ != 4) { + return Status::OK(); + } + auto comp_type = static_cast(accuracy_level_); + auto nbits = static_cast(nbits_); if (input_idx == 1) { - packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); if (packed_b_size_ == 0) return Status::OK(); auto qptr = tensor.Data(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); std::memset(packed_b_.get(), 0, packed_b_size_); - MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, false, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -85,8 +94,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 2 && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, !is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -95,8 +104,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 3 && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -104,7 +113,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(MLAS_JBLAS) +#else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); @@ -119,7 +128,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } @@ -127,9 +136,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; - -#if defined(MLAS_JBLAS) - +#ifdef ORT_NEURAL_SPEED // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -144,14 +151,14 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep packed_b_ = std::move(prepacked_buffers[0]); } -#else // defined(MLAS_JBLAS) +#else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } @@ -160,9 +167,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); - -#if defined(MLAS_JBLAS) - +#ifdef ORT_NEURAL_SPEED if (packed_b_.get()) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); @@ -181,7 +186,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t N = static_cast(helper.N()); const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); + std::vector gemm_params(max_len); AllocatorPtr allocator; auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); @@ -192,15 +197,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data()); + auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); // workspace for activation process(dynamic quantization and others) auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); - MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), - thread_pool); + NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); return Status::OK(); } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h new file mode 100644 index 0000000000000..864abffd131fe --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h @@ -0,0 +1,45 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" + +namespace bestla { + +using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; +using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; +using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; +using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; +using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; +using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; + +template +using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; +template +using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; + +class ORTThreading : public parallel::IThreading { + public: + explicit ORTThreading(void* tp); + void parallel_for(const parallel::thread_func& func) const override; + void set_threads(int nthreads) override { + (void)(nthreads); + assert(0); + } + void sync() const override { assert(0); } + void* mTp; +}; + +} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc new file mode 100644 index 0000000000000..73aaa4ae61a6e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -0,0 +1,438 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.cpp + +Abstract: + + GEMM template combinations of neural_speed. +--*/ + +#include "contrib_ops/cpu/quantization/neural_speed_defs.h" +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#include "core/platform/threadpool.h" + +using ThreadPool = onnxruntime::concurrency::ThreadPool; + +namespace bestla { + +ORTThreading::ORTThreading(void* tp) + : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} + +void ORTThreading::parallel_for(const parallel::thread_func& func) const { + ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, + [&](ptrdiff_t tid) { func(static_cast(tid)); }); +} + +template +static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + if (M <= 16) { + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->IsAsym()) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); + } + typename Launcher::Param args{gp, + {A, lda_, &reduceA}, + {B}, + {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), + reduceA.template RPtr(), reduceA.lda}, + {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } else { + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + static Launcher kernel; + typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } +} + +template +static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); +} + +template +static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(A); + (void)(N); + (void)(C); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using ProA = prologue_a::gemm::ActivationKBlockBaseF32; + static ProA proA; + if (B->IsAsym()) { + auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + // using ProA = prologue_a::gemm::ActivationBase; + return 0; + } +} + +template +static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + (void)(N); + (void)(lda); + (void)(ldc); + (void)(A); + (void)(C); + using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; + static ProA proA; + auto quanA = + proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); + return quanA.mSize; +} + +} // namespace bestla + +using namespace bestla; + +static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, + void* ThreadPool) { + GetCPUDevice(); + bestla::ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, + DataParams[i].ldc, WorkSpace, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && + BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && + BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + } + } + } + return size; +} + +template +static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { + static T proB; + auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), + BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto wptr = reinterpret_cast(ptr); + auto BlkSize = wptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} + +template +static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, + const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, + void* ThreadPool) { + static T proB; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + BTLA_DTYPE::BF16, IsAsym); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + proB.reduceWeight(&stor, &orth); + } +} + +static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + [[fallthrough]]; + default: + return 0; + } +} + +static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, + size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, + K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, + lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, + ldb, ThreadPool); + return true; + } + [[fallthrough]]; + default: + return false; + } +} + +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, + NS_SQNBIT_COMPUTE_TYPE CompType) { + if (nbits == 4) { + auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } + return 0; +} + +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + if (nbits == 4) { + if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +} + +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +} + +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +} + +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by bestla + return; + } +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h new file mode 100644 index 0000000000000..ebcb3027a209f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.h + +Abstract: + + Prepack-weight GEMM APIs of neural_speed. +--*/ + +#pragma once + +#include +#include + +/** + * @brief Define compute types of block quantization + */ +enum NS_SQNBIT_COMPUTE_TYPE { + NSCompUndef = 0, /*!< undef */ + NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ + NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ + NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ + NSCompInt8 = 4 /*!< input int8, accumulator int32 */ +}; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, + NS_SQNBIT_COMPUTE_TYPE comp_type); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, + NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h new file mode 100644 index 0000000000000..d3902f9bd68c7 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h @@ -0,0 +1,39 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-value" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" + +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4457) +#pragma warning(disable : 4189) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4702) +#endif + +#include "bestla/bestla_prologue_a.h" +#include "bestla/bestla_wrapper.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ebd66d8c6528e..f978f50c6851f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); disable_fused_self_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); @@ -105,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { num_heads_, mask_filter_value_, scale_, + is_unidirectional_, false, // past_present_share_buffer false, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index c162f7133cc1c..86a32c92ce003 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel { int num_heads_; // number of attention heads float mask_filter_value_; float scale_; + bool is_unidirectional_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 2d12e975d88d7..9de7ba3885c3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -29,10 +29,13 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); } @@ -48,6 +51,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -71,6 +76,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.num_heads, parameters.head_size, + parameters.rotary_embedding_dim, parameters.max_sequence_length, parameters.position_ids_format, interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h index 6dab2ad56749e..d52f61d670444 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -19,6 +19,8 @@ class RotaryEmbedding final : public CudaKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index e1b83bd8caf54..c6637041f05bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -26,6 +26,7 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int position_ids_format, const bool interleaved, const int batch_stride, @@ -33,24 +34,33 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently - + const int b = blockIdx.z; const int s = blockIdx.y; const int n = blockIdx.x; const int i = threadIdx.x; + if (i >= head_size) { + return; + } + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input + block_offset; T* output_data = output + block_offset; + if (i >= rotary_embedding_dim) { + output_data[i] = input_data[i]; + return; + } + // Cache is (M, H/2) - const int half_head_size = head_size / 2; + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; const int position_id = (position_ids_format == 0) ? \ static_cast(position_ids[0]) + s \ : static_cast(position_ids[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_embedding_dim; const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; @@ -58,13 +68,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH T sign = 0; int j = 0; if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_embedding_dim; sign = (i % 2 == 0) ? -1 : 1; j = (i % 2 == 0) ? i+1 : i-1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? -1 : 1; - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_embedding_dim; + sign = (i < half_rotary_embedding_dim) ? -1 : 1; + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } @@ -82,20 +92,23 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed) { - - constexpr int smem_size = 0; - const dim3 grid(num_heads, sequence_length, batch_size); - const dim3 block(head_size, 1, 1); - // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + ORT_ENFORCE(head_size <= max_threads_per_block, + "Rotary embedding dim must be <= max_threads_per_block"); + + int tpb = (head_size + 31)/32*32; + + const dim3 block(tpb); + const dim3 grid(num_heads, sequence_length, batch_size); // Default input tensor shape is [batch, seq, hidden_size] int head_stride = head_size; @@ -109,10 +122,9 @@ Status LaunchRotaryEmbeddingKernel( } assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>( - output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved, - batch_stride, seq_stride, head_stride + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, + rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -129,6 +141,7 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, @@ -146,6 +159,25 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + BFloat16* output, + const BFloat16* input, + const int64_t* position_ids, + const BFloat16* cos_cache, + const BFloat16* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index ee1ccc43dcbff..36300fe7a660f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -21,6 +21,7 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 9b989dac9a94b..40a667ffd5d83 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" @@ -204,5 +202,3 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index cbd483fddab78..5ea4ae59c4020 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -36,5 +34,3 @@ class ShardedMoE final : public NcclKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 34b44694a5fcc..8f368251f12c7 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,10 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -98,6 +96,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -168,10 +167,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); -#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -271,10 +268,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -299,6 +294,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -375,10 +371,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, -#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, -#endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h index 9b97690fe70fd..86136ea244e23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifdef USE_CUTLASS - #pragma once #include @@ -52,5 +49,3 @@ inline int compute_occupancy_for_kernel() { } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index f0abd46572a90..adc043e5689e2 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef USE_CUTLASS #include "cutlass_heuristic.h" @@ -66,9 +65,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, } // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); - const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); - const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + const size_t ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const size_t ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > workspace_bytes) { return false; @@ -128,7 +127,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) { + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); int occupancy = occupancies[ii]; @@ -186,5 +185,3 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector @@ -64,5 +62,3 @@ class MoeGemmRunner { }; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d0dfe7c5a647..1d9a249db4237 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7a5d97902ee8f..7b250e6ca9060 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -14,12 +14,8 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 3fd0fc47055a5..66950c9b65970 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -14,8 +14,6 @@ * limitations under the License. */ -#ifdef USE_CUTLASS - // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ #pragma GCC diagnostic push @@ -428,5 +426,3 @@ void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, con } } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 9232e8d012933..f4f2b49032d23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include #include #include @@ -900,5 +898,3 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half cudaStream_t); } // namespace ort_fastertransformer - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index f09471de1cc2e..5cc2a3f79f003 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -16,8 +16,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "moe_gemm_kernels.h" @@ -174,6 +172,4 @@ class CutlassMoeFCRunner> { } // namespace layout } // namespace cutlass - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 0da06192e266b..3f26a274109ad 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" @@ -119,5 +117,3 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 710b914f0633d..c4d8c4dc64c57 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -26,5 +24,3 @@ class MoE final : public CudaKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index dc8b9d57f79f6..f55a7cde2e208 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUTLASS - #pragma once #include "core/common/common.h" @@ -172,5 +170,3 @@ class MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 67384957d8dd2..d4d583906b7f4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -89,7 +89,7 @@ __device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* ha asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(kOneSixteenth), "r"(kNeg64)); } -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -120,7 +120,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, sums_half2[3] = sums_half2[3] + v3 * (*(reinterpret_cast(&(vec_permuted.w)))); } #else -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, half scale, uint8_t zp, const half* a, half* sums) { half2 scale_half2 = {scale, scale}; half zp_adjust = -scale * __short2half_rn(zp); half2 zp_adjust2 = {zp_adjust, zp_adjust}; @@ -144,7 +144,7 @@ __device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, } #endif -__device__ __forceinline__ float AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { +__device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, float scale, uint8_t zp, const float* a, float* sums) { float4 a_vec_0 = *(reinterpret_cast(a)); float4 a_vec_1 = *(reinterpret_cast(a + 4)); diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 6f98312e4067d..09e7d61b71db9 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -68,6 +68,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; @@ -121,8 +122,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, - num_heads_, mask_filter_value_, scale_, + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h index 84d8b76bbfebe..1d676d7a7bcac 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public RocmKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index fcf9c2b03dea5..711fd595e90fd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -30,6 +30,10 @@ #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #endif // ARM #endif // Linux @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); #endif } @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() { /* TODO: implement them when hw+sw is available for testing these features */ has_arm_neon_i8mm_ = false; has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a15c75104b83a..2f8041e39f680 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -30,6 +30,7 @@ class CPUIDInfo { bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } uint32_t GetCurrentCoreIdx() const; @@ -125,6 +126,7 @@ class CPUIDInfo { bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e4fe0c7564548..07b465c80745a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -16,6 +16,7 @@ #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -634,6 +635,100 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +static Status CreateEpContextModel(const ExecutionProviders& execution_providers, + const Graph& graph, + const std::string& ep_context_path, + const logging::Logger& logger) { + InlinedVector all_ep_context_nodes; + for (const auto& ep : execution_providers) { + const InlinedVector ep_context_nodes = ep->GetEpContextNodes(); + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); + } + + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : all_ep_context_nodes) { + if (node_name == node->Name()) { + return std::make_pair(true, node); + } + } + return std::make_pair(false, static_cast(nullptr)); + }; + + onnxruntime::PathString context_cache_path; + PathString model_pathstring = graph.ModelPath().ToPathString(); + if (all_ep_context_nodes.size() > 0) { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { +#ifdef _WIN32 + std::wifstream fs(context_cache_path); +#else + std::ifstream fs(context_cache_path); +#endif + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } + + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); + + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); + + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } + + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } + + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); + } + } + + // handle initializers + for (const auto& input : graph.GetInputsIncludingInitializers()) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(input->Name(), initializer)) { + // There initializer could have duplicates so make sure we only add once + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { + ep_graph.AddInitializedTensor(*initializer); + } + } + } + + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -840,6 +935,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. @@ -886,7 +983,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_)); + + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (ep_context_enabled) { + ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger)); + } #else + ORT_UNUSED_PARAMETER(config_options); + ORT_UNUSED_PARAMETER(logger); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 4fc85c2588260..d1ef193cf1520 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -13,6 +13,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; class Model; +struct ConfigOptions; class GraphPartitioner { public: @@ -31,6 +32,8 @@ class GraphPartitioner { // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0317ffcfb0e31..7f34647f1faef 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -927,6 +927,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("unidirectional", + "Whether every token can only attend to previous tokens. Default value is 0.", + AttributeProto::INT, + static_cast(0)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", @@ -1145,6 +1149,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("rotary_embedding_dim", + "Rotary embedding dimension. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("num_heads", + "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "input", "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", @@ -1155,17 +1167,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(2, "cos_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Input(3, "sin_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Output(0, "output", "tensor with same shape as input.", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bdd4dba521eba..ce7838556fbf0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index bc0bfc92c85a0..047011e70bd4d 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -183,133 +183,3 @@ MlasSQNBitGemmPackQuantBData( void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr ); - -/** - * @brief Data parameters for NBits GEMM routine - * C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * All except C are [in] parameters - */ -struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (packed nbits blob)*/ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ -}; - -/** - * @brief Compute the byte size of the parameter combination - * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return size of the packing buffer, 0 if the operation is not yet supported. - */ -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type -); - -/** - * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. - * - * @param PackedBuf packed data buffer - * @param QData quantized data buffer - * @param Scale scale pointer - * @param Zp zero point pointer - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor - * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where - * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale - * (is_asym is false) and Zp(is_asym is true). - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t block_size, - int nbits, - bool is_asym, - bool last_call, - MLAS_SQNBIT_COMPUTE_TYPE comp_type, - MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Unpack and dequantize to fp32 - * - * @param FpData unpacked float32 data - * @param PackedBuf quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmUnPackB( - float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Get the workspace size required by computation. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @return Workspace size in bytes - */ -size_t MLASCALL -MlasSQNBitsGemmBatchPackedBWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); - -/** - * @brief Batched GEMM: C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] WorkSpace temporary buffer - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool = nullptr -); diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 0000000000000..e424c30515e9f --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h deleted file mode 100644 index 9cd1711a3ffd2..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_defs.h +++ /dev/null @@ -1,73 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - ---*/ - -#pragma once - -#include "jblas/jit_blas_prologue_b.h" -#include "jblas/jit_blas_wrapper.h" - -namespace jblas -{ - -/* -Name conversion explaination: -Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight -classes) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores. -*/ -template -using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationKBlockBaseF32, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompFp32BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -/* -Name conversion explaination: -Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores. -*/ -template -using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationF32KBlockQuantize, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompInt8BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; -using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; -using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; -using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency -using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; -using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; -using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; -using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency - -class ORTThreading : public jblas::parallel::IThreading -{ - public: - ORTThreading(void* tp); - void parallel_for(const jblas::parallel::thread_func& func) override; - void set_threads(int nthreads) override { assert(0); } - void sync() override { assert(0); } - void* mTp; -}; - -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp deleted file mode 100644 index f3cae3186c28e..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.cpp +++ /dev/null @@ -1,534 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.cpp - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#include "jblas_gemm.h" - -#include "jblas_defs.h" -#include "mlasi.h" - -using namespace jblas; - -jblas::ORTThreading::ORTThreading(void* tp) - : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) -{ -} - -void -jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) -{ - MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { - func(static_cast(tid)); - }); -} - -template -static void -JblasSQ4GemmCompF32( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - if (M <= 16) { - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - if (B->mIsAsym) { - reduceA.assign(WorkSpace); - ORTThreading single(nullptr); - kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single); - } - typename Launcher::BEpiParam blkargs{ - B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), - reduceA.template get(), reduceA.lda}; - - typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); - } else { - using Parallel = jblas::parallel::gemm::SchedulerBase; - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - - typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}}; - jblas::parallel::GemmBaseRun(kernel, args, th); - } -} - -template -static void -JblasSQ4GemmCompInt8( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - static Launcher kernel; - auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym); - quanA.assign(WorkSpace); - if (M <= 16) { - ORTThreading single(nullptr); - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); - } else { - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); - } - typename Launcher::Param args{ - M_, - N_, - K_, - B->mBlockSize, - {A, lda_, &quanA}, - {B}, - {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep, - quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), - quanA.template RPtr(), B->mBlockSize}, - {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); -} - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - ORTThreading orth(ThreadPool); - bool processed = true; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - } - } else { - processed = false; - break; - } - } - return processed; -} - -template -static size_t -JblasSQ4GemmCompF32WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - auto M_ = static_cast(M); - auto K_ = static_cast(K); - (void)(N); - (void)(lda); - (void)(ldc); - if (M <= 16) { - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - if (B->mIsAsym) { - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - return reduceA.mSize; - } - return 0; - } else { - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - return 0; - } - return 0; -} - -template -static size_t -JblasSQ4GemmCompInt8WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - static Launcher kernel; - (void)(N); - (void)(lda); - (void)(ldc); - auto quanA = kernel.mProA.createStorage( - static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym - ); - return quanA.mSize; -} - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ - GetCPUDevice(); - size_t size = 0; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - } - } - } - return size; -} - -template -static size_t -JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) -{ - static T launcher; - auto stor = launcher.mProB.createStorage( - static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, - JBLAS_DTYPE::BF16, isAsym - ); - // TODO(Yu) support more scale dtype - return stor.mSize; -} - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) -{ - GetCPUDevice(); - if (K % BlkSize != 0) { - return 0; - } - // from low precision to high precision - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - break; - default: - return 0; - } - return 0; -} - -template -static void -JblasQ4GemmPackBImpl( - void* PackedBuf, - size_t BlkSize, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - bool IsAsym, - bool lastCall, - size_t ldb, - MLAS_THREADPOOL* ThreadPool -) -{ - static T JblasKernel; - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto stor = JblasKernel.mProB.createStorage( - N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym - ); - stor.assign(reinterpret_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); - if (lastCall) { - JblasKernel.mProB.reduceWeight(&stor, &orth); - } -} - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - // explicit statement fall through. - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - default: - return false; - } - return false; -} - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); - auto uptr = std::unique_ptr(ptr); - ORTThreading orth(ThreadPool); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto ldb_ = static_cast(ldb); - GetCPUDevice(); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - } - return true; - } - return false; -} diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h deleted file mode 100644 index 044dc5e849a0a..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.h +++ /dev/null @@ -1,61 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.h - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#pragma once - -#include "mlas_qnbit.h" - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -); - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb - , MLAS_THREADPOOL* ThreadPool); - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -); - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bb8b17031a84..624eb913d5c9e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -193,6 +193,8 @@ class MLASCPUIDInfo bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); @@ -200,6 +202,7 @@ class MLASCPUIDInfo bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -357,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -727,6 +744,10 @@ extern "C" { #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -856,6 +877,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 1310ed3f384b9..de092f7d1d350 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -60,6 +60,10 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -70,6 +74,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h new file mode 100644 index 0000000000000..de7fd72fad45a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..a6a73996c548b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 7d877848017fe..0d8a5692359a6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,10 +19,6 @@ Module Name: #include -#ifdef MLAS_JBLAS -#include "jblas_gemm.h" -#endif - namespace { @@ -694,127 +690,3 @@ MlasSQNBitGemmBatch( ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } - -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); - if (jsize) { - return jsize; - } - } -#endif - (void)(N); - (void)(K); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(CompType); - return 0; -} - -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - int nbits, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { - return; - } - } -#endif - (void)(PackedBuf); - (void)(QData); - (void)(Scale); - (void)(Zp); - (void)(N); - (void)(K); - (void)(ldb); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(lastCall); - (void)(CompType); - (void)(ThreadPool); -} - -void MLASCALL -MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ -#ifdef MLAS_JBLAS - if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { - return; - } -#endif - (void)(FpData); - (void)(PackedBuf); - (void)(N); - (void)(K); - (void)(ldb); - (void)(ThreadPool); -} - -size_t MLASCALL -MlasSQNBitsGemmBatchPackedBWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ -#ifdef MLAS_JBLAS - return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - return 0; -} - -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetMlasPlatform(); -#ifdef MLAS_JBLAS - if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { - // PackedWeight is created by jblas - return; - } -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - (void)(WorkSpace); - (void)(ThreadPool); -} diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format deleted file mode 100644 index 84b876706161d..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format +++ /dev/null @@ -1,7 +0,0 @@ -Language: Cpp -BasedOnStyle: Google -DerivePointerAlignment: false -ColumnLimit: 120 -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SortIncludes: false diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt deleted file mode 100644 index 5d9c5edf45a96..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt +++ /dev/null @@ -1,33 +0,0 @@ -cmake_minimum_required(VERSION 3.5) - -project(jblas LANGUAGES CXX VERSION 0.1.0) - -file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) -file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) - -add_library(${PROJECT_NAME} INTERFACE) -add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) - -target_include_directories( - ${PROJECT_NAME} INTERFACE - "$" - "$" -) - -if(WIN32) - target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) - target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100) - #4068 ignore unroll and GCC flags - #4849 ignore collapse - #6262 ignore stack too large - #4702 unreachable code(false warning on constexpr condition) - #4100 unreferenced formal parameter - - target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size -endif(WIN32) - - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h deleted file mode 100644 index 143adb771760b..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include -#include -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" - -#define OFFSET(field) offsetof(params, field) - -namespace jblas { - -namespace xbyak { -class JitBase : protected Xbyak::CodeGenerator { - protected: - JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {} - - void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) { - xor_(reg, reg); - mov(reg.cvt32(), addr); - } - - void vreg_push(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i)); - } -#endif - } - - void vreg_pop(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]); - } -#endif - } - - void padto_le(const Xbyak::Reg64& _src, int padding) { - // _src=_src/padding*padding - if (padding == 1) { - return; - } - for (int i = 1; i < 16; i++) { - if ((1 << i) == padding) { - shr(_src, i); - shl(_src, i); - return; - } - } - assert(0); - } - - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - inLocalLabel(); - lea(_tmp, _total); - sub(_tmp, _pos); - cmp(_tmp, N); - jb(".maskflag"); - cmp(_tmp, 0); - jl(".zeroflag"); - uint64_t allmask = (static_cast(1) << N) - 1; - if (N == 64) { - allmask = static_cast(-1); - } - mov(_tmp, allmask); - kmovq(_msk, _tmp); - jmp(".maskend"); - L(".maskflag"); - mov(_tmp1, 1); - shlx(_tmp1, _tmp1, _tmp); - sub(_tmp1, 1); - kmovq(_msk, _tmp1); - jmp(".maskend"); - L(".zeroflag"); - mov(_tmp1, 0); - kmovq(_msk, _tmp1); - L(".maskend"); - outLocalLabel(); - } - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N); - } -}; - -class JitAvx : protected JitBase { - protected: - static int constexpr VBits = 256; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 16; - typedef Xbyak::Ymm vreg_t; -}; - -class JitAvx2 : protected JitAvx { - protected: - static int constexpr VBits = 256; - typedef Xbyak::Ymm vreg_t; - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } - - void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } -}; - -class JitAvx512f : protected JitAvx2 { - protected: - static int constexpr VBits = 512; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 32; - typedef Xbyak::Zmm vreg_t; - - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } - - void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { - vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); - vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); - vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6)); - vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - } - - void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) { - for (int i = 0; i < 8; ++i) { - vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]); - vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]); - } - - for (int i = 0; i < 4; ++i) { - vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]); - vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]); - } - - for (int i = 0; i < 2; ++i) { - vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88); - vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88); - vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88); - vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88); - vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd); - vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd); - vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd); - vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd); - } - - // last step and move out - for (int i = 0; i < N; ++i) { - vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd); - } - } - - void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) { - vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]); - vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]); - vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]); - vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]); - - vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]); - vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]); - vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]); - vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4); - vmovups(src_4regs[0], tmp_regs[1]); - vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[1], tmp_regs[3]); - vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14); - vmovups(src_4regs[2], tmp_regs[1]); - vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[3], tmp_regs[3]); - vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - } - - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { - vpsrld(_fp32, _fp32, 16); - vpmovdw(_bf16, _fp32); - } - - void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } - - void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) { - mov(tmp.cvt16(), addr); - shl(tmp.cvt32(), 16); - vpbroadcastd(dst, tmp.cvt32()); - } - - void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) { - auto bf16 = Xbyak::Ymm(_fp32.getIdx()); - cvt_fp32_bf16(bf16, _fp32); - vmovups(_add, bf16); - } -}; - -class JitAvx512_bf16 : protected JitAvx512f {}; - -class JitAvx512_fp16 : protected JitAvx512f {}; - -class JitAvx512vnni : protected JitAvx512f { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); - } -}; - -class JitAvxvnni : protected JitAvx2 { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::VexEncoding); - } -}; - -class JitAmxtile : protected JitAvx512f { - public: - struct alignas(64) tileconfig_t { - uint8_t palette_id; - uint8_t reserved[15]; - uint16_t colb[16]; - uint8_t rows[16]; - }; - static int constexpr TileCount = 8; - - typedef long long (*configure_t)(void*); - - static void generate_config(Xbyak::CodeGenerator* g) { - Xbyak::util::StackFrame st(g, 1, 0, 0); - auto& parambase = st.p[0]; - g->ldtilecfg(g->ptr[parambase]); - } - - static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, - int CNum) { - // Filling tile configure structure. Could be done offline. - tc.palette_id = 1; - // Configure C tiles - int t = 0; - for (; t < CNum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_N * 4); - } - // Configure A tiles - for (; t < CNum + ANum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_K * elesize); - } - // Configure B tile. B effectively has 64 rows and 16 columns. - int kpack = 4 / elesize; - for (; t < CNum + ANum + BNum; ++t) { - tc.rows[t] = static_cast(TILE_K / kpack); - tc.colb[t] = static_cast(TILE_N * 4); - } - } -}; - -class JitAmxbf16 : protected JitAmxtile { - protected: - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); } -}; - -class JitAmxint8 : protected JitAmxtile { - protected: - template - void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3); -}; -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbssd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbsud(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbusd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbuud(x1, x2, x3); -} -} // namespace xbyak -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h deleted file mode 100644 index 8ecf3535c17f4..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include -enum JBLAS_CODE { - JblasSuccess = 0, - JblasInvalidParam = 1, - JblasInvalidISA = 2, - JblasRuntimeError = 4, - JblasNotSupport = 8, -}; -enum JBLAS_ISA : uint32_t { - JblasNoSIMD = 0, - JblasAVX, - JblasAVX2, - JblasAVX_VNNI, - JblasAVX512F, - JblasAVX512_VNNI, - JblasAMX_BF16, - JblasAMX_INT8, - JblasAVX512_FP16, - JblasAVX512_BF16, -}; -enum class JBLAS_DTYPE : uint32_t { - EleBitsMask = 0xff, - EleBitsUndef = 0, - EleBits4 = 4, - EleBits8 = 8, - EleBits16 = 16, - EleBits32 = 32, - EleBits64 = 64, - TypeMask = 0xff00, - TypeFloat = 0 << 8, - TypeInt = 1 << 8, - SubTypeMask = 0xff0000, - SubType0 = 0 << 16, - SubType1 = 1 << 16, - SubType2 = 2 << 16, - F64 = EleBits64 | TypeFloat, - F32 = EleBits32 | TypeFloat, - F16 = EleBits16 | TypeFloat, - BF16 = EleBits16 | TypeFloat | SubType1, - F8_E4M3 = EleBits8 | TypeFloat, - F8_E5M2 = EleBits8 | TypeFloat | SubType1, - F8_E3M4 = EleBits8 | TypeFloat | SubType2, - S8 = EleBits8 | TypeInt, - U8 = EleBits8 | TypeInt | SubType1, - S4_CLIP = EleBits4 | TypeInt, - S4_FULLRANGE = EleBits4 | TypeInt | SubType1, - F4_E2M1 = EleBits4 | TypeFloat, - F4_BNB = EleBits4 | TypeFloat | SubType1, - F4_NF4 = EleBits4 | TypeFloat | SubType2, - S32 = EleBits32 | TypeInt, - U32 = EleBits32 | TypeInt | SubType1, -}; - -enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; -enum JBLAS_TRANSPOSE { - JblasNoTrans = 111, - JblasTrans = 112, - JblasConjTrans = 113, -}; -enum JBLAS_ELTWISEOP { - GELU, - SWISH, - TANH, - EXP, - LOW_PRECISION_EXP, - RELU, - LINEAR, -}; - -enum class JBLAS_PROLOGUEB_IDS : uint32_t { - Undef = (uint32_t)-1, - Begin = 0, - NormalBegin = Begin, - WeightPack = NormalBegin, - NormalEnd, - KBlockBegin = NormalEnd, - WeightKBlockS8 = KBlockBegin, - WeightKBlockS4, - WeightKBlockF4, - KBlockEnd, - End, -}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h deleted file mode 100644 index 5cac1080bc610..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "jit_blas.h" -#include "xbyak/xbyak_util.h" - -namespace jblas { - -namespace device { - -struct X64_ISA { - int64_t MMX : 1; // 0 - int64_t SSE : 1; // 1 - int64_t SSE2 : 1; // 2 - int64_t SSE3 : 1; // 3 - int64_t SSSE3 : 1; // 4 - int64_t SSE41 : 1; // 5 - int64_t SSE42 : 1; // 6 - int64_t AVX : 1; // 7 - int64_t F16C : 1; // 8 - int64_t FMA : 1; // 9 - int64_t AVX2 : 1; // 10 - int64_t AVX_VNNI : 1; // 11 - int64_t AVX_VNNI_INT8 : 1; // 12 - int64_t AVX_NE_CONVERT : 1; // 13 - int64_t AVX_IFMA : 1; // 14 - int64_t AVX512F : 1; // 15 - int64_t AVX512BW : 1; // 16 - int64_t AVX512CD : 1; // 17 - int64_t AVX512DQ : 1; // 18 - int64_t AVX512ER : 1; // 19 - int64_t AVX512IFMA52 : 1; // 20 - int64_t AVX512PF : 1; // 21 - int64_t AVX512VL : 1; // 22 - int64_t AVX512VPOPCNTDQ : 1; // 23 - int64_t AVX512_4FMAPS : 1; // 24 - int64_t AVX512_4VNNIW : 1; // 25 - int64_t AVX512_BF16 : 1; // 26 - int64_t AVX512_BITALG : 1; // 27 - int64_t AVX512_VBMI : 1; // 28 - int64_t AVX512_VBMI2 : 1; // 29 - int64_t AVX512_VNNI : 1; // 30 - int64_t AVX512_VP2INTERSECT : 1; // 31 - int64_t AVX512_FP16 : 1; // 32 - int64_t AMX_TILE : 1; // 33 - int64_t AMX_BF16 : 1; // 34 - int64_t AMX_INT8 : 1; // 35 - int64_t AMX_FP16 : 1; // 36 - int64_t AMX_COMPLEX : 1; // 37 - int64_t reserved : (64 - 38); -}; - -class AVX2_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 0; - static constexpr bool AVX512BW = 0; - static constexpr bool AVX512CD = 0; - static constexpr bool AVX512DQ = 0; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 0; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 0; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class AVX512_VNNI_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class SapphireRapids { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 1; - static constexpr bool AMX_BF16 = 1; - static constexpr bool AMX_INT8 = 1; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -template -class isa_base { - public: - static bool constexpr avx = ISA_T >= JblasAVX; - static bool constexpr avx2 = ISA_T >= JblasAVX2; - static bool constexpr avx512f = ISA_T >= JblasAVX512F; - static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; - static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; - static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; - static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; -}; - -class CpuDevice { - public: - inline void setThreads(int _nth) { - if (_nth <= 0) { - numthreads = numcores; - } else { - numthreads = std::min(numcores, _nth); - } - } - inline int getThreads() { return numthreads; } - inline int getCores() { return numcores; } - inline uint32_t getL2CacheSize() { return L2Cache; } - inline uint32_t getL1CacheSize() { return L1Cache; } - inline bool AVX() { return mHasAVX; } - inline bool AVX2() { return mHasAVX2; } - inline bool AVX_VNNI() { return mHasAVX_VNNI; } - inline bool AVX512F() { return mHasAVX512F; } - inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } - inline bool AMX_INT8() { return mHasAMX_INT8; } - inline bool AMX_BF16() { return mHasAMX_BF16; } - inline bool AVX512_BF16() { return mHasAVX512_BF16; } - inline bool AVX512_FP16() { return mHasAVX512_FP16; } -#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) - CpuDevice() { - static Xbyak::util::Cpu _cpu; - L1Cache = _cpu.getDataCacheSize(0); - L2Cache = _cpu.getDataCacheSize(1); - ADD_FLAG(AVX); - ADD_FLAG(AVX2); - ADD_FLAG(AVX512F); - ADD_FLAG(AVX512_VNNI); - ADD_FLAG(AVX_VNNI); - ADD_FLAG(AMX_BF16); - ADD_FLAG(AMX_INT8); - ADD_FLAG(AVX512_BF16); - ADD_FLAG(AVX512_FP16); - numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); - numthreads = numcores; - } - - static CpuDevice* getInstance() { - static CpuDevice instance; - return &instance; - } - - void print() { - printf( - "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", - mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, - mHasAVX512_FP16); - } -#undef ADD_FLAG - - protected: - uint32_t L2Cache, L1Cache; - bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, - mHasAVX512_FP16; - int numcores; - int numthreads; -}; - -#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); - -class CpuBase { - public: - CpuBase() { - GetCPUDevice(); - mL2Cache = _cd->getL2CacheSize(); - mL1Cache = _cd->getL1CacheSize(); - mNumThreads = _cd->getThreads(); - } - size_t mL2Cache, mL1Cache; - int mNumThreads; -}; -} // namespace device -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h deleted file mode 100644 index ceb7a545092d8..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_base.h" -#include "jit_blas.h" -#include "jit_blas_utils.h" -#include "kernel_wrapper.h" - -namespace jblas { -namespace epilogue { -namespace gemm { - -template -class AccumulatorWriteBack { - public: - using SType = _SRC_T; - using DType = _DST_T; - struct Param { - DType* C; - int ldc; - void* elt_const_v; - }; - - template - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - bool constexpr Valid = !std::is_same::value || std::is_same::value; - static_assert(Valid, "fp32 to bf16 conversion only."); - if constexpr (std::is_same::value) { - return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (std::is_same, std::tuple>::value) { - return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (sizeof(SType) == sizeof(DType)) { - return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v, ops...); - } else { - assert(false); - } - } -}; - -template -class CustomAccumulatorWriteBackWithEltop { - public: - struct Param { - _DST_T* C; - int ldc; - void* elt_const_v; - }; - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { - return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v); - } else { - assert(false); - } - } -}; -template -using AccumulatorWriteBackFp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackInt32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackBf16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; - -template -using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; - -template -using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; - -template -class AlphaBetaProcessFp32 { - public: - struct Param { - float *C, *D; - int ldc, ldd; - float alpha, beta; - }; - - JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto DOffset = M_offset * _param.ldd + N_offset; - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto dptr = _param.D + DOffset; - return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, - dptr, _param.ldd, cptr, _param.ldc, M, N); - } -}; - -template -class CompFp32BlockEpilogue { - public: - struct Param { - void* scales; - JBLAS_DTYPE scaledtype; - int ldsb; - int8_t* zps = nullptr; - float* reduce = nullptr; - int ldra; - }; - JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - auto ret = JblasNotSupport; - if (_param.scaledtype == JBLAS_DTYPE::F32) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(ret == JblasSuccess); - if (_param.zps != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( - dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, - _param.reduce + M_offset * _param.ldra + K_offset); - } - assert(ret == JblasSuccess); - return ret; - } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(_param.zps == nullptr); - assert(ret == JblasSuccess); - return ret; - } - return JblasNotSupport; - } -}; - -template -class DequantInt32ToFp32 { - public: - struct Param { - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, _param.ldsa, - _param.scalesB + N_offset); - } -}; - -template -class CompInt8BlockEpilogue { - public: - struct Param { - void* scalesB; - JBLAS_DTYPE scaleBdtype; - int ldsb; - float* scalesA; - int ldsa; - // optional if A asym - uint8_t* zpA = nullptr; - void* reduceB = nullptr; - JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - JBLAS_CODE ret = JblasNotSupport; - float* scab = nullptr; - size_t ScaleBTmpSize = N * sizeof(float); - size_t ReduceBTmpSize = N * sizeof(float); - assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); - if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { - auto scache = reinterpret_cast(tmpcache); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - scab = scache; - } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; - } - float* redb = nullptr; - if (_param.reduceB) { - if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { - auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - redb = rcache; - } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { - redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; - } - } - ret = kernel::wrapper::DequanS32Fp32::template forward( - srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); - assert(ret == JblasSuccess); - ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, - dstptr, cachestep, M, N); - assert(ret == JblasSuccess); - - if (_param.zpA == nullptr) { - if (_param.zpB == nullptr) { - return ret; - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa + K_offset); - } - } else { - if (_param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, - _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); - } - } - return ret; - } -}; - -template -class ZpDequantInt32ToFp32 { - public: - struct Param { - // necessary - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - // optional if A asym - uint8_t* zpA = nullptr; - float* reduceB = nullptr; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.scalesB + N_offset); - if (ret != JblasSuccess) { - return ret; - } - if (_param.zpA == nullptr && _param.zpB == nullptr) { - return ret; - } else if (_param.zpA != nullptr && _param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.reduceB + N_offset); - } else if (_param.zpA == nullptr && _param.zpB != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, - _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, - _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); - } - return ret; - } -}; - -template -class AlphaBetaProcessS32U8 { - public: - struct Param { - uint8_t* C; - int ldc; - float alpha; - float scaleAcc, scaleC; - int zpC; - }; - - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, - M, N, _param.scaleAcc, _param.scaleC, _param.zpC); - } -}; - -} // namespace gemm -} // namespace epilogue -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h deleted file mode 100644 index 364da9223940f..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h +++ /dev/null @@ -1,2699 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_blas_utils.h" -#include "jit_base.h" - -namespace jblas { -namespace gemm { -enum class CompType : uint32_t { - COMP_FP32 = 0, - COMP_BF16_FP32 = 1, - COMP_FP16_FP16 = 2, - COMP_INT_START = 3, - COMP_INT8_US_INT32 = COMP_INT_START, - COMP_INT8_UU_INT32 = 4, - COMP_INT8_SS_INT32 = 5, - COMP_INT8_SU_INT32 = 6, - COMP_INT16_SS_INT32 = 7, - COMP_INT8_US_FP32 = 8, - COMP_INT8_UU_FP32 = 9, - COMP_INT8_SS_FP32 = 10, - COMP_INT8_SU_FP32 = 11, -}; - -class CoreAttr { - public: - // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| - static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, - COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; - - static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } - static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { - return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); - } - - static void parse_id(uint32_t id, uint32_t* vals) { - vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); - vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); - vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); - } - - static const char* to_str(uint32_t id) { - static char tmp[128]; - uint32_t vals[4]; - parse_id(id, vals); - sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); - return tmp; - } - - static inline size_t get_bsize(uint32_t id) { - auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - return size_t(4 / packrow); - } -}; - -namespace code { - -template -class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { - public: - static int constexpr RegLen = 8, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { - public: - static int constexpr RegLen = 32, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; - typedef utils::fp16 AType; - typedef utils::fp16 BType; - typedef utils::fp16 CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { - public: - static int constexpr RegLen = 8, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; - -template -class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; - static uint32_t constexpr COMPUTE = - (uint32_t)(std::is_same_v - ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 - : std::is_same_v ? CompType::COMP_INT8_US_INT32 - : CompType::COMP_INT8_UU_INT32); - using AType = AT; - using BType = BT; - typedef int32_t CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; -template -using Amxint8N16P4US = Amxint8N16P4; - -template -using Amxint8N16P4SS = Amxint8N16P4; - -class AmxConfigure : protected jblas::xbyak::JitAmxtile { - public: - typedef long long (*func_t)(tileconfig_t*); - - static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { - static AmxConfigure code; - tileconfig_t cfg; - std::memset(&cfg, 0, sizeof(cfg)); - configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); - code.mKernel(&cfg); - } - - protected: - AmxConfigure() { - generate_config(this); - mKernel = getCode(); - } - - func_t mKernel = nullptr; -}; - -namespace kblock { -// optimize for kblock gemm, each block size in k dimension has dequant operation -// all accumulators use fp32 dtype. -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; - typedef uint8_t AType; - typedef int8_t BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - uint8_t* zpA; - float* scaleA; - int ldsa; - float* scaleB; - float* reduceB; - int ldsb; - int k; - int n; - int kblock; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_iterkb; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_tmp4; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = NRegs; - CReg = 0; - CF32Reg = CReg + CRegCount; - BReg = CF32Reg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg < RegCount); - TmpRegCount = RegCount - TmpReg; - assert(TmpRegCount >= 1); - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_iterkb = st.t[12]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_tmp4 = st.t[11]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - xor_(reg_iterkb, reg_iterkb); - L(".kloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - xor_(reg_tmp2, reg_tmp2); - load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); - mov(reg_tmp, reg_tmp3); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kbloop", T_NEAR); - L(".unkbloop"); - generate_fma(_mtile, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_tmp2, KUNROLL * KTILE); - cmp(reg_tmp2, reg_tmp); - jb(".unkbloop"); - cmp(reg_tmp, reg_tmp3); - jge(".kend", T_NEAR); - L(".kbloop"); - generate_fma(_mtile, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_tmp2, 1 * KTILE); - cmp(reg_tmp2, reg_tmp3); - jb(".kbloop"); - L(".kend"); - add(reg_iterk, reg_tmp2); - generate_f32_accumulate(_mtile); - generate_zp_correction(_mtile); - inc(reg_iterkb); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { - for (int kk = 0; kk < _ktile; kk++) { - lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void generate_f32_accumulate(int _mtile) { - load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - - mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); - load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); - lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); - for (int i = 0; i < NRegs; i++) { - vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); - vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); - vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - } - } - } - - void generate_zp_correction(int _mtile) { - load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp1, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - auto& reg_redB = reg_tmp2; - - mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); - auto& reg_zpA = reg_tmp; - - mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); - auto& reg_scaleA = reg_tmp1; - - load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); - auto& reg_ldsa = reg_tmp3; - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); - } - - for (int i = 0; i < _mtile; i++) { - vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); - vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); - vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); - for (int j = 0; j < NRegs; j++) { - vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); - vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); - } - lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); - lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); - } - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -} // namespace kblock -} // namespace code -template