diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 4a4e286071ff5..ce8fb3160954e 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -7,7 +7,7 @@ jobs: triage: runs-on: ubuntu-latest steps: - - uses: github/issue-labeler@v3.2 + - uses: github/issue-labeler@v3.3 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" configuration-path: .github/labeler.yml diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 9ea9bda7e7c53..fff50d6481a05 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -25,7 +25,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up JDK 11 - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: java-version: '11' distribution: 'adopt' diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index ba8bfd718abfa..d85978568e6c4 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -25,7 +25,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Setup Node.js - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: node-version: 18 - name: Generate JS docs diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 3a780f87d2300..c03abe0be9783 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -26,7 +26,7 @@ jobs: python-version: '3.11.x' architecture: 'x64' - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: node-version: 18 diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 7494035e4784e..23ded3bfc1e68 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() +set(USE_JBLAS FALSE) +if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD) + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + add_compile_definitions(MLAS_JBLAS) + set(USE_JBLAS TRUE) + elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + add_compile_definitions(MLAS_JBLAS) + set(USE_JBLAS TRUE) + endif() +endif() + # TVM EP if (onnxruntime_USE_TVM) if (NOT TARGET tvm) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 26e4380af4c23..bee83ff07c74b 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -45,6 +45,15 @@ endif() set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) +function(add_jblas) + add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas) + target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/jblas_gemm.cpp + ) + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF) +endfunction() + #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/q4gemm_avx512.cpp ) endif() - else() target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp @@ -566,7 +574,7 @@ else() ) set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() + endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) @@ -604,6 +612,10 @@ else() target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) endif() +if(USE_JBLAS) + add_jblas() +endif() + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR}) onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET}) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e5b43ddba8cc7..131db5d8d9b37 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2824,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
size of each input feature
N : int (required)
size of each output feature
+
accuracy_level : int
+
The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.
bits : int (required)
number of bits used for weight quantization (default 4)
block_size : int (required)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index edf249a816923..1ce9b3254d91f 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -80,7 +80,8 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c41700453a73b..dbd5ad41255fa 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3593,17 +3593,11 @@ struct OrtApi { * * QNN supported keys: * "backend_path": file path to QNN backend library. - * "qnn_context_cache_enable": 1 to enable QNN graph creation from cached QNN context file. If it's enabled: QNN EP will - * load from cached QNN context binary if it exist. It will generate a context binary file if it's not exist - * "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided. * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off. * "rpc_control_latency": QNN RPC control latency. * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". - * "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model. - * 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context. - * The path is relative path to the ONNX skeleton model file. * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and * may alter model/EP partitioning. Use only for debugging. diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index a94973b2cc5d7..df79cb6e5b21b 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -235,3 +235,18 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil // Use this config to control the minimum size of the initializer when externalizing it during serialization static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; + +// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file. +// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. +// "0": disable. (default) +// "1": enable. +static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable"; + +// Specify the file path for the Onnx model which has EP context. +// Default to original_file_name_ctx.onnx if not specified +static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path"; + +// Flag to specify whether to dump the EP context into the Onnx model. +// "0": dump the EP context into separate file, keep the file name in the Onnx model. +// "1": dump the EP context into the Onnx model. (default). +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index e129c6971a85c..3e1e833addb91 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -82,7 +82,7 @@ export const resolveBackend = async(backendHints: readonly string[]): Promise; + init(backendName: string): Promise; createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise; diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index 5f5ad49a2dea8..e8eb0e9babf5a 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -20,7 +20,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { } async dispose(): Promise { - return Promise.resolve(); + this.#inferenceSession.dispose(); } readonly inputNames: string[]; diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts index 8a0ce89abfa64..54b5767139904 100644 --- a/js/node/lib/binding.ts +++ b/js/node/lib/binding.ts @@ -28,6 +28,8 @@ export declare namespace Binding { readonly outputNames: string[]; run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType; + + dispose(): void; } export interface InferenceSessionConstructor { diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index c409fdc8895f7..1bbb6df1ce1c8 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -31,6 +31,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { Napi::Function func = DefineClass( env, "InferenceSession", {InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run), + InstanceMethod("dispose", &InferenceSessionWrap::Dispose), InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr), InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)}); @@ -45,7 +46,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { } InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo &info) - : Napi::ObjectWrap(info), initialized_(false), session_(nullptr), + : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr), defaultRunOptions_(nullptr) {} Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { @@ -53,6 +54,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { Napi::HandleScope scope(env); ORT_NAPI_THROW_ERROR_IF(this->initialized_, env, "Model already loaded. Cannot load model multiple times."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); size_t argsLength = info.Length(); ORT_NAPI_THROW_TYPEERROR_IF(argsLength == 0, env, "Expect argument: model file path or buffer."); @@ -129,6 +131,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); Napi::EscapableHandleScope scope(env); return scope.Escape(CreateNapiArrayFrom(env, inputNames_)); @@ -137,6 +140,7 @@ Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); Napi::EscapableHandleScope scope(env); return scope.Escape(CreateNapiArrayFrom(env, outputNames_)); @@ -145,6 +149,7 @@ Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); ORT_NAPI_THROW_TYPEERROR_IF(info.Length() < 2, env, "Expect argument: inputs(feed) and outputs(fetch)."); ORT_NAPI_THROW_TYPEERROR_IF(!info[0].IsObject() || !info[1].IsObject(), env, "Expect inputs(feed) and outputs(fetch) to be objects."); @@ -209,6 +214,18 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) { } } +Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized."); + ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed."); + + this->defaultRunOptions_.reset(nullptr); + this->session_.reset(nullptr); + + this->disposed_ = true; + return env.Undefined(); +} + Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo &info) { Napi::Env env = info.Env(); Napi::EscapableHandleScope scope(env); diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h index 9eee45b72dcb1..1e789c4814cd6 100644 --- a/js/node/src/inference_session_wrap.h +++ b/js/node/src/inference_session_wrap.h @@ -55,6 +55,14 @@ class InferenceSessionWrap : public Napi::ObjectWrap { */ Napi::Value Run(const Napi::CallbackInfo &info); + /** + * [sync] dispose the session. + * @param nothing + * @returns nothing + * @throw nothing + */ + Napi::Value Dispose(const Napi::CallbackInfo &info); + // private members // persistent constructor @@ -62,6 +70,7 @@ class InferenceSessionWrap : public Napi::ObjectWrap { // session objects bool initialized_; + bool disposed_; std::unique_ptr session_; std::unique_ptr defaultRunOptions_; diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 78edcc90f55f9..2d123cdb71290 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -4,7 +4,7 @@ import {cpus} from 'node:os'; import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; -import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; +import {initializeOrtEp, initializeWebAssemblyAndOrtRuntime} from './wasm/proxy-wrapper'; import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference'; /** @@ -33,12 +33,23 @@ export const initializeFlags = (): void => { }; export class OnnxruntimeWebAssemblyBackend implements Backend { - async init(): Promise { + /** + * This function initializes the WebAssembly backend. + * + * This function will be called only once for each backend name. It will be called the first time when + * `ort.InferenceSession.create()` is called with a registered backend name. + * + * @param backendName - the registered backend name. + */ + async init(backendName: string): Promise { // populate wasm flags initializeFlags(); // init wasm - await initializeWebAssemblyInstance(); + await initializeWebAssemblyAndOrtRuntime(); + + // performe EP specific initialization + await initializeOrtEp(backendName); } createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise; diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 6060271ced156..499327741c82b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -21,7 +21,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : require('./backend-wasm-training').wasmBackend; - if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) { + if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 4f4a06c37a94f..6c3d22352772e 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -144,17 +144,7 @@ export class WebGpuBackend { */ sessionExternalDataMapping: Map> = new Map(); - async initialize(env: Env): Promise { - if (!navigator.gpu) { - // WebGPU is not available. - throw new Error('WebGpuBackend: WebGPU is not available.'); - } - - const adapter = await navigator.gpu.requestAdapter(); - if (!adapter) { - throw new Error('WebGpuBackend: Failed to get GPU adapter.'); - } - + async initialize(env: Env, adapter: GPUAdapter): Promise { this.env = env; const requiredFeatures: GPUFeatureName[] = []; const deviceDescriptor: GPUDeviceDescriptor = { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index e6db631c44eea..cad1e87b24a51 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -130,64 +130,76 @@ class ComputeContextImpl implements ComputeContext { } } -export const init = async(module: OrtWasmModule, env: Env): Promise => { - const init = module.jsepInit; - if (init && navigator.gpu) { - if (!env.wasm.simd) { - throw new Error( - 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using WebGPU EP'); - } - const backend = new WebGpuBackend(); - await backend.initialize(env); - - init( - // backend - backend, - - // jsepAlloc() - (size: number) => backend.alloc(size), - - // jsepFree() - (ptr: number) => backend.free(ptr), - - // jsepCopy(src, dst, size, isSourceGpu) - (src: number, dst: number, size: number, isSourceGpu = false) => { - if (isSourceGpu) { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); - backend.memcpy(src, dst); - } else { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); - const data = module.HEAPU8.subarray(src, src + size); - backend.upload(dst, data); - } - }, - - // jsepCopyAsync(src, dst, size) - async(gpuDataId: number, dataOffset: number, size: number): - Promise => { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); - - await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size)); - }, - - // jsepCreateKernel - (name: string, kernel: number, attribute: unknown) => backend.createKernel( - name, kernel, attribute, - env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`), - - // jsepReleaseKernel - (kernel: number) => backend.releaseKernel(kernel), - - // jsepRun - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ - contextDataOffset}`); - const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, errors); - }); +/** + * Initialize JSEP with WebGPU backend. + * + * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called). + * This function expects: + * - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). + * - WebGPU is available in current environment. (a valid GPUAdapter is passed in) + * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate + * 'webgpu' backend. + * + * @param module - the ORT WebAssembly module + * @param env - the ORT environment variable (ort.env) + * @param gpuAdapter - the pre-created GPU adapter + */ +export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise => { + const jsepInit = module.jsepInit; + if (!jsepInit) { + throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.'); } + + const backend = new WebGpuBackend(); + await backend.initialize(env, gpuAdapter); + + jsepInit( + // backend + backend, + + // jsepAlloc() + (size: number) => backend.alloc(size), + + // jsepFree() + (ptr: number) => backend.free(ptr), + + // jsepCopy(src, dst, size, isSourceGpu) + (src: number, dst: number, size: number, isSourceGpu = false) => { + if (isSourceGpu) { + LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); + backend.memcpy(src, dst); + } else { + LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); + const data = module.HEAPU8.subarray(src, src + size); + backend.upload(dst, data); + } + }, + + // jsepCopyAsync(src, dst, size) + async(gpuDataId: number, dataOffset: number, size: number): + Promise => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); + + await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size)); + }, + + // jsepCreateKernel + (name: string, kernel: number, attribute: unknown) => backend.createKernel( + name, kernel, attribute, + env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`), + + // jsepReleaseKernel + (kernel: number) => backend.releaseKernel(kernel), + + // jsepRun + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ + contextDataOffset}`); + const context = new ComputeContextImpl(module, backend, contextDataOffset); + return backend.computeKernel(kernel, context, errors); + }); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 5fffa2f266603..0eb0d40a3ea5e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -772,14 +772,14 @@ class ShaderHelperImpl implements ShaderHelper { const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3` : - `@builtin(local_invocation_index) local_index : u32, + `@builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? - 'let global_idx = global_id.x;' : + 'let global_idx = global_id.x; let local_idx = local_id.x;' : `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ - workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`; + workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) fn main(${paramList}) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 6e9dee41ce488..1c5d28e4b8e3f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -97,8 +97,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let m = global_id.x / N; - let n = global_id.x % N; + let m = global_idx / N; + let n = global_idx % N; var value = ${dataType}(0); for (var k: u32 = 0u; k<${K}u; k++) { @@ -107,7 +107,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${calculateAlpha} ${calculateC} - output[global_id.x] = value; + output[global_idx] = value; }`; return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index 1365d1e9a12a4..7c440cbffea7b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -141,7 +141,6 @@ export const createReduceSharedProgramInfo = return ((a - 1u) / b + 1u); } ${shaderHelper.mainStart(workgroupSize)} - let local_idx = local_id.x; let outputIndex = global_idx / ${workgroupSize}; let offset = outputIndex * uniforms.reduceSize; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 378a7e738dac9..324dc3af1a710 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -73,8 +73,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut } ${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)} ${shaderHelper.mainStart()} - let gindex = i32(global_id.x); - let lindex = i32(local_id.x); + let gindex = i32(global_idx); + let lindex = i32(local_idx); const wg = ${WG}; let row = gindex / wg; let cols = uniforms.packedCols; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 51114d8a99dd1..a25e7fe4229b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -125,8 +125,8 @@ export interface ClipAttributes extends AttributeWithCacheKey { } const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + const min = (inputs.length >= 2 && inputs[1].data !== 0) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3 && inputs[2].data !== 0) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; return createAttributeWithCacheKey({min, max}); }; diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index efeb086256cf3..02246c9ee4767 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -3,6 +3,9 @@ import type {Env, InferenceSession, Tensor} from 'onnxruntime-common'; +/** + * Among all the tensor locations, only 'cpu' is serializable. + */ export type SerializableTensorMetadata = [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; @@ -12,15 +15,28 @@ export type GpuBufferMetadata = { dispose?: () => void; }; +/** + * Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable. + */ export type UnserializableTensorMetadata = [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; +/** + * Tensor metadata is a tuple of [dataType, dims, data, location], where + * - dataType: tensor data type + * - dims: tensor dimensions + * - data: tensor data, which can be one of the following depending on the location: + * - cpu: Uint8Array + * - cpu-pinned: Uint8Array + * - gpu-buffer: GpuBufferMetadata + * - location: tensor data location + */ export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; -export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number]; +export type SerializableInternalBuffer = [bufferOffset: number, bufferLength: number]; interface MessageError { err?: string; @@ -28,35 +44,32 @@ interface MessageError { interface MessageInitWasm extends MessageError { type: 'init-wasm'; - in ?: Env.WebAssemblyFlags; -} - -interface MessageInitOrt extends MessageError { - type: 'init-ort'; in ?: Env; + out?: never; } -interface MessageCreateSessionAllocate extends MessageError { - type: 'create_allocate'; - in ?: {model: Uint8Array}; - out?: SerializableModeldata; +interface MessageInitEp extends MessageError { + type: 'init-ep'; + in ?: {env: Env; epName: string}; + out?: never; } -interface MessageCreateSessionFinalize extends MessageError { - type: 'create_finalize'; - in ?: {modeldata: SerializableModeldata; options?: InferenceSession.SessionOptions}; - out?: SerializableSessionMetadata; +interface MessageCopyFromExternalBuffer extends MessageError { + type: 'copy-from'; + in ?: {buffer: Uint8Array}; + out?: SerializableInternalBuffer; } interface MessageCreateSession extends MessageError { type: 'create'; - in ?: {model: Uint8Array; options?: InferenceSession.SessionOptions}; + in ?: {model: SerializableInternalBuffer|Uint8Array; options?: InferenceSession.SessionOptions}; out?: SerializableSessionMetadata; } interface MessageReleaseSession extends MessageError { type: 'release'; in ?: number; + out?: never; } interface MessageRun extends MessageError { @@ -71,12 +84,8 @@ interface MessageRun extends MessageError { interface MesssageEndProfiling extends MessageError { type: 'end-profiling'; in ?: number; + out?: never; } -interface MessageIsOrtEnvInitialized extends MessageError { - type: 'is-ort-env-initialized'; - out?: boolean; -} - -export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; +export type OrtWasmMessage = MessageInitWasm|MessageInitEp|MessageCopyFromExternalBuffer|MessageCreateSession| + MessageReleaseSession|MessageRun|MesssageEndProfiling; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 1cb6d9e391e4f..4df524cdcfb22 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -36,104 +36,82 @@ declare global { } import {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; +import {createSession, copyFromExternalBuffer, endProfiling, extractTransferableBuffers, initEp, initRuntime, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { - switch (ev.data.type) { - case 'init-wasm': - try { - initializeWebAssembly(ev.data.in!) + const {type, in : message} = ev.data; + try { + switch (type) { + case 'init-wasm': + initializeWebAssembly(message!.wasm) .then( - () => postMessage({type: 'init-wasm'} as OrtWasmMessage), - err => postMessage({type: 'init-wasm', err} as OrtWasmMessage)); - } catch (err) { - postMessage({type: 'init-wasm', err} as OrtWasmMessage); - } - break; - case 'init-ort': - try { - initRuntime(ev.data.in!).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ - type: 'init-ort', - err - } as OrtWasmMessage)); - } catch (err) { - postMessage({type: 'init-ort', err} as OrtWasmMessage); - } - break; - case 'create_allocate': - try { - const {model} = ev.data.in!; - const modeldata = createSessionAllocate(model); - postMessage({type: 'create_allocate', out: modeldata} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'create_allocate', err} as OrtWasmMessage); + () => { + initRuntime(message!).then( + () => { + postMessage({type}); + }, + err => { + postMessage({type, err}); + }); + }, + err => { + postMessage({type, err}); + }); + break; + case 'init-ep': { + const {epName, env} = message!; + initEp(env, epName) + .then( + () => { + postMessage({type}); + }, + err => { + postMessage({type, err}); + }); + break; } - break; - case 'create_finalize': - try { - const {modeldata, options} = ev.data.in!; - const sessionMetadata = createSessionFinalize(modeldata, options); - postMessage({type: 'create_finalize', out: sessionMetadata} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'create_finalize', err} as OrtWasmMessage); + case 'copy-from': { + const {buffer} = message!; + const bufferData = copyFromExternalBuffer(buffer); + postMessage({type, out: bufferData} as OrtWasmMessage); + break; } - break; - case 'create': - try { - const {model, options} = ev.data.in!; + case 'create': { + const {model, options} = message!; const sessionMetadata = createSession(model, options); - postMessage({type: 'create', out: sessionMetadata} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'create', err} as OrtWasmMessage); + postMessage({type, out: sessionMetadata} as OrtWasmMessage); + break; } - break; - case 'release': - try { - releaseSession(ev.data.in!); - postMessage({type: 'release'} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'release', err} as OrtWasmMessage); - } - break; - case 'run': - try { - const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!; + case 'release': + releaseSession(message!); + postMessage({type}); + break; + case 'run': { + const {sessionId, inputIndices, inputs, outputIndices, options} = message!; run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options) .then( outputs => { if (outputs.some(o => o[3] !== 'cpu')) { - postMessage({type: 'run', err: 'Proxy does not support non-cpu tensor location.'}); + postMessage({type, err: 'Proxy does not support non-cpu tensor location.'}); } else { postMessage( - {type: 'run', out: outputs} as OrtWasmMessage, + {type, out: outputs} as OrtWasmMessage, extractTransferableBuffers(outputs as SerializableTensorMetadata[])); } }, err => { - postMessage({type: 'run', err} as OrtWasmMessage); + postMessage({type, err}); }); - } catch (err) { - postMessage({type: 'run', err} as OrtWasmMessage); - } - break; - case 'end-profiling': - try { - const handler = ev.data.in!; - endProfiling(handler); - postMessage({type: 'end-profiling'} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'end-profiling', err} as OrtWasmMessage); - } - break; - case 'is-ort-env-initialized': - try { - const ortEnvInitialized = isOrtEnvInitialized(); - postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); - } catch (err) { - postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + break; } - break; - default: + case 'end-profiling': + endProfiling(message!); + postMessage({type}); + break; + default: + } + } catch (err) { + postMessage({type, err} as OrtWasmMessage); } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 069a1fa452dbc..86017a4ec6904 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env, env, InferenceSession} from 'onnxruntime-common'; +import {env, InferenceSession} from 'onnxruntime-common'; -import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; +import {OrtWasmMessage, SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; @@ -13,18 +13,18 @@ let initializing = false; let initialized = false; let aborted = false; -// resolve; reject -type PromiseCallbacks = [(result: T) => void, (reason: unknown) => void]; - +type PromiseCallbacks = [resolve: (result: T) => void, reject: (reason: unknown) => void]; let initWasmCallbacks: PromiseCallbacks; -let initOrtCallbacks: PromiseCallbacks; -const createSessionAllocateCallbacks: Array> = []; -const createSessionFinalizeCallbacks: Array> = []; -const createSessionCallbacks: Array> = []; -const releaseSessionCallbacks: Array> = []; -const runCallbacks: Array> = []; -const endProfilingCallbacks: Array> = []; -const isOrtEnvInitializedCallbacks: Array> = []; +const queuedCallbacks: Map>> = new Map(); + +const enqueueCallbacks = (type: OrtWasmMessage['type'], callbacks: PromiseCallbacks): void => { + const queue = queuedCallbacks.get(type); + if (queue) { + queue.push(callbacks); + } else { + queuedCallbacks.set(type, [callbacks]); + } +}; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -44,82 +44,40 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { initWasmCallbacks[0](); } break; - case 'init-ort': - if (ev.data.err) { - initOrtCallbacks[1](ev.data.err); - } else { - initOrtCallbacks[0](); - } - break; - case 'create_allocate': - if (ev.data.err) { - createSessionAllocateCallbacks.shift()![1](ev.data.err); - } else { - createSessionAllocateCallbacks.shift()![0](ev.data.out!); - } - break; - case 'create_finalize': - if (ev.data.err) { - createSessionFinalizeCallbacks.shift()![1](ev.data.err); - } else { - createSessionFinalizeCallbacks.shift()![0](ev.data.out!); - } - break; + case 'init-ep': + case 'copy-from': case 'create': - if (ev.data.err) { - createSessionCallbacks.shift()![1](ev.data.err); - } else { - createSessionCallbacks.shift()![0](ev.data.out!); - } - break; case 'release': - if (ev.data.err) { - releaseSessionCallbacks.shift()![1](ev.data.err); - } else { - releaseSessionCallbacks.shift()![0](); - } - break; case 'run': + case 'end-profiling': { + const callbacks = queuedCallbacks.get(ev.data.type)!; if (ev.data.err) { - runCallbacks.shift()![1](ev.data.err); - } else { - runCallbacks.shift()![0](ev.data.out!); - } - break; - case 'end-profiling': - if (ev.data.err) { - endProfilingCallbacks.shift()![1](ev.data.err); - } else { - endProfilingCallbacks.shift()![0](); - } - break; - case 'is-ort-env-initialized': - if (ev.data.err) { - isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + callbacks.shift()![1](ev.data.err); } else { - isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + callbacks.shift()![0](ev.data.out!); } break; + } default: } }; const scriptSrc = typeof document !== 'undefined' ? (document?.currentScript as HTMLScriptElement)?.src : undefined; -export const initializeWebAssemblyInstance = async(): Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - if (initialized) { - return; - } - if (initializing) { - throw new Error('multiple calls to \'initWasm()\' detected.'); - } - if (aborted) { - throw new Error('previous call to \'initWasm()\' failed.'); - } +export const initializeWebAssemblyAndOrtRuntime = async(): Promise => { + if (initialized) { + return; + } + if (initializing) { + throw new Error('multiple calls to \'initWasm()\' detected.'); + } + if (aborted) { + throw new Error('previous call to \'initWasm()\' failed.'); + } - initializing = true; + initializing = true; + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { // overwrite wasm filepaths if (env.wasm.wasmPaths === undefined) { if (scriptSrc && scriptSrc.indexOf('blob:') !== 0) { @@ -142,78 +100,78 @@ export const initializeWebAssemblyInstance = async(): Promise => { proxyWorker.onmessage = onProxyWorkerMessage; URL.revokeObjectURL(workerUrl); initWasmCallbacks = [resolve, reject]; - const message: OrtWasmMessage = {type: 'init-wasm', in : env.wasm}; + const message: OrtWasmMessage = {type: 'init-wasm', in : env}; proxyWorker.postMessage(message); }); } else { - return initializeWebAssembly(env.wasm); + try { + await initializeWebAssembly(env.wasm); + await core.initRuntime(env); + initialized = true; + } catch (e) { + aborted = true; + throw e; + } finally { + initializing = false; + } } }; -export const initializeRuntime = async(env: Env): Promise => { +export const initializeOrtEp = async(epName: string): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { - initOrtCallbacks = [resolve, reject]; - const message: OrtWasmMessage = {type: 'init-ort', in : env}; + enqueueCallbacks('init-ep', [resolve, reject]); + const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env}}; proxyWorker!.postMessage(message); }); } else { - await core.initRuntime(env); + await core.initEp(env, epName); } }; -export const createSessionAllocate = async(model: Uint8Array): Promise => { +export const copyFromExternalBuffer = async(buffer: Uint8Array): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); - return new Promise((resolve, reject) => { - createSessionAllocateCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'create_allocate', in : {model}}; - proxyWorker!.postMessage(message, [model.buffer]); + return new Promise((resolve, reject) => { + enqueueCallbacks('copy-from', [resolve, reject]); + const message: OrtWasmMessage = {type: 'copy-from', in : {buffer}}; + proxyWorker!.postMessage(message, [buffer.buffer]); }); } else { - return core.createSessionAllocate(model); + return core.copyFromExternalBuffer(buffer); } }; -export const createSessionFinalize = async(modeldata: SerializableModeldata, options?: InferenceSession.SessionOptions): - Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - ensureWorker(); - return new Promise((resolve, reject) => { - createSessionFinalizeCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'create_finalize', in : {modeldata, options}}; - proxyWorker!.postMessage(message); - }); - } else { - return core.createSessionFinalize(modeldata, options); - } - }; - export const createSession = - async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - // check unsupported options - if (options?.preferredOutputLocation) { - throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); - } - ensureWorker(); - return new Promise((resolve, reject) => { - createSessionCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'create', in : {model, options}}; - proxyWorker!.postMessage(message, [model.buffer]); - }); - } else { - return core.createSession(model, options); - } -}; + async(model: SerializableInternalBuffer|Uint8Array, options?: InferenceSession.SessionOptions): + Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } + ensureWorker(); + return new Promise((resolve, reject) => { + enqueueCallbacks('create', [resolve, reject]); + const message: OrtWasmMessage = {type: 'create', in : {model, options}}; + const transferable: Transferable[] = []; + if (model instanceof Uint8Array) { + transferable.push(model.buffer); + } + proxyWorker!.postMessage(message, transferable); + }); + } else { + return core.createSession(model, options); + } + }; export const releaseSession = async(sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { - releaseSessionCallbacks.push([resolve, reject]); + enqueueCallbacks('release', [resolve, reject]); const message: OrtWasmMessage = {type: 'release', in : sessionId}; proxyWorker!.postMessage(message); }); @@ -236,7 +194,7 @@ export const run = async( } ensureWorker(); return new Promise((resolve, reject) => { - runCallbacks.push([resolve, reject]); + enqueueCallbacks('run', [resolve, reject]); const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; @@ -251,7 +209,7 @@ export const endProfiling = async(sessionId: number): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { ensureWorker(); return new Promise((resolve, reject) => { - endProfilingCallbacks.push([resolve, reject]); + enqueueCallbacks('end-profiling', [resolve, reject]); const message: OrtWasmMessage = {type: 'end-profiling', in : sessionId}; proxyWorker!.postMessage(message); }); @@ -259,16 +217,3 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; - -export const isOrtEnvInitialized = async(): Promise => { - if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { - ensureWorker(); - return new Promise((resolve, reject) => { - isOrtEnvInitializedCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; - proxyWorker!.postMessage(message); - }); - } else { - return core.isOrtEnvInitialized(); - } -}; diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts index 3ca34d957c572..b62287483208a 100644 --- a/js/web/lib/wasm/session-handler-inference.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -2,14 +2,12 @@ // Licensed under the MIT License. import {readFile} from 'node:fs/promises'; -import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; +import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; +import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; +import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitializationPromise: Promise|undefined; - export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { switch (tensor.location) { case 'cpu': @@ -44,7 +42,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan inputNames: string[]; outputNames: string[]; - async createSessionAllocate(path: string): Promise { + async fetchModelAndCopyToWasmMemory(path: string): Promise { // fetch model from url and move to wasm heap. The arraybufffer that held the http // response is freed once we return const response = await fetch(path); @@ -52,33 +50,26 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan throw new Error(`failed to load model: ${path}`); } const arrayBuffer = await response.arrayBuffer(); - return createSessionAllocate(new Uint8Array(arrayBuffer)); + return copyFromExternalBuffer(new Uint8Array(arrayBuffer)); } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!(await isOrtEnvInitialized())) { - if (!runtimeInitializationPromise) { - runtimeInitializationPromise = initializeRuntime(env); - } - await runtimeInitializationPromise; - runtimeInitializationPromise = undefined; - } + let model: Parameters[0]; if (typeof pathOrBuffer === 'string') { if (typeof process !== 'undefined' && process.versions && process.versions.node) { // node - const model = await readFile(pathOrBuffer); - [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options); + model = await readFile(pathOrBuffer); } else { // browser - // fetch model and move to wasm heap. - const modelData: SerializableModeldata = await this.createSessionAllocate(pathOrBuffer); - // create the session - [this.sessionId, this.inputNames, this.outputNames] = await createSessionFinalize(modelData, options); + // fetch model and copy to wasm heap. + model = await this.fetchModelAndCopyToWasmMemory(pathOrBuffer); } } else { - [this.sessionId, this.inputNames, this.outputNames] = await createSession(pathOrBuffer, options); + model = pathOrBuffer; } + + [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options); } async dispose(): Promise { diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts index 71815f21e650a..e35759192fe3c 100644 --- a/js/web/lib/wasm/session-handler-training.ts +++ b/js/web/lib/wasm/session-handler-training.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; +import {InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common'; -import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference'; -import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {copyFromExternalBuffer} from './wasm-core-impl'; import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, lazyResetGrad, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl'; export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { @@ -18,7 +18,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes evalInputNames: string[] = []; evalOutputNames: string[] = []; - async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { let buffer: Uint8Array; if (typeof uriOrBuffer === 'string') { const response = await fetch(uriOrBuffer); @@ -27,21 +27,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes } else { buffer = uriOrBuffer; } - return createSessionAllocate(buffer); + return copyFromExternalBuffer(buffer); } async createTrainingSession( checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { - if (!isOrtEnvInitialized()) { - await initRuntime(env); - } - const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); // 0 is supposed to be the nullptr - let evalModelData: SerializableModeldata = [0, 0]; - let optimizerModelData: SerializableModeldata = [0, 0]; + let evalModelData: SerializableInternalBuffer = [0, 0]; + let optimizerModelData: SerializableInternalBuffer = [0, 0]; if (evalModelUriOrBuffer !== '') { evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 3aacf8f4d90e0..a9dfd9218bb6f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -3,37 +3,60 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; +import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; -let ortEnvInitialized = false; +// #region Initializations /** - * get the input/output count of the session. - * @param sessionHandle the handle representing the session. should be non-zero. - * @returns a tuple including 2 numbers, representing the input count and output count. + * There are 4 different "initialization" steps for ORT. They happen in different places and different time. + * + * 1. JavaScript initialization for onnxruntime-common and onnxruntime-web. + * This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend() + * function multiple times to register all the available backends. The backend registration is very fast. It only + * registers the backend name with the uninitialized backend object. No heavy initialization is done in this step. + * Refer to web/lib/index.ts for the backend registration. + * + * 2. WebAssembly artifact initialization. + * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or + * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings: + * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled. + * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated + * JavaScript code to initialize the WebAssembly runtime. + * - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'. + * - downloading the 'ort-wasm{...}.wasm' file is done in this step. + * - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool. + * + * 3. ORT environment initialization. + * This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization. + * Function `_OrtInit()` is called in this step. + * - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'. + * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step. + * + * 4. Session initialization. + * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3 + * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the + * followings: + * If the parameter is a URL: + * - download the model data from the URL. + * - copy the model data to the WASM heap. (proxy: 'copy-from') + * - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected. + * - call `_OrtCreateSession()` to create the session. (proxy: 'create') + * + * If the parameter is a Uint8Array object: + * - copy the model data to the WASM heap. (proxy: 'copy-from') + * - call `_OrtCreateSession()` to create the session. (proxy: 'create') + * + * */ -const getSessionInputOutputCount = (sessionHandle: number): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); - if (errorCode !== 0) { - checkLastError('Can\'t get session input/output count.'); - } - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } finally { - wasm.stackRestore(stack); - } -}; /** * initialize ORT environment. + * * @param numThreads SetGlobalIntraOpNumThreads(numThreads) * @param loggingLevel CreateEnv(static_cast(logging_level)) */ @@ -51,18 +74,41 @@ const initOrt = (numThreads: number, loggingLevel: number): void => { export const initRuntime = async(env: Env): Promise => { // init ORT initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel)); +}; + +/** + * perform EP specific initialization. + * + * @param env + * @param epName + */ +export const initEp = async(env: Env, epName: string): Promise => { + if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + // perform WebGPU availability check + if (typeof navigator === 'undefined' || !navigator.gpu) { + throw new Error('WebGPU is not supported in current environment'); + } + const adapter = await navigator.gpu.requestAdapter(); + if (!adapter) { + throw new Error( + 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + } + + if (!env.wasm.simd) { + throw new Error( + 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); + } - if (!BUILD_DEFS.DISABLE_WEBGPU) { // init JSEP if available // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; - await initJsep(getInstance(), env); + await initJsep(getInstance(), env, adapter); } - - ortEnvInitialized = true; }; +// #endregion Initializations + /** * valid data locations for input/output tensors. */ @@ -97,13 +143,33 @@ type SessionMetadata = [ const activeSessions = new Map(); -export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; +/** + * get the input/output count of the session. + * @param sessionHandle the handle representing the session. should be non-zero. + * @returns a tuple including 2 numbers, representing the input count and output count. + */ +const getSessionInputOutputCount = (sessionHandle: number): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } finally { + wasm.stackRestore(stack); + } +}; /** - * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. + * allocate the memory and memcpy the external buffer. + * + * @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap. * @returns a 2-elements tuple - the pointer and size of the allocated buffer */ -export const createSessionAllocate = (model: Uint8Array): [number, number] => { +export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => { const wasm = getInstance(); const modelDataOffset = wasm._malloc(model.byteLength); if (modelDataOffset === 0) { @@ -114,15 +180,30 @@ export const createSessionAllocate = (model: Uint8Array): [number, number] => { }; /** - * create an inference session using the prepared buffer containing the model data. - * @param modelData a 2-elements tuple containing the pointer and size of the model data buffer. + * create an inference session from a model data buffer. + * + * @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the + * pointer and size of the model data buffer. * @param options an optional session options object. * @returns a 3-elements tuple containing [session handle, input names, output names] */ -export const createSessionFinalize = - (modelData: SerializableModeldata, options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { +export const createSession = + (modelData: Uint8Array|SerializableInternalBuffer, + options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { + let modelDataOffset: number, modelDataLength: number; const wasm = getInstance(); + if (Array.isArray(modelData)) { + // if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data + [modelDataOffset, modelDataLength] = modelData; + } else if (modelData.buffer === wasm.HEAPU8.buffer) { + // if model data uses the same buffer as the WASM heap, we don't need to copy it. + [modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength]; + } else { + // otherwise, copy the model data to the WASM heap. + [modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData); + } + let sessionHandle = 0; let sessionOptionsHandle = 0; let ioBindingHandle = 0; @@ -133,7 +214,7 @@ export const createSessionFinalize = try { [sessionOptionsHandle, allocs] = setSessionOptions(options); - sessionHandle = wasm._OrtCreateSession(modelData[0], modelData[1], sessionOptionsHandle); + sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } @@ -201,7 +282,7 @@ export const createSessionFinalize = } throw e; } finally { - wasm._free(modelData[0]); + wasm._free(modelDataOffset); if (sessionOptionsHandle !== 0) { wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } @@ -209,17 +290,6 @@ export const createSessionFinalize = } }; - -/** - * create an instance of InferenceSession. - * @returns the metadata of InferenceSession. 0-value handle for failure. - */ -export const createSession = - (model: Uint8Array, options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { - const modelData: SerializableModeldata = createSessionAllocate(model); - return createSessionFinalize(modelData, options); - }; - export const releaseSession = (sessionId: number): void => { const wasm = getInstance(); const session = activeSessions.get(sessionId); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 0cc28188a6093..c65178e2358d2 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -3,7 +3,7 @@ import {InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, TensorMetadata} from './proxy-messages'; +import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; @@ -32,7 +32,7 @@ const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero } }; -export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { +export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { const wasm = getInstance(); const [checkpointDataOffset, checkpointDataLength] = checkpointData; @@ -108,8 +108,8 @@ export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: }; export const createTrainingSessionHandle = - (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, - optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => { + (checkpointHandle: number, trainModelData: SerializableInternalBuffer, evalModelData: SerializableInternalBuffer, + optimizerModelData: SerializableInternalBuffer, options: InferenceSession.SessionOptions): number => { const wasm = getInstance(); let trainingSessionHandle = 0; diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 29acc07e118f9..5e9b0910a2c68 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -850,7 +850,7 @@ export class ProtoOpTestContext { this.backendHint = test.backend!; this.ioBindingMode = test.ioBinding; - this.loadedData = onnx.ModelProto.encode(model).finish(); + this.loadedData = onnx.ModelProto.encode(model).finish().slice(); // in debug mode, open a new tab in browser for the generated onnx model. if (ort.env.debug) { diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 320a05bb97dac..b060d500c6484 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -20,30 +20,158 @@ class MatMulNBits final : public OpKernel { K_{narrow(info.GetAttr("K"))}, N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, - nbits_{narrow(info.GetAttr("bits"))} { + nbits_{narrow(info.GetAttr("bits"))}, + accuracy_level_{info.GetAttr("accuracy_level")} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + is_asym_ = info.GetInputCount() >= 4; + const Tensor* tensor_B = nullptr; + const Tensor* tensor_scale = nullptr; + const Tensor* tensor_zero_point = nullptr; + bool B_constant = info.TryGetConstantInput(1, &tensor_B); + bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); + bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + all_constant_ = B_constant && scale_constant; + all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; } Status Compute(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + /*out*/ bool& used_shared_buffers) override; + private: const size_t K_; const size_t N_; const size_t block_size_; const size_t nbits_; + const int64_t accuracy_level_; const bool column_wise_quant_{true}; + IAllocatorUniquePtr packed_b_; + size_t packed_b_size_{0}; + bool is_asym_{false}; + bool all_constant_{false}; }; +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (!all_constant_) { + return Status::OK(); + } + auto compt_type = static_cast(accuracy_level_); + MLAS_THREADPOOL* pool = NULL; + if (input_idx == 1) { + packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.Data(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + if (packed_b_ == nullptr) { + return Status::OK(); + } + std::memset(packed_b_.get(), 0, packed_b_size_); + MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, false, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 2 && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, !is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 3 && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_), + is_asym_, is_asym_, compt_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + + return Status::OK(); +} + +Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + // Pack three tensors into one buffer + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 2) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 3) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + return Status::OK(); +} + Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(0); + const auto* a_data = a->Data(); + + if (packed_b_.get()) { + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + std::vector gemm_params(max_len); + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + for (size_t i = 0; i < max_len; i++) { + gemm_params[i].A = a_data + helper.LeftOffsets()[i]; + gemm_params[i].lda = lda; + gemm_params[i].B = packed_b_.get(); + gemm_params[i].C = y_data + helper.OutputOffsets()[i]; + gemm_params[i].ldc = N; + } + auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + // workspace for activation process(dynamic quantization and others) + auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); + MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), + thread_pool); + return Status::OK(); + } + const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); - - const auto* a_data = a->Data(); const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 26fca454c96f0..54eb43753931a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3359,6 +3359,13 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored .Attr("N", "size of each output feature", AttributeProto::INT) .Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT) .Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("accuracy_level", + "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) " + "(default unset). It is used to control how input A is quantized or downcast internally while " + "doing computation, for example: 0 means input A will not be quantized or downcast while doing " + "computation. 4 means input A can be quantized with the same block_size to int8 internally from " + "type T1.", + AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1-dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 9620dd42d1da9..1e83dd1cec400 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -77,3 +77,144 @@ MlasIsSQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen ); + +/** + * @brief Define compute types of block quantization + */ +typedef enum { + CompUndef = 0, /*!< undef */ + CompFp32 = 1, /*!< input fp32, accumulator fp32 */ + CompFp16 = 2, /*!< input fp16, accumulator fp16 */ + CompBf16 = 3, /*!< input bf16, accumulator fp32 */ + CompInt8 = 4 /*!< input int8, accumulator int32 */ +} MLAS_SQNBIT_COMPUTE_TYPE; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t MLASCALL +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type +); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void MLASCALL +MlasNBitsGemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t block_size, + int nbits, + bool is_asym, + bool last_call, + MLAS_SQNBIT_COMPUTE_TYPE comp_type, + MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void MLASCALL +MlasNBitsGemmUnPackB( + float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool +); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t MLASCALL +MlasSQNBitsGemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSQNBitsGemmBatchPackedB( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + void* WorkSpace, + MLAS_THREADPOOL* ThreadPool = nullptr +); diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h new file mode 100644 index 0000000000000..9cd1711a3ffd2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_defs.h @@ -0,0 +1,73 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "jblas/jit_blas_prologue_b.h" +#include "jblas/jit_blas_wrapper.h" + +namespace jblas +{ + +/* +Name conversion explaination: +Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore) +S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight +classes) +F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and +jblas::epilogue::gemm::AccumulatorWriteBackFp32. + +Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores. +*/ +template +using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< + GemmCore_T::ISA, + GemmCore_T, + jblas::prologue_a::gemm::ActivationKBlockBaseF32, + jblas::prologue_b::gemm::WeightKBlockS4, + jblas::epilogue::gemm::CompFp32BlockEpilogue, + jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + +/* +Name conversion explaination: +Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore) +S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only) +F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and +jblas::epilogue::gemm::AccumulatorWriteBackFp32. + +Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores. +*/ +template +using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock< + GemmCore_T::ISA, + GemmCore_T, + jblas::prologue_a::gemm::ActivationF32KBlockQuantize, + jblas::prologue_b::gemm::WeightKBlockS4, + jblas::epilogue::gemm::CompInt8BlockEpilogue, + jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + +using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency +using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency + +class ORTThreading : public jblas::parallel::IThreading +{ + public: + ORTThreading(void* tp); + void parallel_for(const jblas::parallel::thread_func& func) override; + void set_threads(int nthreads) override { assert(0); } + void sync() override { assert(0); } + void* mTp; +}; + +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp new file mode 100644 index 0000000000000..f3cae3186c28e --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_gemm.cpp @@ -0,0 +1,534 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + jblas_gemm.cpp + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#include "jblas_gemm.h" + +#include "jblas_defs.h" +#include "mlasi.h" + +using namespace jblas; + +jblas::ORTThreading::ORTThreading(void* tp) + : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) +{ +} + +void +jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func) +{ + MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) { + func(static_cast(tid)); + }); +} + +template +static void +JblasSQ4GemmCompF32( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + jblas::parallel::IThreading* th +) +{ + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + if (M <= 16) { + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Fp32_S4_F32F32; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->mIsAsym) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single); + } + typename Launcher::BEpiParam blkargs{ + B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(), + reduceA.template get(), reduceA.lda}; + + typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}}; + jblas::parallel::GemmKBlockRun(kernel, args, th); + } else { + using Parallel = jblas::parallel::gemm::SchedulerBase; + using Launcher = jblas::wrapper::gemm::LauncherBase< + GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, + jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + + typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}}; + jblas::parallel::GemmBaseRun(kernel, args, th); + } +} + +template +static void +JblasSQ4GemmCompInt8( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc, + int8_t* WorkSpace, + jblas::parallel::IThreading* th +) +{ + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Int8_S4_F32F32; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + typename Launcher::Param args{ + M_, + N_, + K_, + B->mBlockSize, + {A, lda_, &quanA}, + {B}, + {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep, + quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(), + quanA.template RPtr(), B->mBlockSize}, + {C, ldc_}}; + jblas::parallel::GemmKBlockRun(kernel, args, th); +} + +bool +JblasSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool +) +{ + GetCPUDevice(); + ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto kptr = reinterpret_cast(ptr); + auto coretype = ptr->mCoreId; + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + JblasSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + JblasSQ4GemmCompF32( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + JblasSQ4GemmCompInt8( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc, + WorkSpace, &orth + ); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +template +static size_t +JblasSQ4GemmCompF32WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc +) +{ + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(N); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using Launcher = tLauncher_Fp32_S4_F32F32; + static Launcher kernel; + if (B->mIsAsym) { + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + using Launcher = jblas::wrapper::gemm::LauncherBase< + GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase, + jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; + static Launcher kernel; + return 0; + } + return 0; +} + +template +static size_t +JblasSQ4GemmCompInt8WorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const float* A, + const size_t lda, + jblas::storage::gemm::StorageWeightKBlockS4* B, + float* C, + const size_t ldc +) +{ + using Parallel = jblas::parallel::gemm::SchedulerKBlock; + using Launcher = tLauncher_Int8_S4_F32F32; + static Launcher kernel; + (void)(N); + (void)(lda); + (void)(ldc); + auto quanA = kernel.mProA.createStorage( + static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym + ); + return quanA.mSize; +} + +size_t +JblasSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +) +{ + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto kptr = reinterpret_cast(ptr); + auto coretype = ptr->mCoreId; + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + size = std::max( + JblasSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + size = std::max( + JblasSQ4GemmCompF32WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + size = std::max( + JblasSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc + ), + size + ); + } + } + } + } + } + return size; +} + +template +static size_t +JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) +{ + static T launcher; + auto stor = launcher.mProB.createStorage( + static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, + JBLAS_DTYPE::BF16, isAsym + ); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +size_t +JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType) +{ + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + case CompBf16: + case CompFp16: + case CompFp32: + case CompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return JblasQ4BuSize>(BlkSize, N, K, isAsym); + } + break; + default: + return 0; + } + return 0; +} + +template +static void +JblasQ4GemmPackBImpl( + void* PackedBuf, + size_t BlkSize, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + bool IsAsym, + bool lastCall, + size_t ldb, + MLAS_THREADPOOL* ThreadPool +) +{ + static T JblasKernel; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = JblasKernel.mProB.createStorage( + N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym + ); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + JblasKernel.mProB.reduceWeight(&stor, &orth); + } +} + +bool +JblasQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +) +{ + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case CompInt8: + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + case CompBf16: + case CompFp16: + case CompFp32: + case CompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + JblasQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool + ); + return true; + } + default: + return false; + } + return false; +} + +bool +JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) +{ + auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) { + auto NTile = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT + ); + auto CType = jblas::gemm::CoreAttr::get_mask_val( + ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT + ); + if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) { + if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) { + if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) { + static jblas::prologue_b::gemm::WeightKBlockS4 proB; + proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h new file mode 100644 index 0000000000000..044dc5e849a0a --- /dev/null +++ b/onnxruntime/core/mlas/lib/jblas_gemm.h @@ -0,0 +1,61 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + jblas_gemm.h + +Abstract: + + Currently only support Q4 gemm. +--*/ + +#pragma once + +#include "mlas_qnbit.h" + +size_t +JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType); + +bool +JblasQ4GemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +); + +bool +JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb + , MLAS_THREADPOOL* ThreadPool); + +bool +JblasSQ4GemmBatchDriver( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + int8_t* WorkSpace, + MLAS_THREADPOOL* ThreadPool +); + +size_t +JblasSQ4GemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bda1bb504173..7bb8b17031a84 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -50,7 +50,9 @@ Module Name: #include #endif #if defined(__x86_64__) || defined(__i386__) +#if !defined(signature_VORTEX_ebx) && !defined(signature_NEXGEN_ebx) && !defined(signature_AMD_ebx)//workaround for Bug 96238 - [i386] cpuid.h header needs include guards #include +#endif #if defined(__GNUC__) && __GNUC__ >= 12 #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index f964b1affec31..7f1d1b084aec0 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -15,6 +15,9 @@ Module Name: --*/ #include "sqnbitgemm.h" +#ifdef MLAS_JBLAS +#include "jblas_gemm.h" +#endif namespace { @@ -142,3 +145,127 @@ MlasIsSQNBitGemmAvailable( return true; } + +size_t MLASCALL +MlasNBitsGemmPackBSize( + size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType +) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } +#endif + (void)(N); + (void)(K); + (void)(BlkSize); + (void)(nbits); + (void)(isAsym); + (void)(CompType); + return 0; +} + +void MLASCALL +MlasNBitsGemmPackB( + void* PackedBuf, + const uint8_t* QData, + const float* Scale, + const uint8_t* Zp, + size_t N, + size_t K, + size_t ldb, + size_t BlkSize, + int nbits, + bool isAsym, + bool lastCall, + MLAS_SQNBIT_COMPUTE_TYPE CompType, + MLAS_THREADPOOL* ThreadPool +) +{ +#ifdef MLAS_JBLAS + if (nbits == 4) { + if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +#endif + (void)(PackedBuf); + (void)(QData); + (void)(Scale); + (void)(Zp); + (void)(N); + (void)(K); + (void)(ldb); + (void)(BlkSize); + (void)(nbits); + (void)(isAsym); + (void)(lastCall); + (void)(CompType); + (void)(ThreadPool); +} + +void MLASCALL +MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool) +{ +#ifdef MLAS_JBLAS + if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +#endif + (void)(FpData); + (void)(PackedBuf); + (void)(N); + (void)(K); + (void)(ldb); + (void)(ThreadPool); +} + +size_t MLASCALL +MlasSQNBitsGemmBatchWorkspaceSize( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams +) +{ +#ifdef MLAS_JBLAS + return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +#endif + (void)(M); + (void)(N); + (void)(K); + (void)(BatchN); + (void)(DataParams); + return 0; +} + +void MLASCALL +MlasSQNBitsGemmBatchPackedB( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, + void* WorkSpace, + MLAS_THREADPOOL* ThreadPool +) +{ + GetMlasPlatform(); +#ifdef MLAS_JBLAS + if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by jblas + return; + } +#endif + (void)(M); + (void)(N); + (void)(K); + (void)(BatchN); + (void)(DataParams); + (void)(WorkSpace); + (void)(ThreadPool); +} diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format new file mode 100644 index 0000000000000..84b876706161d --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format @@ -0,0 +1,7 @@ +Language: Cpp +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 120 +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SortIncludes: false diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt new file mode 100644 index 0000000000000..5d9c5edf45a96 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt @@ -0,0 +1,33 @@ +cmake_minimum_required(VERSION 3.5) + +project(jblas LANGUAGES CXX VERSION 0.1.0) + +file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) +file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) + +add_library(${PROJECT_NAME} INTERFACE) +add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) + +target_include_directories( + ${PROJECT_NAME} INTERFACE + "$" + "$" +) + +if(WIN32) + target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) + target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100) + #4068 ignore unroll and GCC flags + #4849 ignore collapse + #6262 ignore stack too large + #4702 unreachable code(false warning on constexpr condition) + #4100 unreferenced formal parameter + + target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size +endif(WIN32) + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h new file mode 100644 index 0000000000000..143adb771760b --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h @@ -0,0 +1,303 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include +#include +#include "xbyak/xbyak.h" +#include "xbyak/xbyak_util.h" + +#define OFFSET(field) offsetof(params, field) + +namespace jblas { + +namespace xbyak { +class JitBase : protected Xbyak::CodeGenerator { + protected: + JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {} + + void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) { + xor_(reg, reg); + mov(reg.cvt32(), addr); + } + + void vreg_push(const Xbyak::Reg64& baseaddr) { +#ifdef _WIN32 + for (int i = 0; i < 10; i++) { + movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i)); + } +#endif + } + + void vreg_pop(const Xbyak::Reg64& baseaddr) { +#ifdef _WIN32 + for (int i = 0; i < 10; i++) { + movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]); + } +#endif + } + + void padto_le(const Xbyak::Reg64& _src, int padding) { + // _src=_src/padding*padding + if (padding == 1) { + return; + } + for (int i = 1; i < 16; i++) { + if ((1 << i) == padding) { + shr(_src, i); + shl(_src, i); + return; + } + } + assert(0); + } + + void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total, + const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { + inLocalLabel(); + lea(_tmp, _total); + sub(_tmp, _pos); + cmp(_tmp, N); + jb(".maskflag"); + cmp(_tmp, 0); + jl(".zeroflag"); + uint64_t allmask = (static_cast(1) << N) - 1; + if (N == 64) { + allmask = static_cast(-1); + } + mov(_tmp, allmask); + kmovq(_msk, _tmp); + jmp(".maskend"); + L(".maskflag"); + mov(_tmp1, 1); + shlx(_tmp1, _tmp1, _tmp); + sub(_tmp1, 1); + kmovq(_msk, _tmp1); + jmp(".maskend"); + L(".zeroflag"); + mov(_tmp1, 0); + kmovq(_msk, _tmp1); + L(".maskend"); + outLocalLabel(); + } + void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total, + const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) { + generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N); + } +}; + +class JitAvx : protected JitBase { + protected: + static int constexpr VBits = 256; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 16; + typedef Xbyak::Ymm vreg_t; +}; + +class JitAvx2 : protected JitAvx { + protected: + static int constexpr VBits = 256; + typedef Xbyak::Ymm vreg_t; + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); } + + void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) { + vpmovzxwd(dst, addr); + vpslld(dst, dst, 16); + } +}; + +class JitAvx512f : protected JitAvx2 { + protected: + static int constexpr VBits = 512; + static int constexpr VecBytes = VBits / 8; + static int constexpr RegCount = 32; + typedef Xbyak::Zmm vreg_t; + + void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); } + + void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) { + vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]); + vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]); + vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6)); + vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6)); + vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6)); + vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6)); + } + + void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) { + for (int i = 0; i < 8; ++i) { + vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]); + vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]); + } + + for (int i = 0; i < 4; ++i) { + vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]); + vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]); + vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]); + vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]); + } + + for (int i = 0; i < 2; ++i) { + vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88); + vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88); + vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88); + vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88); + vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd); + vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd); + vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd); + vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd); + } + + // last step and move out + for (int i = 0; i < N; ++i) { + vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd); + } + } + + void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) { + vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]); + vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]); + vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]); + vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]); + + vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]); + vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]); + vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]); + vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]); + vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4); + vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4); + vmovups(src_4regs[0], tmp_regs[1]); + vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); + vmovups(src_4regs[1], tmp_regs[3]); + vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); + vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14); + vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14); + vmovups(src_4regs[2], tmp_regs[1]); + vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6)); + vmovups(src_4regs[3], tmp_regs[3]); + vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6)); + } + + void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { + vpsrld(_fp32, _fp32, 16); + vpmovdw(_bf16, _fp32); + } + + void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) { + vpmovzxwd(dst, addr); + vpslld(dst, dst, 16); + } + + void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) { + mov(tmp.cvt16(), addr); + shl(tmp.cvt32(), 16); + vpbroadcastd(dst, tmp.cvt32()); + } + + void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) { + auto bf16 = Xbyak::Ymm(_fp32.getIdx()); + cvt_fp32_bf16(bf16, _fp32); + vmovups(_add, bf16); + } +}; + +class JitAvx512_bf16 : protected JitAvx512f {}; + +class JitAvx512_fp16 : protected JitAvx512f {}; + +class JitAvx512vnni : protected JitAvx512f { + protected: + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + vpdpbusds(x1, x2, op, Xbyak::EvexEncoding); + } +}; + +class JitAvxvnni : protected JitAvx2 { + protected: + void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) { + vpdpbusds(x1, x2, op, Xbyak::VexEncoding); + } +}; + +class JitAmxtile : protected JitAvx512f { + public: + struct alignas(64) tileconfig_t { + uint8_t palette_id; + uint8_t reserved[15]; + uint16_t colb[16]; + uint8_t rows[16]; + }; + static int constexpr TileCount = 8; + + typedef long long (*configure_t)(void*); + + static void generate_config(Xbyak::CodeGenerator* g) { + Xbyak::util::StackFrame st(g, 1, 0, 0); + auto& parambase = st.p[0]; + g->ldtilecfg(g->ptr[parambase]); + } + + static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, + int CNum) { + // Filling tile configure structure. Could be done offline. + tc.palette_id = 1; + // Configure C tiles + int t = 0; + for (; t < CNum; ++t) { + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_N * 4); + } + // Configure A tiles + for (; t < CNum + ANum; ++t) { + tc.rows[t] = static_cast(TILE_M); + tc.colb[t] = static_cast(TILE_K * elesize); + } + // Configure B tile. B effectively has 64 rows and 16 columns. + int kpack = 4 / elesize; + for (; t < CNum + ANum + BNum; ++t) { + tc.rows[t] = static_cast(TILE_K / kpack); + tc.colb[t] = static_cast(TILE_N * 4); + } + } +}; + +class JitAmxbf16 : protected JitAmxtile { + protected: + void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); } +}; + +class JitAmxint8 : protected JitAmxtile { + protected: + template + void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3); +}; +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbssd(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbsud(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbusd(x1, x2, x3); +} +template <> +inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) { + tdpbuud(x1, x2, x3); +} +} // namespace xbyak +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h new file mode 100644 index 0000000000000..8ecf3535c17f4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h @@ -0,0 +1,96 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +enum JBLAS_CODE { + JblasSuccess = 0, + JblasInvalidParam = 1, + JblasInvalidISA = 2, + JblasRuntimeError = 4, + JblasNotSupport = 8, +}; +enum JBLAS_ISA : uint32_t { + JblasNoSIMD = 0, + JblasAVX, + JblasAVX2, + JblasAVX_VNNI, + JblasAVX512F, + JblasAVX512_VNNI, + JblasAMX_BF16, + JblasAMX_INT8, + JblasAVX512_FP16, + JblasAVX512_BF16, +}; +enum class JBLAS_DTYPE : uint32_t { + EleBitsMask = 0xff, + EleBitsUndef = 0, + EleBits4 = 4, + EleBits8 = 8, + EleBits16 = 16, + EleBits32 = 32, + EleBits64 = 64, + TypeMask = 0xff00, + TypeFloat = 0 << 8, + TypeInt = 1 << 8, + SubTypeMask = 0xff0000, + SubType0 = 0 << 16, + SubType1 = 1 << 16, + SubType2 = 2 << 16, + F64 = EleBits64 | TypeFloat, + F32 = EleBits32 | TypeFloat, + F16 = EleBits16 | TypeFloat, + BF16 = EleBits16 | TypeFloat | SubType1, + F8_E4M3 = EleBits8 | TypeFloat, + F8_E5M2 = EleBits8 | TypeFloat | SubType1, + F8_E3M4 = EleBits8 | TypeFloat | SubType2, + S8 = EleBits8 | TypeInt, + U8 = EleBits8 | TypeInt | SubType1, + S4_CLIP = EleBits4 | TypeInt, + S4_FULLRANGE = EleBits4 | TypeInt | SubType1, + F4_E2M1 = EleBits4 | TypeFloat, + F4_BNB = EleBits4 | TypeFloat | SubType1, + F4_NF4 = EleBits4 | TypeFloat | SubType2, + S32 = EleBits32 | TypeInt, + U32 = EleBits32 | TypeInt | SubType1, +}; + +enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 }; +enum JBLAS_TRANSPOSE { + JblasNoTrans = 111, + JblasTrans = 112, + JblasConjTrans = 113, +}; +enum JBLAS_ELTWISEOP { + GELU, + SWISH, + TANH, + EXP, + LOW_PRECISION_EXP, + RELU, + LINEAR, +}; + +enum class JBLAS_PROLOGUEB_IDS : uint32_t { + Undef = (uint32_t)-1, + Begin = 0, + NormalBegin = Begin, + WeightPack = NormalBegin, + NormalEnd, + KBlockBegin = NormalEnd, + WeightKBlockS8 = KBlockBegin, + WeightKBlockS4, + WeightKBlockF4, + KBlockEnd, + End, +}; diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h new file mode 100644 index 0000000000000..5cac1080bc610 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h @@ -0,0 +1,277 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "jit_blas.h" +#include "xbyak/xbyak_util.h" + +namespace jblas { + +namespace device { + +struct X64_ISA { + int64_t MMX : 1; // 0 + int64_t SSE : 1; // 1 + int64_t SSE2 : 1; // 2 + int64_t SSE3 : 1; // 3 + int64_t SSSE3 : 1; // 4 + int64_t SSE41 : 1; // 5 + int64_t SSE42 : 1; // 6 + int64_t AVX : 1; // 7 + int64_t F16C : 1; // 8 + int64_t FMA : 1; // 9 + int64_t AVX2 : 1; // 10 + int64_t AVX_VNNI : 1; // 11 + int64_t AVX_VNNI_INT8 : 1; // 12 + int64_t AVX_NE_CONVERT : 1; // 13 + int64_t AVX_IFMA : 1; // 14 + int64_t AVX512F : 1; // 15 + int64_t AVX512BW : 1; // 16 + int64_t AVX512CD : 1; // 17 + int64_t AVX512DQ : 1; // 18 + int64_t AVX512ER : 1; // 19 + int64_t AVX512IFMA52 : 1; // 20 + int64_t AVX512PF : 1; // 21 + int64_t AVX512VL : 1; // 22 + int64_t AVX512VPOPCNTDQ : 1; // 23 + int64_t AVX512_4FMAPS : 1; // 24 + int64_t AVX512_4VNNIW : 1; // 25 + int64_t AVX512_BF16 : 1; // 26 + int64_t AVX512_BITALG : 1; // 27 + int64_t AVX512_VBMI : 1; // 28 + int64_t AVX512_VBMI2 : 1; // 29 + int64_t AVX512_VNNI : 1; // 30 + int64_t AVX512_VP2INTERSECT : 1; // 31 + int64_t AVX512_FP16 : 1; // 32 + int64_t AMX_TILE : 1; // 33 + int64_t AMX_BF16 : 1; // 34 + int64_t AMX_INT8 : 1; // 35 + int64_t AMX_FP16 : 1; // 36 + int64_t AMX_COMPLEX : 1; // 37 + int64_t reserved : (64 - 38); +}; + +class AVX2_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 0; + static constexpr bool AVX512BW = 0; + static constexpr bool AVX512CD = 0; + static constexpr bool AVX512DQ = 0; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 0; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 0; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class AVX512_VNNI_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class SapphireRapids { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 1; + static constexpr bool AMX_BF16 = 1; + static constexpr bool AMX_INT8 = 1; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +template +class isa_base { + public: + static bool constexpr avx = ISA_T >= JblasAVX; + static bool constexpr avx2 = ISA_T >= JblasAVX2; + static bool constexpr avx512f = ISA_T >= JblasAVX512F; + static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI; + static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16; + static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16; + static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8; +}; + +class CpuDevice { + public: + inline void setThreads(int _nth) { + if (_nth <= 0) { + numthreads = numcores; + } else { + numthreads = std::min(numcores, _nth); + } + } + inline int getThreads() { return numthreads; } + inline int getCores() { return numcores; } + inline uint32_t getL2CacheSize() { return L2Cache; } + inline uint32_t getL1CacheSize() { return L1Cache; } + inline bool AVX() { return mHasAVX; } + inline bool AVX2() { return mHasAVX2; } + inline bool AVX_VNNI() { return mHasAVX_VNNI; } + inline bool AVX512F() { return mHasAVX512F; } + inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } + inline bool AMX_INT8() { return mHasAMX_INT8; } + inline bool AMX_BF16() { return mHasAMX_BF16; } + inline bool AVX512_BF16() { return mHasAVX512_BF16; } + inline bool AVX512_FP16() { return mHasAVX512_FP16; } +#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) + CpuDevice() { + static Xbyak::util::Cpu _cpu; + L1Cache = _cpu.getDataCacheSize(0); + L2Cache = _cpu.getDataCacheSize(1); + ADD_FLAG(AVX); + ADD_FLAG(AVX2); + ADD_FLAG(AVX512F); + ADD_FLAG(AVX512_VNNI); + ADD_FLAG(AVX_VNNI); + ADD_FLAG(AMX_BF16); + ADD_FLAG(AMX_INT8); + ADD_FLAG(AVX512_BF16); + ADD_FLAG(AVX512_FP16); + numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + numthreads = numcores; + } + + static CpuDevice* getInstance() { + static CpuDevice instance; + return &instance; + } + + void print() { + printf( + "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", + mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, + mHasAVX512_FP16); + } +#undef ADD_FLAG + + protected: + uint32_t L2Cache, L1Cache; + bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, + mHasAVX512_FP16; + int numcores; + int numthreads; +}; + +#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance(); + +class CpuBase { + public: + CpuBase() { + GetCPUDevice(); + mL2Cache = _cd->getL2CacheSize(); + mL1Cache = _cd->getL1CacheSize(); + mNumThreads = _cd->getThreads(); + } + size_t mL2Cache, mL1Cache; + int mNumThreads; +}; +} // namespace device +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h new file mode 100644 index 0000000000000..ceb7a545092d8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h @@ -0,0 +1,329 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "jit_base.h" +#include "jit_blas.h" +#include "jit_blas_utils.h" +#include "kernel_wrapper.h" + +namespace jblas { +namespace epilogue { +namespace gemm { + +template +class AccumulatorWriteBack { + public: + using SType = _SRC_T; + using DType = _DST_T; + struct Param { + DType* C; + int ldc; + void* elt_const_v; + }; + + template + JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + bool constexpr Valid = !std::is_same::value || std::is_same::value; + static_assert(Valid, "fp32 to bf16 conversion only."); + if constexpr (std::is_same::value) { + return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward( + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + } else if constexpr (std::is_same, std::tuple>::value) { + return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward( + const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false); + } else if constexpr (sizeof(SType) == sizeof(DType)) { + return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, + _param.ldc, _param.elt_const_v, ops...); + } else { + assert(false); + } + } +}; + +template +class CustomAccumulatorWriteBackWithEltop { + public: + struct Param { + _DST_T* C; + int ldc; + void* elt_const_v; + }; + JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { + return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, + _param.ldc, _param.elt_const_v); + } else { + assert(false); + } + } +}; +template +using AccumulatorWriteBackFp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackInt32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackBf16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; + +template +using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; + +template +using AccumulatorWriteBackWithSwishFp32 = CustomAccumulatorWriteBackWithEltop; + +template +class AlphaBetaProcessFp32 { + public: + struct Param { + float *C, *D; + int ldc, ldd; + float alpha, beta; + }; + + JBLAS_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto DOffset = M_offset * _param.ldd + N_offset; + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto dptr = _param.D + DOffset; + return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, + dptr, _param.ldd, cptr, _param.ldc, M, N); + } +}; + +template +class CompFp32BlockEpilogue { + public: + struct Param { + void* scales; + JBLAS_DTYPE scaledtype; + int ldsb; + int8_t* zps = nullptr; + float* reduce = nullptr; + int ldra; + }; + JBLAS_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + auto ret = JblasNotSupport; + if (_param.scaledtype == JBLAS_DTYPE::F32) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + assert(ret == JblasSuccess); + if (_param.zps != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( + dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, + _param.reduce + M_offset * _param.ldra + K_offset); + } + assert(ret == JblasSuccess); + return ret; + } else if (_param.scaledtype == JBLAS_DTYPE::BF16) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + assert(_param.zps == nullptr); + assert(ret == JblasSuccess); + return ret; + } + return JblasNotSupport; + } +}; + +template +class DequantInt32ToFp32 { + public: + struct Param { + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; + }; + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, _param.ldsa, + _param.scalesB + N_offset); + } +}; + +template +class CompInt8BlockEpilogue { + public: + struct Param { + void* scalesB; + JBLAS_DTYPE scaleBdtype; + int ldsb; + float* scalesA; + int ldsa; + // optional if A asym + uint8_t* zpA = nullptr; + void* reduceB = nullptr; + JBLAS_DTYPE reduceBdtype = JBLAS_DTYPE::F32; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; + }; + JBLAS_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + JBLAS_CODE ret = JblasNotSupport; + float* scab = nullptr; + size_t ScaleBTmpSize = N * sizeof(float); + size_t ReduceBTmpSize = N * sizeof(float); + assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); + if (_param.scaleBdtype == JBLAS_DTYPE::BF16) { + auto scache = reinterpret_cast(tmpcache); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, + false); + assert(ret == JblasSuccess); + scab = scache; + } else if (_param.scaleBdtype == JBLAS_DTYPE::F32) { + scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; + } + float* redb = nullptr; + if (_param.reduceB) { + if (_param.reduceBdtype == JBLAS_DTYPE::BF16) { + auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, + false); + assert(ret == JblasSuccess); + redb = rcache; + } else if (_param.reduceBdtype == JBLAS_DTYPE::F32) { + redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; + } + } + ret = kernel::wrapper::DequanS32Fp32::template forward( + srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); + assert(ret == JblasSuccess); + ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, + dstptr, cachestep, M, N); + assert(ret == JblasSuccess); + + if (_param.zpA == nullptr) { + if (_param.zpB == nullptr) { + return ret; + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa + K_offset); + } + } else { + if (_param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, + _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); + } + } + return ret; + } +}; + +template +class ZpDequantInt32ToFp32 { + public: + struct Param { + // necessary + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; + // optional if A asym + uint8_t* zpA = nullptr; + float* reduceB = nullptr; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; + }; + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.scalesB + N_offset); + if (ret != JblasSuccess) { + return ret; + } + if (_param.zpA == nullptr && _param.zpB == nullptr) { + return ret; + } else if (_param.zpA != nullptr && _param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.reduceB + N_offset); + } else if (_param.zpA == nullptr && _param.zpB != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, + _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, + _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); + } + return ret; + } +}; + +template +class AlphaBetaProcessS32U8 { + public: + struct Param { + uint8_t* C; + int ldc; + float alpha; + float scaleAcc, scaleC; + int zpC; + }; + + JBLAS_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, + M, N, _param.scaleAcc, _param.scaleC, _param.zpC); + } +}; + +} // namespace gemm +} // namespace epilogue +} // namespace jblas diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h new file mode 100644 index 0000000000000..364da9223940f --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_gemm.h @@ -0,0 +1,2699 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "jit_blas_utils.h" +#include "jit_base.h" + +namespace jblas { +namespace gemm { +enum class CompType : uint32_t { + COMP_FP32 = 0, + COMP_BF16_FP32 = 1, + COMP_FP16_FP16 = 2, + COMP_INT_START = 3, + COMP_INT8_US_INT32 = COMP_INT_START, + COMP_INT8_UU_INT32 = 4, + COMP_INT8_SS_INT32 = 5, + COMP_INT8_SU_INT32 = 6, + COMP_INT16_SS_INT32 = 7, + COMP_INT8_US_FP32 = 8, + COMP_INT8_UU_FP32 = 9, + COMP_INT8_SS_FP32 = 10, + COMP_INT8_SU_FP32 = 11, +}; + +class CoreAttr { + public: + // INT32=LSB|**8bits:NTile**||**8bits:PackRow**||**8bits:CompType**||**8bits:Reserve**| + static uint32_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, + COMP_MASK = 0xff0000, COMP_SHIFT = 16, ISA_MASK = 0xff000000, ISA_SHIFT = 24; + + static inline uint32_t get_mask_val(uint32_t raw, uint32_t mask, uint32_t shift) { return (raw & mask) >> shift; } + static constexpr uint32_t make_core_id(uint32_t NTile, uint32_t PackRow, uint32_t CompType, uint32_t ISA) { + return (NTile << NTILE_SHIFT) | (PackRow << PACKROW_SHIFT) | (CompType << COMP_SHIFT) | (ISA << ISA_SHIFT); + } + + static void parse_id(uint32_t id, uint32_t* vals) { + vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); + vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); + vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); + } + + static const char* to_str(uint32_t id) { + static char tmp[128]; + uint32_t vals[4]; + parse_id(id, vals); + sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", vals[0], vals[1], vals[2], vals[3]); + return tmp; + } + + static inline size_t get_bsize(uint32_t id) { + auto packrow = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + return size_t(4 / packrow); + } +}; + +namespace code { + +template +class Avx2N8P1 : protected jblas::xbyak::JitAvx2 { + public: + static int constexpr RegLen = 8, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX2; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fp16N32P1 : protected jblas::xbyak::JitAvx512_fp16 { + public: + static int constexpr RegLen = 32, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_FP16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP16_FP16; + typedef utils::fp16 AType; + typedef utils::fp16 BType; + typedef utils::fp16 CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512bf16N16P2 : protected jblas::xbyak::JitAvx512_bf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class AvxvnniN8P4 : protected jblas::xbyak::JitAvxvnni { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Amxbf16N16P2 : protected jblas::xbyak::JitAmxbf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_BF16; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; + +template +class Amxint8N16P4 : protected jblas::xbyak::JitAmxint8 { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAMX_INT8; + static uint32_t constexpr COMPUTE = + (uint32_t)(std::is_same_v + ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 + : std::is_same_v ? CompType::COMP_INT8_US_INT32 + : CompType::COMP_INT8_UU_INT32); + using AType = AT; + using BType = BT; + typedef int32_t CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto& reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; +template +using Amxint8N16P4US = Amxint8N16P4; + +template +using Amxint8N16P4SS = Amxint8N16P4; + +class AmxConfigure : protected jblas::xbyak::JitAmxtile { + public: + typedef long long (*func_t)(tileconfig_t*); + + static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { + static AmxConfigure code; + tileconfig_t cfg; + std::memset(&cfg, 0, sizeof(cfg)); + configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); + code.mKernel(&cfg); + } + + protected: + AmxConfigure() { + generate_config(this); + mKernel = getCode(); + } + + func_t mKernel = nullptr; +}; + +namespace kblock { +// optimize for kblock gemm, each block size in k dimension has dequant operation +// all accumulators use fp32 dtype. +template +class Avx512fN16P1 : protected jblas::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512F; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected jblas::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static uint32_t constexpr ISA = (uint32_t)JBLAS_ISA::JblasAVX512_VNNI; + static uint32_t constexpr COMPUTE = (uint32_t)CompType::COMP_INT8_US_FP32; + typedef uint8_t AType; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = NRegs; + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= 1); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); + } + } + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); + vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); + vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + } + } + } + + void generate_zp_correction(int _mtile) { + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); + } + + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); + for (int j = 0; j < NRegs; j++) { + vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); + vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +} // namespace kblock +} // namespace code +template