diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config index 2583e0d1b2ead..b862dec5e1c87 100644 --- a/.pipelines/nuget_config/x64/packages.config +++ b/.pipelines/nuget_config/x64/packages.config @@ -1,6 +1,6 @@  - + diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config index 5ca659941c159..c348dd3e9cdad 100644 --- a/.pipelines/nuget_config/x86/packages.config +++ b/.pipelines/nuget_config/x86/packages.config @@ -1,6 +1,6 @@  - + diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 092afa15df4df..84cc529467b05 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.17.0 +1.18.0 diff --git a/build_arm64x.bat b/build_arm64x.bat index fbcdd373086a9..1ed268ae94a43 100644 --- a/build_arm64x.bat +++ b/build_arm64x.bat @@ -5,7 +5,6 @@ setlocal set PATH=C:\Program Files\Git\usr\bin;%PATH% -set LINK_REPRO_NAME=/mylink.rsp rem Requires a Python install to be available in your PATH python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %* diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index bc96218dac79e..7d7304630c00e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -79,6 +79,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) +option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -87,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -909,6 +910,10 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) endif() + if(USE_NEURAL_SPEED) + target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) + endif() + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span @@ -1193,14 +1198,10 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -set(USE_JBLAS FALSE) -if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD) - if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) - elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") - add_compile_definitions(MLAS_JBLAS) - set(USE_JBLAS TRUE) +if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) + include(neural_speed) + if (USE_NEURAL_SPEED) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) endif() endif() diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 30d8cbf78fb1a..2c7bf9f1c2f5c 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -123,6 +123,11 @@ if (onnxruntime_DISABLE_RTTI) add_compile_options("$<$:/GR->" "$<$:/we4541>") else() add_compile_options("$<$:-fno-rtti>") + if (onnxruntime_USE_WEBNN) + # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled + # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/7001 + add_compile_options("$<$:-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0>") + endif() endif() else() #MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on. diff --git a/cmake/deps.txt b/cmake/deps.txt index ff07803013071..fda27e5e93797 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -54,4 +54,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 \ No newline at end of file diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake index dfd9ad120eb98..ae7e6d3801a64 100644 --- a/cmake/external/dml.cmake +++ b/cmake/external/dml.cmake @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML) set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config) set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config) get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE) - set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0) + set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.1) # Restore nuget packages, which will pull down the DirectML redist package. add_custom_command( diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake new file mode 100644 index 0000000000000..e66e2acfb209a --- /dev/null +++ b/cmake/external/neural_speed.cmake @@ -0,0 +1,18 @@ +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + set(USE_NEURAL_SPEED TRUE) +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + set(USE_NEURAL_SPEED TRUE) +endif() + +if(USE_NEURAL_SPEED) + FetchContent_Declare( + neural_speed + URL https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip + URL_HASH SHA1=65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 + ) + set(BTLA_USE_OPENMP OFF) + FetchContent_MakeAvailable(neural_speed) + if(NOT neural_speed_POPULATED) + FetchContent_Populate(neural_speed) + endif() +endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index c79bb87fd7f5d..403b4b2c4107a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -112,10 +112,10 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually # download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE # variable. - if (APPLE) + if (CMAKE_HOST_APPLE) # Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices. # https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html - # To keep it simple, just download and use the universal protoc binary for Apple builds. + # To keep it simple, just download and use the universal protoc binary for all Apple host builds. FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) FetchContent_Populate(protoc_binary) if(protoc_binary_SOURCE_DIR) @@ -123,7 +123,7 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() - elseif(CMAKE_CROSSCOMPILING) + elseif (CMAKE_CROSSCOMPILING) message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b995b27123218..17de2aa4aaea6 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -57,15 +57,6 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) -function(add_jblas) - add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) - target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/jblas_gemm.cpp - ) - set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) -endfunction() - #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -364,19 +355,23 @@ else() ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) @@ -622,10 +617,6 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() -if(USE_JBLAS) - add_jblas() -endif() - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f60faa4d39116..b81a5c79ac0cc 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() + set(onnxruntime_cpu_neural_speed_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" + ) + if(NOT USE_NEURAL_SPEED) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(USE_NEURAL_SPEED) + onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) + endif() +endif() + if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 84d1376f99d5e..9887d615c92d7 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -1,10 +1,25 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" - ) + + if (onnxruntime_CUDA_MINIMAL) + file(GLOB onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.cc" + ) + # Remove pch files + list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/integer_gemm.cc" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/triton_kernel.h" + ) + else() + file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" + ) + endif() # Remove pch files list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" @@ -16,11 +31,16 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) - file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" - "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" - ) + + if (onnxruntime_CUDA_MINIMAL) + set(onnxruntime_providers_cuda_shared_srcs "") + else() + file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" + "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" + ) + endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) @@ -156,10 +176,15 @@ endif() add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) - if(onnxruntime_CUDNN_HOME) - target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) + target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + else() + target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + if(onnxruntime_CUDNN_HOME) + target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) + target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + endif() endif() if (onnxruntime_USE_TRITON_KERNEL) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 858583e64e9df..546d50c1ca2d3 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -268,7 +268,10 @@ else() endif() if (onnxruntime_USE_WEBNN) - set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + if (onnxruntime_DISABLE_RTTI) + set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") + endif() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 68a399f8b9671..7fe16f4156ef2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -65,10 +65,10 @@ static NativeTrainingMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); // TODO: Make this save the pointer, and not copy the whole structure across - api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/); + api_ = (OrtApi)OrtGetApi(18 /*ORT_API_VERSION*/); OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); - trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/); + trainingApiPtr = OrtGetTrainingApi(18 /*ORT_API_VERSION*/); if (trainingApiPtr != IntPtr.Zero) { trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi)); diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45c0e6f822ce9..22e82443167f6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
unidirectional : int
+
Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8) @@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
+
num_heads : int
+
Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
+
rotary_embedding_dim : int
+
Rotary embedding dimension. Default value is 0.
scale : float
Custom scale will be used if specified. Default value is 1.0
@@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
sin_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2).
+
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
#### Outputs @@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output types to float tensors.
M : tensor(int64)
Constrain input and output types to integer tensors
diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index bede16204d420..91057d3dfb120 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -293,6 +293,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_MEMORY_OPT_LEVEL=0 ``` +### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes. + + ```bash + export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 # Enable + export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 394bd7ad2abae..9a2a7ac89bbb3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -868,7 +868,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -922,10 +922,12 @@ Do not modify directly.* |BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float), tensor(float16)| @@ -952,7 +954,8 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -961,7 +964,8 @@ Do not modify directly.* |DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| @@ -1004,7 +1008,8 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1099,7 +1104,8 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| @@ -1150,7 +1156,8 @@ Do not modify directly.* |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| @@ -1178,7 +1185,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| @@ -1188,7 +1196,8 @@ Do not modify directly.* |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| diff --git a/docs/python/README.rst b/docs/python/README.rst index 32bb3729e01d0..bbc8571fe3f17 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime CreatePreferredAllocators() { return std::vector(); }; + /** + * Get the array of pointers for EPContext nodes + * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it. + * Default return an empty vector if not provided by the Execution Provider + */ + virtual const InlinedVector GetEpContextNodes() const { + return InlinedVector(); + } + private: const std::string type_; diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 9416fad5f1448..1370f5c4c5e10 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -16,9 +16,10 @@ #include "core/providers/custom_op_context.h" #include #include +#ifndef USE_CUDA_MINIMAL #include #include - +#endif namespace Ort { namespace Custom { diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 60196d0c80cbb..32a9f06464ace 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -11,6 +11,8 @@ /// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions. /// struct OrtTensorRTProviderOptionsV2 { + OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator + int device_id{0}; // cuda device id. int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream. void* user_compute_stream{nullptr}; // user specified CUDA compute stream. @@ -46,8 +48,26 @@ struct OrtTensorRTProviderOptionsV2 { const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT - int trt_dump_ep_context_model{0}; // Dump EP context node model - int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data - int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute - const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix + + /* + * Please note that there are rules for using following context model related provider options: + * + * 1. In the case of dumping the context model and loading the context model, + * for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be + * the absolute path or relative path that is outside of context model directory. + * It means engine cache needs to be in the same directory or sub-directory of context model. + * + * 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory. + * For example: + * If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled, + * if "trt_ep_context_file_path" is "./context_model_dir", + * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir" + * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir" + * + */ + int trt_dump_ep_context_model{0}; // Dump EP context node model + const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path. + int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data + + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b321b2b2bac27..101a578ec3e1d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -38,7 +38,7 @@ * * This value is used by some API functions to behave as this version of the header expects. */ -#define ORT_API_VERSION 17 +#define ORT_API_VERSION 18 #ifdef __cplusplus extern "C" { @@ -3608,6 +3608,14 @@ struct OrtApi { * - "1": Faster preparation time, less optimal graph. * - "2": Longer preparation time, more optimal graph. * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index df79cb6e5b21b..b282438795eb5 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; -// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file. +// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) // "1": enable. @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 69ccb954e8afe..1c21387b50455 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -7,6 +7,7 @@ import java.lang.reflect.Array; import java.nio.Buffer; import java.util.Arrays; +import java.util.stream.Collectors; /** Describes an {@link OnnxTensor}, including it's size, shape and element type. */ public class TensorInfo implements ValueInfo { @@ -159,6 +160,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { /** The shape of the tensor. */ final long[] shape; + /** The names of the unbound dimensions. */ + final String[] dimensionNames; + + /** If there are non-empty dimension names */ + private final boolean hasNames; + /** The Java type of this tensor. */ public final OnnxJavaType type; @@ -177,6 +184,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { */ TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) { this.shape = shape; + this.dimensionNames = new String[shape.length]; + Arrays.fill(dimensionNames, ""); + this.hasNames = false; this.type = type; this.onnxType = onnxType; this.numElements = elementCount(shape); @@ -188,10 +198,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { *

Called from JNI. * * @param shape The tensor shape. + * @param names The dimension names. * @param typeInt The native type int. */ - TensorInfo(long[] shape, int typeInt) { + TensorInfo(long[] shape, String[] names, int typeInt) { this.shape = shape; + this.dimensionNames = names; + boolean hasNames = false; + for (String s : names) { + if (!s.isEmpty()) { + hasNames = true; + break; + } + } + this.hasNames = hasNames; this.onnxType = OnnxTensorType.mapFromInt(typeInt); this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); this.numElements = elementCount(shape); @@ -206,15 +226,42 @@ public long[] getShape() { return Arrays.copyOf(shape, shape.length); } + /** + * Get a copy of the tensor's named dimensions. + * + * @return A copof the tensor's named dimensions. + */ + public String[] getDimensionNames() { + return Arrays.copyOf(dimensionNames, dimensionNames.length); + } + @Override public String toString() { - return "TensorInfo(javaType=" - + type.toString() - + ",onnxType=" - + onnxType.toString() - + ",shape=" - + Arrays.toString(shape) - + ")"; + String output = + "TensorInfo(javaType=" + + type.toString() + + ",onnxType=" + + onnxType.toString() + + ",shape=" + + Arrays.toString(shape); + if (hasNames) { + output = + output + + ",dimNames=[" + + Arrays.stream(dimensionNames) + .map( + a -> { + if (a.isEmpty()) { + return "\"\""; + } else { + return a; + } + }) + .collect(Collectors.joining(",")) + + "]"; + } + output = output + ")"; + return output; } /** diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 879ba8a310618..7b26291581395 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT if (code != ORT_OK) { return NULL; } - //printf("numDim %d\n",numDim); int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim); code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); if (code != ORT_OK) { @@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT free(dimensions); dimensions = NULL; + // Create the string array for the names. + const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim); + if (dimensionNames == NULL) { + throwOrtException(jniEnv, 1, "Not enough memory"); + return NULL; + } + code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim)); + if (code != ORT_OK) { + // extraction failed, exception has been thrown, return to Java. + free(dimensionNames); + return NULL; + } + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL); + for (size_t i = 0; i < numDim; i++) { + jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName); + } + free(dimensionNames); + // Create the TensorInfo object static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo"; jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName); - jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V"); - //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor); - jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt); + jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([J[Ljava/lang/String;I)V"); + jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt); return tensorInfo; } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f6f9da1829402..7fef2dc784b7b 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -590,6 +590,12 @@ public void testSymbolicDimensionAssignment() throws OrtException { Map infoMap = session.getInputInfo(); TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); assertArrayEquals(new long[] {-1, 2}, aInfo.shape); + assertEquals(2, aInfo.dimensionNames.length); + assertEquals("n", aInfo.dimensionNames[0]); + assertEquals("", aInfo.dimensionNames[1]); + TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo(); + assertEquals(1, bInfo.dimensionNames.length); + assertEquals("m", bInfo.dimensionNames[0]); } } // Check that when the options are assigned it overrides the symbolic dimension diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 96c2361cceabe..40f970ddf02ae 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 84f6dba83fa59..a5ada877b916a 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/common/package.json b/js/common/package.json index beab7d29be263..64ab2736adbe3 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 96c2361cceabe..40f970ddf02ae 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 542eebe746d59..2d7c39c86097f 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.17.0", + "version": "1.18.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "os": [ "win32", @@ -27,7 +27,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/node/package.json b/js/node/package.json index 8e591d8f46b9d..026840742e29e 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.17.0", + "version": "1.18.0", "dependencies": { "onnxruntime-common": "file:../common" }, diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 96c2361cceabe..40f970ddf02ae 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/react_native/package.json b/js/react_native/package.json index 39e6cb08bb06a..47324a76fe55f 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -36,7 +36,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.17.0", + "version": "1.18.0", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index ff9be7fbe3a5b..4dca90d7415cf 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.17.0" + version "1.18.0" open@^6.2.0: version "6.4.0" diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2f510308d9306..2557971eb4ded 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -52,6 +52,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index d9f63fec9c492..31ecffb07e40c 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -31,6 +31,12 @@ export const initializeFlags = (): void => { } if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) { + // Web: when crossOriginIsolated is false, SharedArrayBuffer is not available so WebAssembly threads will not work. + // Node.js: onnxruntime-web does not support multi-threads in Node.js. + if ((typeof self !== 'undefined' && !self.crossOriginIsolated) || + (typeof process !== 'undefined' && process.versions && process.versions.node)) { + env.wasm.numThreads = 1; + } const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency; env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2)); } diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 96c2361cceabe..40f970ddf02ae 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.17.0'; +export const version = '1.18.0'; diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 9d4d5875310b7..68054210e79a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. + */ + jsepOnRunStart: () => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 2956ec1cad4da..afef7042a4280 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -208,7 +208,7 @@ export class WebGpuBackend { Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - // init queryType, which is necessary for createKernel + // init queryType, which is necessary for InferenceSession.create this.setQueryType(); } @@ -223,8 +223,6 @@ export class WebGpuBackend { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - // refresh queryType, as sometimes we only need to enable query for a specific run - this.setQueryType(); if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { this.querySet = this.device.createQuerySet({ type: 'timestamp', @@ -639,6 +637,7 @@ export class WebGpuBackend { return createView(data.buffer, type); }; } + // #endregion writeTimestamp(index: number): void { if (this.queryType !== 'inside-passes') { return; @@ -657,5 +656,7 @@ export class WebGpuBackend { } } } - // #endregion + onRunStart(): void { + this.setQueryType(); + } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 90e02da986b8f..cc504093ca0d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 30754c84413b7..a0d4021516bf7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -100,8 +100,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${calculateAlpha} ${(() => { if (c != null) { - return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += uniforms.beta * ${ - c.getByOffset('cOffset')};`; + return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${ + dataType}(uniforms.beta) * ${c.getByOffset('cOffset')};`; } return ''; })()} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index a25e7fe4229b4..82311d72e58b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; +export interface HardSigmoidAttributes extends AttributeWithCacheKey { + readonly alpha: number; + readonly beta: number; +} + +export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => + createAttributeWithCacheKey(attributes as { + alpha: number; + beta: number; + }); + +export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'HardSigmoid', + a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta})))`, + undefined, attributes.cacheKey)); +}; + export const sin = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5821fac3c468f..8768643fa7257 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -488,8 +488,8 @@ export const run = async( } } + wasm.jsepOnRunStart?.(); let errorCode: number; - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index 81508a253ce8b..9b9334c93b78c 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -28,13 +28,34 @@ let initialized = false; let initializing = false; let aborted = false; -const isMultiThreadSupported = (): boolean => { - try { - // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. - if (typeof SharedArrayBuffer === 'undefined') { - return false; +const isMultiThreadSupported = (numThreads: number): boolean => { + // WebAssembly threads are set to 1 (single thread). + if (numThreads === 1) { + return false; + } + + // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work. + if (typeof SharedArrayBuffer === 'undefined') { + if (typeof self !== 'undefined' && !self.crossOriginIsolated) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', but this will not work unless you enable crossOriginIsolated mode. ' + + 'See https://web.dev/cross-origin-isolation-guide/ for more info.'); } + return false; + } + + // onnxruntime-web does not support multi-threads in Node.js. + if (typeof process !== 'undefined' && process.versions && process.versions.node) { + // eslint-disable-next-line no-console + console.warn( + 'env.wasm.numThreads is set to ' + numThreads + + ', however, currently onnxruntime-web does not support multi-threads in Node.js. ' + + 'Please consider using onnxruntime-node for performance critical scenarios.'); + } + try { // Test for transferability of SABs (for browsers. needed for Firefox) // https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ if (typeof MessageChannel !== 'undefined') { @@ -106,7 +127,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise const numThreads = flags.numThreads!; const simd = flags.simd!; - const useThreads = numThreads > 1 && isMultiThreadSupported(); + const useThreads = isMultiThreadSupported(numThreads); const useSimd = simd && isSimdSupported(); const wasmPaths = flags.wasmPaths; diff --git a/js/web/lib/wasm/wasm-utils-load-file.ts b/js/web/lib/wasm/wasm-utils-load-file.ts index abe480a43c790..c6cdba2320bde 100644 --- a/js/web/lib/wasm/wasm-utils-load-file.ts +++ b/js/web/lib/wasm/wasm-utils-load-file.ts @@ -47,9 +47,19 @@ export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Pro } const reader = response.body.getReader(); - // use WebAssembly Memory to allocate larger ArrayBuffer - const pages = Math.ceil(fileSize / 65536); - const buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + let buffer; + try { + // try to create ArrayBuffer directly + buffer = new ArrayBuffer(fileSize); + } catch (e) { + if (e instanceof RangeError) { + // use WebAssembly Memory to allocate larger ArrayBuffer + const pages = Math.ceil(fileSize / 65536); + buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer; + } else { + throw e; + } + } let offset = 0; // eslint-disable-next-line no-constant-condition diff --git a/js/web/package-lock.json b/js/web/package-lock.json index cd71c20ba4d2f..41c44aaa2679b 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.17.0", + "version": "1.18.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -28,7 +28,7 @@ "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", - "electron": "^23.1.2", + "electron": "^28.1.4", "globby": "^13.1.3", "karma": "^6.4.1", "karma-browserstack-launcher": "^1.6.0", @@ -49,7 +49,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.17.0", + "version": "1.18.0", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" @@ -862,9 +862,9 @@ } }, "node_modules/cross-spawn/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -1042,14 +1042,14 @@ "dev": true }, "node_modules/electron": { - "version": "23.3.13", - "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz", - "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==", + "version": "28.1.4", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", + "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", "dev": true, "hasInstallScript": true, "dependencies": { "@electron/get": "^2.0.0", - "@types/node": "^16.11.26", + "@types/node": "^18.11.18", "extract-zip": "^2.0.1" }, "bin": { @@ -1059,12 +1059,6 @@ "node": ">= 12.20.55" } }, - "node_modules/electron/node_modules/@types/node": { - "version": "16.18.14", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz", - "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==", - "dev": true - }, "node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", @@ -1432,9 +1426,9 @@ } }, "node_modules/get-func-name": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz", - "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz", + "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==", "dev": true, "engines": { "node": "*" @@ -1542,9 +1536,9 @@ } }, "node_modules/global-agent/node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "optional": true, "dependencies": { @@ -2908,9 +2902,9 @@ "dev": true }, "node_modules/semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true, "bin": { "semver": "bin/semver.js" @@ -4203,9 +4197,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } @@ -4339,22 +4333,14 @@ "dev": true }, "electron": { - "version": "23.3.13", - "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz", - "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==", + "version": "28.1.4", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", + "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", "dev": true, "requires": { "@electron/get": "^2.0.0", - "@types/node": "^16.11.26", + "@types/node": "^18.11.18", "extract-zip": "^2.0.1" - }, - "dependencies": { - "@types/node": { - "version": "16.18.14", - "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz", - "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==", - "dev": true - } } }, "emoji-regex": { @@ -4657,9 +4643,9 @@ "dev": true }, "get-func-name": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz", - "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz", + "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==", "dev": true }, "get-intrinsic": { @@ -4742,9 +4728,9 @@ }, "dependencies": { "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "optional": true, "requires": { @@ -5780,9 +5766,9 @@ "dev": true }, "semver": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz", - "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==", + "version": "6.3.1", + "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz", + "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==", "dev": true }, "semver-compare": { diff --git a/js/web/package.json b/js/web/package.json index 7ffc9ba16aaa9..a502c2b6b032d 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -8,7 +8,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.17.0", + "version": "1.18.0", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", @@ -47,7 +47,7 @@ "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", - "electron": "^23.1.2", + "electron": "^28.1.4", "globby": "^13.1.3", "karma": "^6.4.1", "karma-browserstack-launcher": "^1.6.0", diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 033b3b3f4b0f5..373b3c645df57 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -597,9 +597,9 @@ // // "test_hardmax_example", // // "test_hardmax_negative_axis", // // "test_hardmax_one_hot", - // // "test_hardsigmoid_default", - // // "test_hardsigmoid_example", - // // "test_hardsigmoid", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", "test_if", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 57219c50f39aa..c3699f0fb33ad 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.17.0" +__version__ = "1.18.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 4711ccf487cc8..768676259aa14 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const { relative_position_bias, ¶meters)); + if (parameters.do_rotary) { + ORT_NOT_IMPLEMENTED( + "Rotary embedding is not supported in Attention CPU kernel. \ + Please fuse the model with MHA + RotaryEmbedding."); + } + const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int input_hidden_size = parameters.input_hidden_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 694c40bf3eda6..eb25d0fd7cc1e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i num_heads_ = static_cast(num_heads); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } // Reshape Q/K/V from BxSxD to BxSxNxH @@ -283,8 +284,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { nullptr, ¶meters, num_heads_, - scale, mask_filter_value_, + scale, + is_unidirectional_, past_present_share_buffer, false)); diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index 4c86b777e9842..fb7da78a5c0a5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { protected: int num_heads_; // number of attention heads float mask_filter_value_; + bool is_unidirectional_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 00e82c9844b3d..c91f5b601b4e9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -25,6 +25,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing) { // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None @@ -315,7 +316,7 @@ Status CheckInputs(const T* query, output_parameters->head_size = hidden_size / num_heads; output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; - output_parameters->is_unidirectional = false; + output_parameters->is_unidirectional = is_unidirectional; output_parameters->past_present_share_buffer = past_present_share_buffer; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; @@ -342,6 +343,7 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, + bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing, int max_threads_per_block) { @@ -350,8 +352,8 @@ Status CheckInputs(const T* query, } return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, - past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer, - dmmha_packing); + past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, + past_present_share_buffer, dmmha_packing); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 47f462d75fcc4..aa8b5b5f608fa 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); + + if (rotary_embedding_dim > 0) { + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } } template @@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int num_heads = parameters.num_heads; + const int n_heads = parameters.num_heads; const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; - const int half_head_size = head_size / 2; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] int head_stride = head_size; - int seq_stride = num_heads * head_stride; + int seq_stride = n_heads * head_stride; int batch_stride = sequence_length * seq_stride; if (parameters.transposed) { - // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] seq_stride = head_size; head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; + batch_stride = n_heads * head_stride; } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - const int loop_len = batch_size * sequence_length * num_heads; - const double cost = static_cast(head_size); + const int loop_len = batch_size * sequence_length * n_heads; + const double cost = static_cast(rotary_emb_dim); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / num_heads) / sequence_length); - const int s = static_cast((ptr / num_heads) % sequence_length); - const int n = static_cast(ptr % num_heads); + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input_src + block_offset; T* output_data = output_dest + block_offset; - // Cache is (M, H/2) + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) const int position_id = (position_ids_format == 0) ? static_cast(pos_ids_data[0]) + s : static_cast(pos_ids_data[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_emb_dim; const T* cos_data = cos_cache_data + cache_offset; const T* sin_data = sin_cache_data + cache_offset; int cache_idx = 0; T sign = 0; int j = 0; - for (int i = 0; i < head_size; i++) { + for (int i = 0; i < rotary_emb_dim; i++) { if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_emb_dim; sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_emb_dim; + sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } + for (int i = rotary_emb_dim; i < head_size; i++) { + output_data[i] = input_data[i]; + } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index be834a66cdc69..4e32424a22b6c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index 7b2e8289f7b06..dcbb36d1c4a3c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -11,14 +11,15 @@ namespace rotary_embedding_helper { // Parameters deduced from node attributes and inputs/outputs. struct RotaryParameters { - int batch_size; // Batch size used by input - int sequence_length; // Sequence length used by input - int hidden_size; // Hidden size used by input - int head_size; // Head size used by cos/sin cache * 2 - int num_heads; // num_heads = hidden_size / head_size - int max_sequence_length; // Sequence length used by cos/sin cache - int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) - bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size + int rotary_embedding_dim; // Rotary embedding dimension. + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -26,11 +27,13 @@ Status CheckInputs(const T* input, const T* position_ids, const T* cos_cache, const T* sin_cache, + int num_heads, + int rotary_embedding_dim, void* parameters) { // input : (batch_size, sequence_length, hidden_size) // position ids : (1) or (batch_size, sequence_length) - // cos cache : (max_sequence_length, head_size / 2) - // sin cache : (max_sequence_length, head_size / 2) + // cos cache : (max_sequence_length, rotary_embedding_dim / 2) + // sin cache : (max_sequence_length, rotary_embedding_dim / 2) // Check input const auto& input_dims = input->Shape().GetDims(); @@ -60,6 +63,12 @@ Status CheckInputs(const T* input, "the same shape"); } + // Check num_heads and rotary_embedding_dim + if (rotary_embedding_dim > 0 && num_heads == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", + "specified"); + } + // Get attributes from inputs int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); @@ -73,8 +82,13 @@ Status CheckInputs(const T* input, transposed = true; } int max_sequence_length = static_cast(cos_cache_dims[0]); - int head_size = static_cast(cos_cache_dims[1]) * 2; - int num_heads = hidden_size / head_size; + int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 + : static_cast(hidden_size / num_heads); + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + int position_ids_format = -1; // Check position_ids input shapes @@ -91,23 +105,15 @@ Status CheckInputs(const T* input, } else { position_ids_format = 0; } + // Check cos_cache input shapes if (max_sequence_length != static_cast(cos_cache_dims[0])) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", "max_sequence_length, got ", cos_cache_dims[0]); } - if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", - "head_size / 2, got ", cos_cache_dims[1]); - } - // Check sin_cache input shapes - if (max_sequence_length != static_cast(sin_cache_dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", - "max_sequence_length, got ", sin_cache_dims[0]); - } - if ((head_size / 2) != static_cast(sin_cache_dims[1])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", - "head_size / 2, got ", sin_cache_dims[1]); + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); } // Set rotary parameters @@ -117,10 +123,11 @@ Status CheckInputs(const T* input, output_parameters->sequence_length = sequence_length; output_parameters->hidden_size = hidden_size; output_parameters->head_size = head_size; - output_parameters->num_heads = num_heads; + output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; output_parameters->transposed = transposed; + output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 406c73c95d444..72948c74d7877 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,6 +9,9 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#ifdef ORT_NEURAL_SPEED +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#endif namespace onnxruntime { namespace contrib { @@ -24,15 +27,17 @@ class MatMulNBits final : public OpKernel { accuracy_level_{info.GetAttr("accuracy_level")} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); - is_asym_ = info.GetInputCount() >= 4; +#ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; const Tensor* tensor_zero_point = nullptr; bool B_constant = info.TryGetConstantInput(1, &tensor_B); bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + is_asym_ = info.GetInputCount() >= 4; all_constant_ = B_constant && scale_constant; all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; +#endif } Status Compute(OpKernelContext* context) const override; @@ -53,30 +58,34 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; +#ifdef ORT_NEURAL_SPEED bool is_asym_{false}; bool all_constant_{false}; +#endif }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; +#ifdef ORT_NEURAL_SPEED if (!all_constant_) { return Status::OK(); } - -#if defined(MLAS_JBLAS) - - auto compt_type = static_cast(accuracy_level_); MLAS_THREADPOOL* pool = NULL; + if (nbits_ != 4) { + return Status::OK(); + } + auto comp_type = static_cast(accuracy_level_); + auto nbits = static_cast(nbits_); if (input_idx == 1) { - packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); if (packed_b_size_ == 0) return Status::OK(); auto qptr = tensor.Data(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); std::memset(packed_b_.get(), 0, packed_b_size_); - MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, false, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -85,8 +94,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 2 && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, !is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -95,8 +104,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } if (input_idx == 3 && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), - is_asym_, is_asym_, compt_type, pool); + NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, + comp_type, pool); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -104,7 +113,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(MLAS_JBLAS) +#else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); @@ -119,7 +128,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } @@ -127,9 +136,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; - -#if defined(MLAS_JBLAS) - +#ifdef ORT_NEURAL_SPEED // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -144,14 +151,14 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep packed_b_ = std::move(prepacked_buffers[0]); } -#else // defined(MLAS_JBLAS) +#else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) return Status::OK(); } @@ -160,9 +167,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); - -#if defined(MLAS_JBLAS) - +#ifdef ORT_NEURAL_SPEED if (packed_b_.get()) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); @@ -181,7 +186,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t N = static_cast(helper.N()); const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); + std::vector gemm_params(max_len); AllocatorPtr allocator; auto status = ctx->GetTempSpaceAllocator(&allocator); ORT_RETURN_IF_ERROR(status); @@ -192,15 +197,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { gemm_params[i].C = y_data + helper.OutputOffsets()[i]; gemm_params[i].ldc = N; } - auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data()); + auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); // workspace for activation process(dynamic quantization and others) auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); - MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), - thread_pool); + NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); return Status::OK(); } -#endif // defined(MLAS_JBLAS) +#endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h new file mode 100644 index 0000000000000..864abffd131fe --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h @@ -0,0 +1,45 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" + +namespace bestla { + +using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; +using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; +using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; +using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; +using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; +using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; + +template +using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; +template +using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; + +class ORTThreading : public parallel::IThreading { + public: + explicit ORTThreading(void* tp); + void parallel_for(const parallel::thread_func& func) const override; + void set_threads(int nthreads) override { + (void)(nthreads); + assert(0); + } + void sync() const override { assert(0); } + void* mTp; +}; + +} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc new file mode 100644 index 0000000000000..73aaa4ae61a6e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -0,0 +1,438 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.cpp + +Abstract: + + GEMM template combinations of neural_speed. +--*/ + +#include "contrib_ops/cpu/quantization/neural_speed_defs.h" +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#include "core/platform/threadpool.h" + +using ThreadPool = onnxruntime::concurrency::ThreadPool; + +namespace bestla { + +ORTThreading::ORTThreading(void* tp) + : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} + +void ORTThreading::parallel_for(const parallel::thread_func& func) const { + ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, + [&](ptrdiff_t tid) { func(static_cast(tid)); }); +} + +template +static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + if (M <= 16) { + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->IsAsym()) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); + } + typename Launcher::Param args{gp, + {A, lda_, &reduceA}, + {B}, + {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), + reduceA.template RPtr(), reduceA.lda}, + {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } else { + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + static Launcher kernel; + typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } +} + +template +static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); +} + +template +static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(A); + (void)(N); + (void)(C); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using ProA = prologue_a::gemm::ActivationKBlockBaseF32; + static ProA proA; + if (B->IsAsym()) { + auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + // using ProA = prologue_a::gemm::ActivationBase; + return 0; + } +} + +template +static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + (void)(N); + (void)(lda); + (void)(ldc); + (void)(A); + (void)(C); + using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; + static ProA proA; + auto quanA = + proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); + return quanA.mSize; +} + +} // namespace bestla + +using namespace bestla; + +static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, + void* ThreadPool) { + GetCPUDevice(); + bestla::ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, + DataParams[i].ldc, WorkSpace, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && + BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && + BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + } + } + } + return size; +} + +template +static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { + static T proB; + auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), + BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto wptr = reinterpret_cast(ptr); + auto BlkSize = wptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} + +template +static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, + const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, + void* ThreadPool) { + static T proB; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + BTLA_DTYPE::BF16, IsAsym); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + proB.reduceWeight(&stor, &orth); + } +} + +static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + [[fallthrough]]; + default: + return 0; + } +} + +static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, + size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, + K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, + lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, + ldb, ThreadPool); + return true; + } + [[fallthrough]]; + default: + return false; + } +} + +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, + NS_SQNBIT_COMPUTE_TYPE CompType) { + if (nbits == 4) { + auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } + return 0; +} + +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + if (nbits == 4) { + if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +} + +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +} + +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +} + +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by bestla + return; + } +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h new file mode 100644 index 0000000000000..ebcb3027a209f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.h + +Abstract: + + Prepack-weight GEMM APIs of neural_speed. +--*/ + +#pragma once + +#include +#include + +/** + * @brief Define compute types of block quantization + */ +enum NS_SQNBIT_COMPUTE_TYPE { + NSCompUndef = 0, /*!< undef */ + NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ + NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ + NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ + NSCompInt8 = 4 /*!< input int8, accumulator int32 */ +}; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, + NS_SQNBIT_COMPUTE_TYPE comp_type); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, + NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h new file mode 100644 index 0000000000000..d3902f9bd68c7 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h @@ -0,0 +1,39 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-value" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" + +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4457) +#pragma warning(disable : 4189) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4702) +#endif + +#include "bestla/bestla_prologue_a.h" +#include "bestla/bestla_wrapper.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 56d950ca2f41e..dc72a038c3d58 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -397,12 +397,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 94547887d3a90..cd891a9508019 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -404,12 +404,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 91b93a125ad7a..4d6643c68a98b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -500,12 +500,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe output_sequences_scores); // Output per token scores - if (output_scores) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = beam_state.scores; - assert(target.size() == source.size()); - ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } + gsl::span per_token_scores = beam_state.scores; + this->beam_scorer_->OutputScores(per_token_scores, output_scores); return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 7e2e5b2129221..0eccbe26605f5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -50,11 +50,12 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con return beams_.back().score < current_score; } +template void BeamHypotheses::Output( int top_k, int max_length, - gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) - gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty + gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) + gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences ORT_ENFORCE(top_k <= beams_used_); @@ -67,7 +68,7 @@ void BeamHypotheses::Output( gsl::copy(item.hypothesis, target); if (!sequences_scores.empty()) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -181,21 +182,21 @@ void BeamSearchScorer::Process(ISequences& sequences, } } -void BeamSearchScorer::Finalize(ISequences& sequences, - gsl::span& final_beam_scores, - Tensor* output_sequences, - Tensor* output_sequence_scores) { - ORT_ENFORCE(output_sequences != nullptr); - +template +void OutputSequenceScores(BeamSearchScorer* scorer, + ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { // Finalize all open beam hypotheses and add to generated hypotheses. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; if (beam_hyp.done_) { continue; } - for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) { - size_t batch_beam_index = batch_index * num_beams_ + beam_index; + for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) { + size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index; float final_score = final_beam_scores[batch_beam_index]; auto final_tokens = sequences.GetSequence(narrow(batch_beam_index)); beam_hyp.Add(final_tokens, final_score); @@ -206,26 +207,59 @@ void BeamSearchScorer::Finalize(ISequences& sequences, gsl::span output = output_sequences->MutableDataAsSpan(); // Fill output sequences with pad token ID so that we do not need append it later. - std::fill_n(output.data(), output.size(), pad_token_id_); + std::fill_n(output.data(), output.size(), scorer->pad_token_id_); // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; + gsl::span sequence_scores; if (output_sequence_scores) { - sequence_scores = output_sequence_scores->MutableDataAsSpan(); + sequence_scores = output_sequence_scores->MutableDataAsSpan(); } // Select the best hypotheses according to number of sequences to return. - for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { - BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { + BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; - auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_, - num_return_sequences_ * max_length_); - gsl::span sequence_scores_buffer; + auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_, + scorer->num_return_sequences_ * scorer->max_length_); + gsl::span sequence_scores_buffer; if (!sequence_scores.empty()) - sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_); + sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_); + + beam_hyp.template Output(narrow(scorer->num_return_sequences_), narrow(scorer->max_length_), batch_output, + sequence_scores_buffer); + } +} + +void BeamSearchScorer::Finalize(ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + ORT_ENFORCE(output_sequences != nullptr); - beam_hyp.Output(narrow(num_return_sequences_), narrow(max_length_), batch_output, - sequence_scores_buffer); + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } +} + +void BeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + std::copy_n(final_scores.data(), final_scores.size(), target.data()); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target = output_scores->MutableDataAsSpan(); + ORT_ENFORCE(target.size() == final_scores.size()); + const float* src = final_scores.data(); + MLFloat16* dst = target.data(); + for (size_t i = 0; i < target.size(); i++) { + dst[i] = MLFloat16(src[i]); + } + } } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 94b6d340d9f4a..dc92e8038a68e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -35,10 +35,11 @@ struct BeamHypotheses { bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) - gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) + gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring int beams_used_; // Number of elements used in beams_ @@ -60,13 +61,14 @@ struct BeamSearchScorer : IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return not_done_count_ == 0; } gsl::span GetNextScores() override { return next_beam_scores_; } gsl::span GetNextTokens() override { return next_beam_tokens_; } gsl::span GetNextIndicesCPU() override { return next_beam_indices_; } - private: size_t batch_size_; size_t num_beams_; size_t max_length_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index f6faf2e325f8f..cb62e2f7bf4da 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -120,6 +120,9 @@ struct IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) = 0; + virtual void OutputScores(gsl::span& final_scores, + Tensor* output_scores) = 0; + virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index ebd66d8c6528e..f978f50c6851f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); disable_fused_self_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); @@ -105,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { num_heads_, mask_filter_value_, scale_, + is_unidirectional_, false, // past_present_share_buffer false, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index c162f7133cc1c..86a32c92ce003 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel { int num_heads_; // number of attention heads float mask_filter_value_; float scale_; + bool is_unidirectional_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 2d12e975d88d7..9de7ba3885c3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -29,10 +29,13 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); } @@ -48,6 +51,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, + num_heads, + rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -71,6 +76,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.num_heads, parameters.head_size, + parameters.rotary_embedding_dim, parameters.max_sequence_length, parameters.position_ids_format, interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h index 6dab2ad56749e..d52f61d670444 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -19,6 +19,8 @@ class RotaryEmbedding final : public CudaKernel { protected: float scale; + int num_heads; + int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index e1b83bd8caf54..c6637041f05bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -26,6 +26,7 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int position_ids_format, const bool interleaved, const int batch_stride, @@ -33,24 +34,33 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently - + const int b = blockIdx.z; const int s = blockIdx.y; const int n = blockIdx.x; const int i = threadIdx.x; + if (i >= head_size) { + return; + } + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input + block_offset; T* output_data = output + block_offset; + if (i >= rotary_embedding_dim) { + output_data[i] = input_data[i]; + return; + } + // Cache is (M, H/2) - const int half_head_size = head_size / 2; + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; const int position_id = (position_ids_format == 0) ? \ static_cast(position_ids[0]) + s \ : static_cast(position_ids[b * sequence_length + s]); - const int cache_offset = position_id * half_head_size; + const int cache_offset = position_id * half_rotary_embedding_dim; const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; @@ -58,13 +68,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH T sign = 0; int j = 0; if (interleaved) { - cache_idx = (i / 2) % half_head_size; + cache_idx = (i / 2) % half_rotary_embedding_dim; sign = (i % 2 == 0) ? -1 : 1; j = (i % 2 == 0) ? i+1 : i-1; // i - sign } else { - cache_idx = i % half_head_size; - sign = (i < half_head_size) ? -1 : 1; - j = (i + half_head_size) % head_size; + cache_idx = i % half_rotary_embedding_dim; + sign = (i < half_rotary_embedding_dim) ? -1 : 1; + j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } @@ -82,20 +92,23 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed) { - - constexpr int smem_size = 0; - const dim3 grid(num_heads, sequence_length, batch_size); - const dim3 block(head_size, 1, 1); - // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + ORT_ENFORCE(head_size <= max_threads_per_block, + "Rotary embedding dim must be <= max_threads_per_block"); + + int tpb = (head_size + 31)/32*32; + + const dim3 block(tpb); + const dim3 grid(num_heads, sequence_length, batch_size); // Default input tensor shape is [batch, seq, hidden_size] int head_stride = head_size; @@ -109,10 +122,9 @@ Status LaunchRotaryEmbeddingKernel( } assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>( - output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved, - batch_stride, seq_stride, head_stride + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, + rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -129,6 +141,7 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, @@ -146,6 +159,25 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + const bool transposed); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + BFloat16* output, + const BFloat16* input, + const int64_t* position_ids, + const BFloat16* cos_cache, + const BFloat16* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index ee1ccc43dcbff..36300fe7a660f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -21,6 +21,7 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 34b44694a5fcc..fa73950c9c6f5 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -299,6 +300,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index dbd7fb010462d..a39abefed9cd0 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -307,12 +307,13 @@ __device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_ return beams_[beams_count_ - 1].score < current_score; } +template __device__ void BeamHypotheses::Output( int top_k, int max_length, int pad_token_id, int32_t* sequences, // buffer of shape (num_return_sequences, max_length) - float* sequences_scores) // buffer of shape (num_return_sequences) or empty + T* sequences_scores) // buffer of shape (num_return_sequences) or empty { // Copy the top_k beams into the sequences for (int index = 0; index < top_k; index++) { @@ -327,7 +328,7 @@ __device__ void BeamHypotheses::Output( target[i] = pad_token_id; if (sequences_scores) - sequences_scores[index] = item.score; + sequences_scores[index] = (T)item.score; } } @@ -501,13 +502,14 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp next_beam_tokens.data()); } +template __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, const int32_t* sequences_buffer, int sequence_length, BeamHypotheses* beam_hyps_, const float* final_beam_scores, int32_t* output, - float* sequence_scores) { + T* sequence_scores) { int batch_index = blockIdx.x * blockDim.x + threadIdx.x; if (batch_index >= state.batch_size_) return; @@ -534,6 +536,7 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, sequence_scores ? sequence_scores + batch_index * state.num_return_sequences_ : nullptr); } +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -541,7 +544,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream) { BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state, sequences.data(), @@ -552,6 +555,58 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, sequence_scores.data()); } +template void LaunchBeamSearchScorer_Finalize( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span sequence_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScorer_Finalize<__half>( + int batch_size, + BeamScorerState& state, + gsl::span sequences, + int sequence_length, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, + gsl::span<__half> sequence_scores, + cudaStream_t stream); + +template +__global__ void FloatConvertAndCopyKernel(const float* src, T* dst, size_t total_elements) { + int64_t index = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (index < total_elements) { + dst[index] = (T)src[index]; + } +} + +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream) { + ORT_ENFORCE(final_scores.size() == output_scores.size()); + constexpr unsigned ThreadPerBlock = 256; + unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1))/ ThreadPerBlock); + + typedef typename ToCudaType::MappedType CudaT; + + FloatConvertAndCopyKernel<<>>( + final_scores.data(), (CudaT*)output_scores.data(), final_scores.size()); +} + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + +template void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + __global__ void AddProbsKernel(float* log_probs, float* cum_log_probs, const int vocab_size, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 5ed5949196b29..281cb6c725975 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -65,11 +65,12 @@ struct BeamHypotheses { __device__ bool CanImprove(float best_sum_logprobs, int current_length) const; // Output results - __device__ void Output(int top_k, // number of sequences to return - int max_length, // max sequence length - int pad_token_id, // pad token - int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) - float* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + template + __device__ void Output(int top_k, // number of sequences to return + int max_length, // max sequence length + int pad_token_id, // pad token + int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) + T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) }; struct BeamScorerState { @@ -110,6 +111,7 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp gsl::span next_beam_indices, cudaStream_t stream); +template void LaunchBeamSearchScorer_Finalize(int batch_size, BeamScorerState& state, gsl::span sequences, @@ -117,9 +119,14 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, gsl::span beam_hyps_, gsl::span final_beam_scores, gsl::span output, - gsl::span sequence_scores, + gsl::span sequence_scores, cudaStream_t stream); +template +void LaunchBeamSearchScoreCopy(gsl::span final_scores, + gsl::span output_scores, + cudaStream_t stream); + void LaunchNextTokenKernel(const int64_t* next_token_indices, int32_t* next_indices, int32_t* next_tokens, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 380d561bbb23c..bba30805ae1be 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -620,6 +620,8 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { Tensor* output_sequences, Tensor* output_sequence_scores) override; + void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; + bool IsDone() const override { return false; } // For CUDA we speculatively run the next step while we wait for the GPU to report status. We use 'IsDoneLater()' for this bool IsDoneLater() const override; @@ -632,7 +634,6 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { } gsl::span GetNextIndicesGPU() override { return next_beam_indices_; } - private: mutable cuda::AutoDestoryCudaEvent event_process_complete_; IAllocatorUniquePtr state_cpu_; IAllocatorUniquePtr state_gpu_; @@ -743,22 +744,58 @@ bool CudaBeamSearchScorer::IsDoneLater() const { return state_cpu_->not_done_count_ == 0; } +template +void CudaOutputSequenceScores(CudaBeamSearchScorer* scorer, + transformers::ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). + gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; + + // Score of each sequence, with shape (batch_size * num_return_sequences). + using CudaT = typename ToCudaType::MappedType; + gsl::span sequence_scores; + if (output_sequence_scores) { + sequence_scores = gsl::span{(CudaT*)output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + } + + cuda::LaunchBeamSearchScorer_Finalize(scorer->state_cpu_->batch_size_, + *scorer->state_gpu_, + sequences.GetCurrentDeviceSequences(), + sequences.GetSequenceLength(), + scorer->beam_hyps_, + final_beam_scores, + output, + sequence_scores, + scorer->stream_); +} + void CudaBeamSearchScorer::Finalize(transformers::ISequences& sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) { ORT_ENFORCE(output_sequences != nullptr); - // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). - gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; - - // Score of each sequence, with shape (batch_size * num_return_sequences). - gsl::span sequence_scores; - if (output_sequence_scores) { - sequence_scores = gsl::span{output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); + } else { + ORT_ENFORCE(output_sequence_scores->IsDataType()); + CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); } +} - cuda::LaunchBeamSearchScorer_Finalize(state_cpu_->batch_size_, *state_gpu_, sequences.GetCurrentDeviceSequences(), sequences.GetSequenceLength(), beam_hyps_, final_beam_scores, output, sequence_scores, stream_); +void CudaBeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { + if (output_scores) { + if (output_scores->IsDataType()) { + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } else { + ORT_ENFORCE(output_scores->IsDataType()); + gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); + cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); + } + } } std::unique_ptr CreateBeamScorer(const transformers::IGenerationParameters& parameters, diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index fcf9c2b03dea5..711fd595e90fd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -30,6 +30,10 @@ #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #endif // ARM #endif // Linux @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); #endif } @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() { /* TODO: implement them when hw+sw is available for testing these features */ has_arm_neon_i8mm_ = false; has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a15c75104b83a..2f8041e39f680 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -30,6 +30,7 @@ class CPUIDInfo { bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } uint32_t GetCurrentCoreIdx() const; @@ -125,6 +126,7 @@ class CPUIDInfo { bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index d9c49dc6bea1d..8c08152986cf6 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -223,7 +223,8 @@ void IExecutionFrame::Init(gsl::span feed_mlvalue_idxs, gsl::span& initializers, const std::function& is_initializer_sparse_func, gsl::span fetches) { - ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size()); + ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size(), "Get feed size: ", feeds.size(), " but expected feed size: ", + feed_mlvalue_idxs.size()); ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size()); // Need this for sparse conversions in host memory diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e4fe0c7564548..07b465c80745a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -16,6 +16,7 @@ #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -634,6 +635,100 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +static Status CreateEpContextModel(const ExecutionProviders& execution_providers, + const Graph& graph, + const std::string& ep_context_path, + const logging::Logger& logger) { + InlinedVector all_ep_context_nodes; + for (const auto& ep : execution_providers) { + const InlinedVector ep_context_nodes = ep->GetEpContextNodes(); + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); + } + + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : all_ep_context_nodes) { + if (node_name == node->Name()) { + return std::make_pair(true, node); + } + } + return std::make_pair(false, static_cast(nullptr)); + }; + + onnxruntime::PathString context_cache_path; + PathString model_pathstring = graph.ModelPath().ToPathString(); + if (all_ep_context_nodes.size() > 0) { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { +#ifdef _WIN32 + std::wifstream fs(context_cache_path); +#else + std::ifstream fs(context_cache_path); +#endif + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } + + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); + + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); + + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } + + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } + + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); + } + } + + // handle initializers + for (const auto& input : graph.GetInputsIncludingInitializers()) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(input->Name(), initializer)) { + // There initializer could have duplicates so make sure we only add once + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { + ep_graph.AddInitializedTensor(*initializer); + } + } + } + + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -840,6 +935,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. @@ -886,7 +983,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_)); + + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (ep_context_enabled) { + ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger)); + } #else + ORT_UNUSED_PARAMETER(config_options); + ORT_UNUSED_PARAMETER(logger); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 4fc85c2588260..d1ef193cf1520 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -13,6 +13,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; class Model; +struct ConfigOptions; class GraphPartitioner { public: @@ -31,6 +32,8 @@ class GraphPartitioner { // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index df8d0a59cb033..7f34647f1faef 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -927,6 +927,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("unidirectional", + "Whether every token can only attend to previous tokens. Default value is 0.", + AttributeProto::INT, + static_cast(0)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", @@ -1145,6 +1149,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("rotary_embedding_dim", + "Rotary embedding dimension. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("num_heads", + "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "input", "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", @@ -1155,17 +1167,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(2, "cos_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Input(3, "sin_cache", - "2D tensor with shape (max_sequence_length, head_size / 2).", + "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", "T") .Output(0, "output", "tensor with same shape as input.", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -1285,7 +1297,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(3, "input_skip_bias_sum", "Sum of the input and skip inputs (and bias if it exists) with shape (batch_size, sequence_length, hidden_size).", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); ONNX_MS_OPERATOR_SET_SCHEMA( SkipSimplifiedLayerNormalization, 1, @@ -1334,7 +1346,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") - .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); constexpr const char* NGramRepeatBlock_ver1_doc = R"DOC( Enforce no repetition of n-grams. Scores are set to `-inf` for tokens that form a repeated n-gram if added to the back of the input_ids. diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc index eeef20e9dff5e..8b1812f62be25 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc @@ -114,6 +114,45 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c } } +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + + auto stash_type = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + if (ctx.getNumOutputs() > 1) { + auto output_type = ctx.getOutputType(1); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 2) { + auto output_type = ctx.getOutputType(2); + output_type->mutable_tensor_type()->set_elem_type(static_cast(stash_type)); + } + if (ctx.getNumOutputs() > 3) { + propagateElemTypeFromInputToOutput(ctx, 0, 3); + } + if (!hasNInputShapes(ctx, 1)) { + return; + } + auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + int64_t input_ndim = input_shape.dim_size(); + int axis = static_cast(input_ndim - 1); + + if (ctx.getNumOutputs() > 1) { + auto mean_shape = ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape(); + mean_shape->CopyFrom(input_shape); + mean_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 2) { + auto inv_std_dev_shape = ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape(); + inv_std_dev_shape->CopyFrom(input_shape); + inv_std_dev_shape->mutable_dim(axis)->set_dim_value(1); + } + + if (ctx.getNumOutputs() > 3) { + propagateShapeFromInputToOutput(ctx, 0, 3); + } +} + // Shape inference for Attention and QAttention void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index) { // Input 0, 1, 2 are input, weights and bias. diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h index 93cf5b304f653..6eb06af15309c 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.h @@ -13,5 +13,6 @@ namespace onnxruntime { namespace contrib { void AttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_input_index); void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); +void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx); } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bdd4dba521eba..ce7838556fbf0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index bc0bfc92c85a0..047011e70bd4d 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -183,133 +183,3 @@ MlasSQNBitGemmPackQuantBData( void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr ); - -/** - * @brief Data parameters for NBits GEMM routine - * C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * All except C are [in] parameters - */ -struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (packed nbits blob)*/ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ -}; - -/** - * @brief Compute the byte size of the parameter combination - * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return size of the packing buffer, 0 if the operation is not yet supported. - */ -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type -); - -/** - * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. - * - * @param PackedBuf packed data buffer - * @param QData quantized data buffer - * @param Scale scale pointer - * @param Zp zero point pointer - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor - * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where - * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale - * (is_asym is false) and Zp(is_asym is true). - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t block_size, - int nbits, - bool is_asym, - bool last_call, - MLAS_SQNBIT_COMPUTE_TYPE comp_type, - MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Unpack and dequantize to fp32 - * - * @param FpData unpacked float32 data - * @param PackedBuf quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param thread_pool - */ -void MLASCALL -MlasNBitsGemmUnPackB( - float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool -); - -/** - * @brief Get the workspace size required by computation. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @return Workspace size in bytes - */ -size_t MLASCALL -MlasSQNBitsGemmBatchPackedBWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); - -/** - * @brief Batched GEMM: C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] WorkSpace temporary buffer - * @param[in] ThreadPool - * @return - */ -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool = nullptr -); diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 0000000000000..e424c30515e9f --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h deleted file mode 100644 index 9cd1711a3ffd2..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_defs.h +++ /dev/null @@ -1,73 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - ---*/ - -#pragma once - -#include "jblas/jit_blas_prologue_b.h" -#include "jblas/jit_blas_wrapper.h" - -namespace jblas -{ - -/* -Name conversion explaination: -Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight -classes) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores. -*/ -template -using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationKBlockBaseF32, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompFp32BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -/* -Name conversion explaination: -Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore) -S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only) -F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and -jblas::epilogue::gemm::AccumulatorWriteBackFp32. - -Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores. -*/ -template -using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< - GemmCore_T::ISA, - GemmCore_T, - jblas::prologue_a::gemm::ActivationF32KBlockQuantize, - jblas::prologue_b::gemm::WeightKBlockS4, - jblas::epilogue::gemm::CompInt8BlockEpilogue, - jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - -using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; -using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; -using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; -using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency -using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; -using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; -using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; -using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency - -class ORTThreading : public jblas::parallel::IThreading -{ - public: - ORTThreading(void* tp); - void parallel_for(const jblas::parallel::thread_func& func) override; - void set_threads(int nthreads) override { assert(0); } - void sync() override { assert(0); } - void* mTp; -}; - -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp deleted file mode 100644 index f3cae3186c28e..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.cpp +++ /dev/null @@ -1,534 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.cpp - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#include "jblas_gemm.h" - -#include "jblas_defs.h" -#include "mlasi.h" - -using namespace jblas; - -jblas::ORTThreading::ORTThreading(void* tp) - : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) -{ -} - -void -jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) -{ - MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { - func(static_cast(tid)); - }); -} - -template -static void -JblasSQ4GemmCompF32( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - if (M <= 16) { - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - if (B->mIsAsym) { - reduceA.assign(WorkSpace); - ORTThreading single(nullptr); - kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single); - } - typename Launcher::BEpiParam blkargs{ - B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), - reduceA.template get(), reduceA.lda}; - - typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); - } else { - using Parallel = jblas::parallel::gemm::SchedulerBase; - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - - typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}}; - jblas::parallel::GemmBaseRun(kernel, args, th); - } -} - -template -static void -JblasSQ4GemmCompInt8( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc, - int8_t* WorkSpace, - jblas::parallel::IThreading* th -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - static Launcher kernel; - auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym); - quanA.assign(WorkSpace); - if (M <= 16) { - ORTThreading single(nullptr); - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); - } else { - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); - } - typename Launcher::Param args{ - M_, - N_, - K_, - B->mBlockSize, - {A, lda_, &quanA}, - {B}, - {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep, - quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), - quanA.template RPtr(), B->mBlockSize}, - {C, ldc_}}; - jblas::parallel::GemmKBlockRun(kernel, args, th); -} - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - ORTThreading orth(ThreadPool); - bool processed = true; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - JblasSQ4GemmCompF32( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - JblasSQ4GemmCompInt8( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, - WorkSpace, &orth - ); - } - } - } - } else { - processed = false; - break; - } - } - return processed; -} - -template -static size_t -JblasSQ4GemmCompF32WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - auto M_ = static_cast(M); - auto K_ = static_cast(K); - (void)(N); - (void)(lda); - (void)(ldc); - if (M <= 16) { - using Launcher = tLauncher_Fp32_S4_F32F32; - static Launcher kernel; - if (B->mIsAsym) { - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - return reduceA.mSize; - } - return 0; - } else { - using Launcher = jblas::wrapper::gemm::LauncherBase< - GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, - jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; - static Launcher kernel; - return 0; - } - return 0; -} - -template -static size_t -JblasSQ4GemmCompInt8WorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const float* A, - const size_t lda, - jblas::storage::gemm::StorageWeightKBlockS4* B, - float* C, - const size_t ldc -) -{ - using Parallel = jblas::parallel::gemm::SchedulerKBlock; - using Launcher = tLauncher_Int8_S4_F32F32; - static Launcher kernel; - (void)(N); - (void)(lda); - (void)(ldc); - auto quanA = kernel.mProA.createStorage( - static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym - ); - return quanA.mSize; -} - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ - GetCPUDevice(); - size_t size = 0; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto kptr = reinterpret_cast(ptr); - auto coretype = ptr->mCoreId; - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - size = std::max( - JblasSQ4GemmCompF32WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - size = std::max( - JblasSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc - ), - size - ); - } - } - } - } - } - return size; -} - -template -static size_t -JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) -{ - static T launcher; - auto stor = launcher.mProB.createStorage( - static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, - JBLAS_DTYPE::BF16, isAsym - ); - // TODO(Yu) support more scale dtype - return stor.mSize; -} - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) -{ - GetCPUDevice(); - if (K % BlkSize != 0) { - return 0; - } - // from low precision to high precision - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - return JblasQ4BuSize>(BlkSize, N, K, isAsym); - } - break; - default: - return 0; - } - return 0; -} - -template -static void -JblasQ4GemmPackBImpl( - void* PackedBuf, - size_t BlkSize, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - bool IsAsym, - bool lastCall, - size_t ldb, - MLAS_THREADPOOL* ThreadPool -) -{ - static T JblasKernel; - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto stor = JblasKernel.mProB.createStorage( - N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym - ); - stor.assign(reinterpret_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); - if (lastCall) { - JblasKernel.mProB.reduceWeight(&stor, &orth); - } -} - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ - GetCPUDevice(); - // explicit statement fall through. - switch (CompType) { - case CompInt8: - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - case CompBf16: - case CompFp16: - case CompFp32: - case CompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - JblasQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool - ); - return true; - } - default: - return false; - } - return false; -} - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ - auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); - auto uptr = std::unique_ptr(ptr); - ORTThreading orth(ThreadPool); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto ldb_ = static_cast(ldb); - GetCPUDevice(); - if (ptr) { - if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { - auto NTile = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT - ); - auto CType = jblas::gemm::CoreAttr::get_mask_val( - ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT - ); - if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { - if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { - if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { - static jblas::prologue_b::gemm::WeightKBlockS4 proB; - proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); - } - } - } - return true; - } - return false; -} diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h deleted file mode 100644 index 044dc5e849a0a..0000000000000 --- a/onnxruntime/core/mlas/lib/jblas_gemm.h +++ /dev/null @@ -1,61 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - jblas_gemm.h - -Abstract: - - Currently only support Q4 gemm. ---*/ - -#pragma once - -#include "mlas_qnbit.h" - -size_t -JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); - -bool -JblasQ4GemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -); - -bool -JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb - , MLAS_THREADPOOL* ThreadPool); - -bool -JblasSQ4GemmBatchDriver( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - int8_t* WorkSpace, - MLAS_THREADPOOL* ThreadPool -); - -size_t -JblasSQ4GemmBatchWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bb8b17031a84..624eb913d5c9e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -193,6 +193,8 @@ class MLASCPUIDInfo bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); @@ -200,6 +202,7 @@ class MLASCPUIDInfo bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -357,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -727,6 +744,10 @@ extern "C" { #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -856,6 +877,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 1310ed3f384b9..de092f7d1d350 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -60,6 +60,10 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -70,6 +74,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h new file mode 100644 index 0000000000000..de7fd72fad45a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..a6a73996c548b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 7d877848017fe..0d8a5692359a6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,10 +19,6 @@ Module Name: #include -#ifdef MLAS_JBLAS -#include "jblas_gemm.h" -#endif - namespace { @@ -694,127 +690,3 @@ MlasSQNBitGemmBatch( ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } - -size_t MLASCALL -MlasNBitsGemmPackBSize( - size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); - if (jsize) { - return jsize; - } - } -#endif - (void)(N); - (void)(K); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(CompType); - return 0; -} - -void MLASCALL -MlasNBitsGemmPackB( - void* PackedBuf, - const uint8_t* QData, - const float* Scale, - const uint8_t* Zp, - size_t N, - size_t K, - size_t ldb, - size_t BlkSize, - int nbits, - bool isAsym, - bool lastCall, - MLAS_SQNBIT_COMPUTE_TYPE CompType, - MLAS_THREADPOOL* ThreadPool -) -{ -#ifdef MLAS_JBLAS - if (nbits == 4) { - if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { - return; - } - } -#endif - (void)(PackedBuf); - (void)(QData); - (void)(Scale); - (void)(Zp); - (void)(N); - (void)(K); - (void)(ldb); - (void)(BlkSize); - (void)(nbits); - (void)(isAsym); - (void)(lastCall); - (void)(CompType); - (void)(ThreadPool); -} - -void MLASCALL -MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) -{ -#ifdef MLAS_JBLAS - if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { - return; - } -#endif - (void)(FpData); - (void)(PackedBuf); - (void)(N); - (void)(K); - (void)(ldb); - (void)(ThreadPool); -} - -size_t MLASCALL -MlasSQNBitsGemmBatchPackedBWorkspaceSize( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams -) -{ -#ifdef MLAS_JBLAS - return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - return 0; -} - -void MLASCALL -MlasSQNBitsGemmBatchPackedB( - const size_t M, - const size_t N, - const size_t K, - const size_t BatchN, - const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, - void* WorkSpace, - MLAS_THREADPOOL* ThreadPool -) -{ - GetMlasPlatform(); -#ifdef MLAS_JBLAS - if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { - // PackedWeight is created by jblas - return; - } -#endif - (void)(M); - (void)(N); - (void)(K); - (void)(BatchN); - (void)(DataParams); - (void)(WorkSpace); - (void)(ThreadPool); -} diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format deleted file mode 100644 index 84b876706161d..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format +++ /dev/null @@ -1,7 +0,0 @@ -Language: Cpp -BasedOnStyle: Google -DerivePointerAlignment: false -ColumnLimit: 120 -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SortIncludes: false diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt deleted file mode 100644 index 5d9c5edf45a96..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt +++ /dev/null @@ -1,33 +0,0 @@ -cmake_minimum_required(VERSION 3.5) - -project(jblas LANGUAGES CXX VERSION 0.1.0) - -file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) -file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) - -add_library(${PROJECT_NAME} INTERFACE) -add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) - -target_include_directories( - ${PROJECT_NAME} INTERFACE - "$" - "$" -) - -if(WIN32) - target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) - target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100) - #4068 ignore unroll and GCC flags - #4849 ignore collapse - #6262 ignore stack too large - #4702 unreachable code(false warning on constexpr condition) - #4100 unreferenced formal parameter - - target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size -endif(WIN32) - - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h deleted file mode 100644 index 143adb771760b..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h +++ /dev/null @@ -1,303 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include -#include -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" - -#define OFFSET(field) offsetof(params, field) - -namespace jblas { - -namespace xbyak { -class JitBase : protected Xbyak::CodeGenerator { - protected: - JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {} - - void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) { - xor_(reg, reg); - mov(reg.cvt32(), addr); - } - - void vreg_push(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i)); - } -#endif - } - - void vreg_pop(const Xbyak::Reg64& baseaddr) { -#ifdef _WIN32 - for (int i = 0; i < 10; i++) { - movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]); - } -#endif - } - - void padto_le(const Xbyak::Reg64& _src, int padding) { - // _src=_src/padding*padding - if (padding == 1) { - return; - } - for (int i = 1; i < 16; i++) { - if ((1 << i) == padding) { - shr(_src, i); - shl(_src, i); - return; - } - } - assert(0); - } - - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - inLocalLabel(); - lea(_tmp, _total); - sub(_tmp, _pos); - cmp(_tmp, N); - jb(".maskflag"); - cmp(_tmp, 0); - jl(".zeroflag"); - uint64_t allmask = (static_cast(1) << N) - 1; - if (N == 64) { - allmask = static_cast(-1); - } - mov(_tmp, allmask); - kmovq(_msk, _tmp); - jmp(".maskend"); - L(".maskflag"); - mov(_tmp1, 1); - shlx(_tmp1, _tmp1, _tmp); - sub(_tmp1, 1); - kmovq(_msk, _tmp1); - jmp(".maskend"); - L(".zeroflag"); - mov(_tmp1, 0); - kmovq(_msk, _tmp1); - L(".maskend"); - outLocalLabel(); - } - void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total, - const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { - generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N); - } -}; - -class JitAvx : protected JitBase { - protected: - static int constexpr VBits = 256; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 16; - typedef Xbyak::Ymm vreg_t; -}; - -class JitAvx2 : protected JitAvx { - protected: - static int constexpr VBits = 256; - typedef Xbyak::Ymm vreg_t; - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } - - void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } -}; - -class JitAvx512f : protected JitAvx2 { - protected: - static int constexpr VBits = 512; - static int constexpr VecBytes = VBits / 8; - static int constexpr RegCount = 32; - typedef Xbyak::Zmm vreg_t; - - void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } - - void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { - vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); - vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); - vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6)); - vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6)); - vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6)); - } - - void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) { - for (int i = 0; i < 8; ++i) { - vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]); - vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]); - } - - for (int i = 0; i < 4; ++i) { - vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]); - vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]); - vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]); - } - - for (int i = 0; i < 2; ++i) { - vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88); - vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88); - vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88); - vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88); - vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd); - vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd); - vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd); - vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd); - } - - // last step and move out - for (int i = 0; i < N; ++i) { - vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd); - } - } - - void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) { - vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]); - vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]); - vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]); - vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]); - - vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]); - vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]); - vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]); - vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4); - vmovups(src_4regs[0], tmp_regs[1]); - vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[1], tmp_regs[3]); - vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14); - vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14); - vmovups(src_4regs[2], tmp_regs[1]); - vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); - vmovups(src_4regs[3], tmp_regs[3]); - vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); - } - - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { - vpsrld(_fp32, _fp32, 16); - vpmovdw(_bf16, _fp32); - } - - void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) { - vpmovzxwd(dst, addr); - vpslld(dst, dst, 16); - } - - void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) { - mov(tmp.cvt16(), addr); - shl(tmp.cvt32(), 16); - vpbroadcastd(dst, tmp.cvt32()); - } - - void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) { - auto bf16 = Xbyak::Ymm(_fp32.getIdx()); - cvt_fp32_bf16(bf16, _fp32); - vmovups(_add, bf16); - } -}; - -class JitAvx512_bf16 : protected JitAvx512f {}; - -class JitAvx512_fp16 : protected JitAvx512f {}; - -class JitAvx512vnni : protected JitAvx512f { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); - } -}; - -class JitAvxvnni : protected JitAvx2 { - protected: - void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { - vpdpbusds(x1, x2, op, Xbyak::VexEncoding); - } -}; - -class JitAmxtile : protected JitAvx512f { - public: - struct alignas(64) tileconfig_t { - uint8_t palette_id; - uint8_t reserved[15]; - uint16_t colb[16]; - uint8_t rows[16]; - }; - static int constexpr TileCount = 8; - - typedef long long (*configure_t)(void*); - - static void generate_config(Xbyak::CodeGenerator* g) { - Xbyak::util::StackFrame st(g, 1, 0, 0); - auto& parambase = st.p[0]; - g->ldtilecfg(g->ptr[parambase]); - } - - static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, - int CNum) { - // Filling tile configure structure. Could be done offline. - tc.palette_id = 1; - // Configure C tiles - int t = 0; - for (; t < CNum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_N * 4); - } - // Configure A tiles - for (; t < CNum + ANum; ++t) { - tc.rows[t] = static_cast(TILE_M); - tc.colb[t] = static_cast(TILE_K * elesize); - } - // Configure B tile. B effectively has 64 rows and 16 columns. - int kpack = 4 / elesize; - for (; t < CNum + ANum + BNum; ++t) { - tc.rows[t] = static_cast(TILE_K / kpack); - tc.colb[t] = static_cast(TILE_N * 4); - } - } -}; - -class JitAmxbf16 : protected JitAmxtile { - protected: - void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); } -}; - -class JitAmxint8 : protected JitAmxtile { - protected: - template - void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3); -}; -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbssd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbsud(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbusd(x1, x2, x3); -} -template <> -inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { - tdpbuud(x1, x2, x3); -} -} // namespace xbyak -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h deleted file mode 100644 index 8ecf3535c17f4..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include -enum JBLAS_CODE { - JblasSuccess = 0, - JblasInvalidParam = 1, - JblasInvalidISA = 2, - JblasRuntimeError = 4, - JblasNotSupport = 8, -}; -enum JBLAS_ISA : uint32_t { - JblasNoSIMD = 0, - JblasAVX, - JblasAVX2, - JblasAVX_VNNI, - JblasAVX512F, - JblasAVX512_VNNI, - JblasAMX_BF16, - JblasAMX_INT8, - JblasAVX512_FP16, - JblasAVX512_BF16, -}; -enum class JBLAS_DTYPE : uint32_t { - EleBitsMask = 0xff, - EleBitsUndef = 0, - EleBits4 = 4, - EleBits8 = 8, - EleBits16 = 16, - EleBits32 = 32, - EleBits64 = 64, - TypeMask = 0xff00, - TypeFloat = 0 << 8, - TypeInt = 1 << 8, - SubTypeMask = 0xff0000, - SubType0 = 0 << 16, - SubType1 = 1 << 16, - SubType2 = 2 << 16, - F64 = EleBits64 | TypeFloat, - F32 = EleBits32 | TypeFloat, - F16 = EleBits16 | TypeFloat, - BF16 = EleBits16 | TypeFloat | SubType1, - F8_E4M3 = EleBits8 | TypeFloat, - F8_E5M2 = EleBits8 | TypeFloat | SubType1, - F8_E3M4 = EleBits8 | TypeFloat | SubType2, - S8 = EleBits8 | TypeInt, - U8 = EleBits8 | TypeInt | SubType1, - S4_CLIP = EleBits4 | TypeInt, - S4_FULLRANGE = EleBits4 | TypeInt | SubType1, - F4_E2M1 = EleBits4 | TypeFloat, - F4_BNB = EleBits4 | TypeFloat | SubType1, - F4_NF4 = EleBits4 | TypeFloat | SubType2, - S32 = EleBits32 | TypeInt, - U32 = EleBits32 | TypeInt | SubType1, -}; - -enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; -enum JBLAS_TRANSPOSE { - JblasNoTrans = 111, - JblasTrans = 112, - JblasConjTrans = 113, -}; -enum JBLAS_ELTWISEOP { - GELU, - SWISH, - TANH, - EXP, - LOW_PRECISION_EXP, - RELU, - LINEAR, -}; - -enum class JBLAS_PROLOGUEB_IDS : uint32_t { - Undef = (uint32_t)-1, - Begin = 0, - NormalBegin = Begin, - WeightPack = NormalBegin, - NormalEnd, - KBlockBegin = NormalEnd, - WeightKBlockS8 = KBlockBegin, - WeightKBlockS4, - WeightKBlockF4, - KBlockEnd, - End, -}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h deleted file mode 100644 index 5cac1080bc610..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include "jit_blas.h" -#include "xbyak/xbyak_util.h" - -namespace jblas { - -namespace device { - -struct X64_ISA { - int64_t MMX : 1; // 0 - int64_t SSE : 1; // 1 - int64_t SSE2 : 1; // 2 - int64_t SSE3 : 1; // 3 - int64_t SSSE3 : 1; // 4 - int64_t SSE41 : 1; // 5 - int64_t SSE42 : 1; // 6 - int64_t AVX : 1; // 7 - int64_t F16C : 1; // 8 - int64_t FMA : 1; // 9 - int64_t AVX2 : 1; // 10 - int64_t AVX_VNNI : 1; // 11 - int64_t AVX_VNNI_INT8 : 1; // 12 - int64_t AVX_NE_CONVERT : 1; // 13 - int64_t AVX_IFMA : 1; // 14 - int64_t AVX512F : 1; // 15 - int64_t AVX512BW : 1; // 16 - int64_t AVX512CD : 1; // 17 - int64_t AVX512DQ : 1; // 18 - int64_t AVX512ER : 1; // 19 - int64_t AVX512IFMA52 : 1; // 20 - int64_t AVX512PF : 1; // 21 - int64_t AVX512VL : 1; // 22 - int64_t AVX512VPOPCNTDQ : 1; // 23 - int64_t AVX512_4FMAPS : 1; // 24 - int64_t AVX512_4VNNIW : 1; // 25 - int64_t AVX512_BF16 : 1; // 26 - int64_t AVX512_BITALG : 1; // 27 - int64_t AVX512_VBMI : 1; // 28 - int64_t AVX512_VBMI2 : 1; // 29 - int64_t AVX512_VNNI : 1; // 30 - int64_t AVX512_VP2INTERSECT : 1; // 31 - int64_t AVX512_FP16 : 1; // 32 - int64_t AMX_TILE : 1; // 33 - int64_t AMX_BF16 : 1; // 34 - int64_t AMX_INT8 : 1; // 35 - int64_t AMX_FP16 : 1; // 36 - int64_t AMX_COMPLEX : 1; // 37 - int64_t reserved : (64 - 38); -}; - -class AVX2_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 0; - static constexpr bool AVX512BW = 0; - static constexpr bool AVX512CD = 0; - static constexpr bool AVX512DQ = 0; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 0; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 0; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class AVX512_VNNI_Default { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 0; - static constexpr bool AMX_BF16 = 0; - static constexpr bool AMX_INT8 = 0; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -class SapphireRapids { - public: - static constexpr bool MMX = 1; - static constexpr bool SSE = 1; - static constexpr bool SSE2 = 1; - static constexpr bool SSE3 = 1; - static constexpr bool SSSE3 = 1; - static constexpr bool SSE41 = 1; - static constexpr bool SSE42 = 1; - static constexpr bool AVX = 1; - static constexpr bool F16C = 1; - static constexpr bool FMA = 1; - static constexpr bool AVX2 = 1; - static constexpr bool AVX_VNNI = 0; - static constexpr bool AVX_VNNI_INT8 = 0; - static constexpr bool AVX_NE_CONVERT = 0; - static constexpr bool AVX_IFMA = 0; - static constexpr bool AVX512F = 1; - static constexpr bool AVX512BW = 1; - static constexpr bool AVX512CD = 1; - static constexpr bool AVX512DQ = 1; - static constexpr bool AVX512ER = 0; - static constexpr bool AVX512IFMA52 = 0; - static constexpr bool AVX512PF = 0; - static constexpr bool AVX512VL = 1; - static constexpr bool AVX512VPOPCNTDQ = 0; - static constexpr bool AVX512_4FMAPS = 0; - static constexpr bool AVX512_4VNNIW = 0; - static constexpr bool AVX512_BF16 = 0; - static constexpr bool AVX512_BITALG = 0; - static constexpr bool AVX512_VBMI = 0; - static constexpr bool AVX512_VBMI2 = 0; - static constexpr bool AVX512_VNNI = 1; - static constexpr bool AVX512_VP2INTERSECT = 0; - static constexpr bool AVX512_FP16 = 0; - static constexpr bool AMX_TILE = 1; - static constexpr bool AMX_BF16 = 1; - static constexpr bool AMX_INT8 = 1; - static constexpr bool AMX_FP16 = 0; - static constexpr bool AMX_COMPLEX = 0; -}; - -template -class isa_base { - public: - static bool constexpr avx = ISA_T >= JblasAVX; - static bool constexpr avx2 = ISA_T >= JblasAVX2; - static bool constexpr avx512f = ISA_T >= JblasAVX512F; - static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; - static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; - static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; - static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; -}; - -class CpuDevice { - public: - inline void setThreads(int _nth) { - if (_nth <= 0) { - numthreads = numcores; - } else { - numthreads = std::min(numcores, _nth); - } - } - inline int getThreads() { return numthreads; } - inline int getCores() { return numcores; } - inline uint32_t getL2CacheSize() { return L2Cache; } - inline uint32_t getL1CacheSize() { return L1Cache; } - inline bool AVX() { return mHasAVX; } - inline bool AVX2() { return mHasAVX2; } - inline bool AVX_VNNI() { return mHasAVX_VNNI; } - inline bool AVX512F() { return mHasAVX512F; } - inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } - inline bool AMX_INT8() { return mHasAMX_INT8; } - inline bool AMX_BF16() { return mHasAMX_BF16; } - inline bool AVX512_BF16() { return mHasAVX512_BF16; } - inline bool AVX512_FP16() { return mHasAVX512_FP16; } -#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) - CpuDevice() { - static Xbyak::util::Cpu _cpu; - L1Cache = _cpu.getDataCacheSize(0); - L2Cache = _cpu.getDataCacheSize(1); - ADD_FLAG(AVX); - ADD_FLAG(AVX2); - ADD_FLAG(AVX512F); - ADD_FLAG(AVX512_VNNI); - ADD_FLAG(AVX_VNNI); - ADD_FLAG(AMX_BF16); - ADD_FLAG(AMX_INT8); - ADD_FLAG(AVX512_BF16); - ADD_FLAG(AVX512_FP16); - numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); - numthreads = numcores; - } - - static CpuDevice* getInstance() { - static CpuDevice instance; - return &instance; - } - - void print() { - printf( - "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", - mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, - mHasAVX512_FP16); - } -#undef ADD_FLAG - - protected: - uint32_t L2Cache, L1Cache; - bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, - mHasAVX512_FP16; - int numcores; - int numthreads; -}; - -#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); - -class CpuBase { - public: - CpuBase() { - GetCPUDevice(); - mL2Cache = _cd->getL2CacheSize(); - mL1Cache = _cd->getL1CacheSize(); - mNumThreads = _cd->getThreads(); - } - size_t mL2Cache, mL1Cache; - int mNumThreads; -}; -} // namespace device -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h deleted file mode 100644 index ceb7a545092d8..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_base.h" -#include "jit_blas.h" -#include "jit_blas_utils.h" -#include "kernel_wrapper.h" - -namespace jblas { -namespace epilogue { -namespace gemm { - -template -class AccumulatorWriteBack { - public: - using SType = _SRC_T; - using DType = _DST_T; - struct Param { - DType* C; - int ldc; - void* elt_const_v; - }; - - template - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - bool constexpr Valid = !std::is_same::value || std::is_same::value; - static_assert(Valid, "fp32 to bf16 conversion only."); - if constexpr (std::is_same::value) { - return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (std::is_same, std::tuple>::value) { - return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( - const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); - } else if constexpr (sizeof(SType) == sizeof(DType)) { - return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v, ops...); - } else { - assert(false); - } - } -}; - -template -class CustomAccumulatorWriteBackWithEltop { - public: - struct Param { - _DST_T* C; - int ldc; - void* elt_const_v; - }; - JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { - return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, - _param.ldc, _param.elt_const_v); - } else { - assert(false); - } - } -}; -template -using AccumulatorWriteBackFp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackInt32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackBf16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; -template -using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; - -template -using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; - -template -using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; - -template -class AlphaBetaProcessFp32 { - public: - struct Param { - float *C, *D; - int ldc, ldd; - float alpha, beta; - }; - - JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto DOffset = M_offset * _param.ldd + N_offset; - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto dptr = _param.D + DOffset; - return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, - dptr, _param.ldd, cptr, _param.ldc, M, N); - } -}; - -template -class CompFp32BlockEpilogue { - public: - struct Param { - void* scales; - JBLAS_DTYPE scaledtype; - int ldsb; - int8_t* zps = nullptr; - float* reduce = nullptr; - int ldra; - }; - JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - auto ret = JblasNotSupport; - if (_param.scaledtype == JBLAS_DTYPE::F32) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(ret == JblasSuccess); - if (_param.zps != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( - dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, - _param.reduce + M_offset * _param.ldra + K_offset); - } - assert(ret == JblasSuccess); - return ret; - } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { - ret = kernel::wrapper::CompFp32BlockScale::template forward( - reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, - cachestep, M, N); - assert(_param.zps == nullptr); - assert(ret == JblasSuccess); - return ret; - } - return JblasNotSupport; - } -}; - -template -class DequantInt32ToFp32 { - public: - struct Param { - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, _param.ldsa, - _param.scalesB + N_offset); - } -}; - -template -class CompInt8BlockEpilogue { - public: - struct Param { - void* scalesB; - JBLAS_DTYPE scaleBdtype; - int ldsb; - float* scalesA; - int ldsa; - // optional if A asym - uint8_t* zpA = nullptr; - void* reduceB = nullptr; - JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, - const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, - size_t cachesize) { - JBLAS_CODE ret = JblasNotSupport; - float* scab = nullptr; - size_t ScaleBTmpSize = N * sizeof(float); - size_t ReduceBTmpSize = N * sizeof(float); - assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); - if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { - auto scache = reinterpret_cast(tmpcache); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - scab = scache; - } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { - scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; - } - float* redb = nullptr; - if (_param.reduceB) { - if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { - auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); - ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( - reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, - false); - assert(ret == JblasSuccess); - redb = rcache; - } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { - redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; - } - } - ret = kernel::wrapper::DequanS32Fp32::template forward( - srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); - assert(ret == JblasSuccess); - ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, - dstptr, cachestep, M, N); - assert(ret == JblasSuccess); - - if (_param.zpA == nullptr) { - if (_param.zpB == nullptr) { - return ret; - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa + K_offset); - } - } else { - if (_param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, - _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, - _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); - } - } - return ret; - } -}; - -template -class ZpDequantInt32ToFp32 { - public: - struct Param { - // necessary - float* C; - int ldc; - int ldsa; - float* scalesA; - float* scalesB; - // optional if A asym - uint8_t* zpA = nullptr; - float* reduceB = nullptr; - // optional if B asym - int8_t* zpB = nullptr; - float* reduceA = nullptr; - int K = 1; - }; - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, - _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.scalesB + N_offset); - if (ret != JblasSuccess) { - return ret; - } - if (_param.zpA == nullptr && _param.zpB == nullptr) { - return ret; - } else if (_param.zpA != nullptr && _param.zpB == nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, - _param.ldsa, _param.reduceB + N_offset); - } else if (_param.zpA == nullptr && _param.zpB != nullptr) { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( - cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, - _param.reduceA + M_offset * _param.ldsa); - } else { - ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( - cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, - _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, - _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); - } - return ret; - } -}; - -template -class AlphaBetaProcessS32U8 { - public: - struct Param { - uint8_t* C; - int ldc; - float alpha; - float scaleAcc, scaleC; - int zpC; - }; - - JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, - const int N, const Param& _param, void* tmpcache, size_t cachesize) { - auto COffset = M_offset * _param.ldc + N_offset; - auto cptr = _param.C + COffset; - return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, - M, N, _param.scaleAcc, _param.scaleC, _param.zpC); - } -}; - -} // namespace gemm -} // namespace epilogue -} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h deleted file mode 100644 index 364da9223940f..0000000000000 --- a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h +++ /dev/null @@ -1,2699 +0,0 @@ -// Copyright (c) 2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once -#include - -#include "jit_blas_utils.h" -#include "jit_base.h" - -namespace jblas { -namespace gemm { -enum class CompType : uint32_t { - COMP_FP32 = 0, - COMP_BF16_FP32 = 1, - COMP_FP16_FP16 = 2, - COMP_INT_START = 3, - COMP_INT8_US_INT32 = COMP_INT_START, - COMP_INT8_UU_INT32 = 4, - COMP_INT8_SS_INT32 = 5, - COMP_INT8_SU_INT32 = 6, - COMP_INT16_SS_INT32 = 7, - COMP_INT8_US_FP32 = 8, - COMP_INT8_UU_FP32 = 9, - COMP_INT8_SS_FP32 = 10, - COMP_INT8_SU_FP32 = 11, -}; - -class CoreAttr { - public: - // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| - static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, - COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; - - static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } - static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { - return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); - } - - static void parse_id(uint32_t id, uint32_t* vals) { - vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); - vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); - vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); - } - - static const char* to_str(uint32_t id) { - static char tmp[128]; - uint32_t vals[4]; - parse_id(id, vals); - sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); - return tmp; - } - - static inline size_t get_bsize(uint32_t id) { - auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); - return size_t(4 / packrow); - } -}; - -namespace code { - -template -class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { - public: - static int constexpr RegLen = 8, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { - public: - static int constexpr RegLen = 32, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; - typedef utils::fp16 AType; - typedef utils::fp16 BType; - typedef utils::fp16 CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { - public: - static int constexpr RegLen = 8, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; - typedef uint8_t AType; - typedef int8_t BType; - typedef int32_t CType; - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - private: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - protected: - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _kunroll) { - for (int kk = 0; kk < _kunroll; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { - public: - static int constexpr RegLen = 16, PackRow = 2; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; - typedef utils::bf16 AType; - typedef utils::bf16 BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; - -template -class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static_assert(_MTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; - static_assert(NRegs * MRegs + 2 <= TileCount); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; - static uint32_t constexpr COMPUTE = - (uint32_t)(std::is_same_v - ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 - : std::is_same_v ? CompType::COMP_INT8_US_INT32 - : CompType::COMP_INT8_UU_INT32); - using AType = AT; - using BType = BT; - typedef int32_t CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - void* workspace; - }; - typedef long long (*func_t)(params*); - - int TmpRegCount = RegCount; - int TmpReg = 0; - int CTileCount = 0, ATileCount = 0, BTileCount = 0; - int CTile = 0, ATile = 0, BTile = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CTileCount = NRegs * MRegs; - auto tile_re = TileCount - CTileCount; - if (tile_re - 1 >= NRegs) { - BTileCount = NRegs; - ATileCount = tile_re - BTileCount; - } else if (tile_re - 1 >= MRegs) { - ATileCount = MRegs; - BTileCount = tile_re - ATileCount; - } else { - ATileCount = 1; - BTileCount = tile_re - ATileCount; - } - CTile = 0; - ATile = CTile + CTileCount; - BTile = ATile + ATileCount; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int kunrll) { - auto& reg_Bstride = reg_tmp1; - mov(reg_Bstride, NTILE * 4); - int mtiles = _mtile / RegLen; - - for (int kk = 0; kk < kunrll; kk++) { - auto& reg_Atmp = reg_tmp2; - if (mtiles == 1) { - reg_Atmp = reg_matAptr; - } else { - mov(reg_Atmp, reg_matAptr); - } - if (BTileCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - } - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } else { - if (ATileCount == mtiles) { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - for (int mm = 0; mm < mtiles; mm++) { - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); - } - } - } else { - for (int mm = 0; mm < mtiles; mm++) { - tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); - _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); - } - if (mm != mtiles - 1) { - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); - } - } - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < CTileCount; i++) { - tilezero(Xbyak::Tmm(CTile + i)); - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - int mtnum = _mtile / 16; - for (int mm = 0; mm < mtnum; mm++) { - for (int i = 0; i < NRegs; i++) { - tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); - } - if (mm != mtnum - 1) { - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); - } - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_tmp, dword[parambase + OFFSET(workspace)]); - mov(reg_tmp1, NTILE * 4); - for (int mm = 0; mm < MRegs; mm++) { - for (int i = 0; i < NRegs; i++) { - tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); - } - } - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - int zunroll = TmpRegCount / NRegs; - for (int i = 0; i < _mtile; i += zunroll) { - int m_re = utils::remainsize(i, _mtile, zunroll); - for (int im = 0; im < m_re; im++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - } - outLocalLabel(); - } -}; -template -using Amxint8N16P4US = Amxint8N16P4; - -template -using Amxint8N16P4SS = Amxint8N16P4; - -class AmxConfigure : protected jblas::xbyak::JitAmxtile { - public: - typedef long long (*func_t)(tileconfig_t*); - - static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { - static AmxConfigure code; - tileconfig_t cfg; - std::memset(&cfg, 0, sizeof(cfg)); - configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); - code.mKernel(&cfg); - } - - protected: - AmxConfigure() { - generate_config(this); - mKernel = getCode(); - } - - func_t mKernel = nullptr; -}; - -namespace kblock { -// optimize for kblock gemm, each block size in k dimension has dequant operation -// all accumulators use fp32 dtype. -template -class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { - public: - static int constexpr RegLen = 16, PackRow = 1; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; - typedef float AType; - typedef float BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - int k; - int n; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_ret = rax; - Xbyak::Opmask msk_wr = k1; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = RegCount - ARegCount - CRegCount; - if (BRegCount < NRegs) { - BRegCount = 0; - ARegCount = BRegCount + 1; - } - if (BRegCount > NRegs) { - BRegCount = NRegs; - } - CReg = 0; - BReg = CReg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg <= RegCount); - TmpRegCount = RegCount - TmpReg; - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - mov(reg_tmp, reg_ksize); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kloop", T_NEAR); - L(".unkloop"); - generate_fma(_mtile, KUNROLL); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_iterk, KUNROLL * KTILE); - cmp(reg_iterk, reg_tmp); // k iteration variable - jb(".unkloop"); - cmp(reg_tmp, reg_ksize); - jge(".kend", T_NEAR); - L(".kloop"); - generate_fma(_mtile, 1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_iterk, 1 * KTILE); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - L(".kend"); - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile) { - for (int kk = 0; kk < _ktile; kk++) { - lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); - if (BRegCount == NRegs) { - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } else if (BRegCount == 0) { - for (int mm = 0; mm < _mtile; mm += ARegCount) { - int mm_re = utils::remainsize(mm, _mtile, ARegCount); - for (int imm = 0; imm < mm_re; imm++) { - vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), - ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - } - } - } else { - assert(0); - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -template -class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { - public: - static int constexpr RegLen = 16, PackRow = 4; - static_assert(_NTILE % RegLen == 0); - static int constexpr NRegs = _NTILE / RegLen; - static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; - static_assert(NRegs * MRegs <= RegCount - 1); - static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; - static int constexpr KUNROLL = 2; - static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; - static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; - typedef uint8_t AType; - typedef int8_t BType; - typedef float CType; - - struct params { - AType* matA; - int astride; - BType* matB; - int bstride; - CType* matC; - int cstride; - uint8_t* zpA; - float* scaleA; - int ldsa; - float* scaleB; - float* reduceB; - int ldsb; - int k; - int n; - int kblock; - int init; - }; - typedef long long (*func_t)(params*); - - int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; - int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; - static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); - static int constexpr AKStepSize = KTILE * sizeof(AType); - - void generate_code(int _mtile) { - assign_regs(); - reset(); - generate_mtile(_mtile); - ready(); - mKernel = getCode(); - } - func_t mKernel = nullptr; - - protected: - Xbyak::Reg64 parambase; - Xbyak::Reg64 reg_matAptr; - Xbyak::Reg64 reg_matBptr; - Xbyak::Reg64 reg_matCptr; - Xbyak::Reg64 reg_ksize; - Xbyak::Reg64 reg_nsize; - Xbyak::Reg64 reg_cstride; - Xbyak::Reg64 reg_astride; - Xbyak::Reg64 reg_iterk; - Xbyak::Reg64 reg_iterkb; - Xbyak::Reg64 reg_itern; - Xbyak::Reg64 reg_tmp; - Xbyak::Reg64 reg_tmp1; - Xbyak::Reg64 reg_tmp2; - Xbyak::Reg64 reg_tmp3; - Xbyak::Reg64 reg_tmp4; - Xbyak::Reg64 reg_ret = rax; - - void assign_regs() { - CRegCount = MRegs * NRegs; - ARegCount = 1; - BRegCount = NRegs; - CReg = 0; - CF32Reg = CReg + CRegCount; - BReg = CF32Reg + CRegCount; - AReg = BReg + BRegCount; - TmpReg = AReg + ARegCount; - assert(TmpReg < RegCount); - TmpRegCount = RegCount - TmpReg; - assert(TmpRegCount >= 1); - } - - void generate_mtile(int _mtile) { - inLocalLabel(); // use local label for multiple instance - Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); - parambase = st.p[0]; - reg_matAptr = st.t[0]; - reg_matBptr = st.t[1]; - reg_matCptr = st.t[0]; - reg_ksize = st.t[2]; - reg_astride = st.t[3]; - reg_cstride = st.t[3]; - reg_iterk = st.t[4]; - reg_iterkb = st.t[12]; - reg_tmp = st.t[5]; - reg_tmp1 = st.t[6]; - reg_tmp2 = st.t[7]; - reg_tmp3 = st.t[10]; - reg_tmp4 = st.t[11]; - reg_nsize = st.t[8]; - reg_itern = st.t[9]; - reg_ret = rax; - - vreg_push(rsp); - - load32(reg_ksize, ptr[parambase + OFFSET(k)]); - load32(reg_nsize, ptr[parambase + OFFSET(n)]); - xor_(reg_itern, reg_itern); - L(".nloop"); - init_regs(_mtile); - mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); - load32(reg_astride, ptr[parambase + OFFSET(astride)]); - mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); - load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); - imul(reg_tmp, reg_itern); - lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); - xor_(reg_iterk, reg_iterk); - generate_kloop(_mtile); - write_back(_mtile); - add(reg_itern, NTILE); - cmp(reg_itern, reg_nsize); - jb(".nloop"); - mov(reg_ret, 0); - vreg_pop(rsp); - - outLocalLabel(); // end of local label - } - - void generate_kloop(int _mtile) { - inLocalLabel(); - xor_(reg_iterkb, reg_iterkb); - L(".kloop"); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); - } - } - xor_(reg_tmp2, reg_tmp2); - load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); - mov(reg_tmp, reg_tmp3); - padto_le(reg_tmp, KUNROLL * KTILE); - cmp(reg_tmp, 0); - jz(".kbloop", T_NEAR); - L(".unkbloop"); - generate_fma(_mtile, KUNROLL, reg_tmp1); - add(reg_matAptr, KUNROLL * AKStepSize); - add(reg_matBptr, KUNROLL * BKStepSize); - add(reg_tmp2, KUNROLL * KTILE); - cmp(reg_tmp2, reg_tmp); - jb(".unkbloop"); - cmp(reg_tmp, reg_tmp3); - jge(".kend", T_NEAR); - L(".kbloop"); - generate_fma(_mtile, 1, reg_tmp1); - add(reg_matAptr, 1 * AKStepSize); - add(reg_matBptr, 1 * BKStepSize); - add(reg_tmp2, 1 * KTILE); - cmp(reg_tmp2, reg_tmp3); - jb(".kbloop"); - L(".kend"); - add(reg_iterk, reg_tmp2); - generate_f32_accumulate(_mtile); - generate_zp_correction(_mtile); - inc(reg_iterkb); - cmp(reg_iterk, reg_ksize); // k iteration variable - jb(".kloop"); - - outLocalLabel(); - } - - void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { - for (int kk = 0; kk < _ktile; kk++) { - lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); - for (int i = 0; i < NRegs; i++) { - vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); - add(reg_tmp1, reg_astride); - for (int i = 0; i < NRegs; i++) { - vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); - } - } - } - } - - void init_regs(int _mtile) { - inLocalLabel(); - load32(reg_tmp, ptr[parambase + OFFSET(init)]); - cmp(reg_tmp, 0); - je(".read", T_NEAR); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); - } - } - jmp(".end", T_NEAR); - L(".read"); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); - } - add(reg_matCptr, reg_cstride); - } - L(".end"); - outLocalLabel(); - } - - void generate_f32_accumulate(int _mtile) { - load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - - mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); - load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); - } - for (int mm = 0; mm < _mtile; mm++) { - vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); - lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); - for (int i = 0; i < NRegs; i++) { - vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); - vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); - vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); - } - } - } - - void generate_zp_correction(int _mtile) { - load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); - imul(reg_tmp1, reg_iterkb); - mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); - lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); - auto& reg_redB = reg_tmp2; - - mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); - lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); - auto& reg_zpA = reg_tmp; - - mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); - lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); - auto& reg_scaleA = reg_tmp1; - - load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); - auto& reg_ldsa = reg_tmp3; - for (int i = 0; i < NRegs; i++) { - vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); - } - - for (int i = 0; i < _mtile; i++) { - vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); - vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); - vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); - vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); - for (int j = 0; j < NRegs; j++) { - vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); - vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); - } - lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); - lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); - } - } - - void write_back(int _mtile) { - inLocalLabel(); - mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); - load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); - lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); - for (int i = 0; i < _mtile; i++) { - for (int j = 0; j < NRegs; j++) { - vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); - } - add(reg_matCptr, reg_cstride); - } - outLocalLabel(); - } -}; - -} // namespace kblock -} // namespace code -template