diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index 4a4e286071ff5..ce8fb3160954e 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -7,7 +7,7 @@ jobs:
triage:
runs-on: ubuntu-latest
steps:
- - uses: github/issue-labeler@v3.2
+ - uses: github/issue-labeler@v3.3
with:
repo-token: "${{ secrets.GITHUB_TOKEN }}"
configuration-path: .github/labeler.yml
diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml
index 9ea9bda7e7c53..fff50d6481a05 100644
--- a/.github/workflows/publish-java-apidocs.yml
+++ b/.github/workflows/publish-java-apidocs.yml
@@ -25,7 +25,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Set up JDK 11
- uses: actions/setup-java@v3
+ uses: actions/setup-java@v4
with:
java-version: '11'
distribution: 'adopt'
diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml
index ba8bfd718abfa..d85978568e6c4 100644
--- a/.github/workflows/publish-js-apidocs.yml
+++ b/.github/workflows/publish-js-apidocs.yml
@@ -25,7 +25,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Setup Node.js
- uses: actions/setup-node@v3
+ uses: actions/setup-node@v4
with:
node-version: 18
- name: Generate JS docs
diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml
index 3a780f87d2300..c03abe0be9783 100644
--- a/.github/workflows/windows.yml
+++ b/.github/workflows/windows.yml
@@ -26,7 +26,7 @@ jobs:
python-version: '3.11.x'
architecture: 'x64'
- - uses: actions/setup-node@v3
+ - uses: actions/setup-node@v4
with:
node-version: 18
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 7494035e4784e..23ded3bfc1e68 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -87,6 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
+option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
@@ -1166,6 +1167,17 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()
+set(USE_JBLAS FALSE)
+if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
+ if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
+ add_compile_definitions(MLAS_JBLAS)
+ set(USE_JBLAS TRUE)
+ elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
+ add_compile_definitions(MLAS_JBLAS)
+ set(USE_JBLAS TRUE)
+ endif()
+endif()
+
# TVM EP
if (onnxruntime_USE_TVM)
if (NOT TARGET tvm)
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 26e4380af4c23..bee83ff07c74b 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -45,6 +45,15 @@ endif()
set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)
+function(add_jblas)
+ add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
+ target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
+ target_sources(onnxruntime_mlas PRIVATE
+ ${MLAS_SRC_DIR}/jblas_gemm.cpp
+ )
+ set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
+endfunction()
+
#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)
@@ -200,7 +209,6 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
endif()
-
else()
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
@@ -566,7 +574,7 @@ else()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
- endif()
+ endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs})
@@ -604,6 +612,10 @@ else()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()
+if(USE_JBLAS)
+ add_jblas()
+endif()
+
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index e5b43ddba8cc7..131db5d8d9b37 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2824,6 +2824,8 @@ This version of the operator has been available since version 1 of the 'com.micr
size of each input feature
N : int (required)
size of each output feature
+accuracy_level : int
+The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) (default unset). It is used to control how input A is quantized or downcast internally while doing computation, for example: 0 means input A will not be quantized or downcast while doing computation. 4 means input A can be quantized with the same block_size to int8 internally from type T1.
bits : int (required)
number of bits used for weight quantization (default 4)
block_size : int (required)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index edf249a816923..1ce9b3254d91f 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -80,7 +80,8 @@ Do not modify directly.*
|Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)|
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)|
-|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**
or
*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
+|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**
or
*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
+|||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index c41700453a73b..dbd5ad41255fa 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -3593,17 +3593,11 @@ struct OrtApi {
*
* QNN supported keys:
* "backend_path": file path to QNN backend library.
- * "qnn_context_cache_enable": 1 to enable QNN graph creation from cached QNN context file. If it's enabled: QNN EP will
- * load from cached QNN context binary if it exist. It will generate a context binary file if it's not exist
- * "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided.
* "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off.
* "rpc_control_latency": QNN RPC control latency.
* "vtcm_mb": QNN VTCM size in MB. default to 0(not set).
* "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
* "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
- * "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model.
- * 0 means dump the QNN context binary into separate bin file and set the path to EPContext->ep_cache_context.
- * The path is relative path to the ONNX skeleton model file.
* "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
* dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
* may alter model/EP partitioning. Use only for debugging.
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index a94973b2cc5d7..df79cb6e5b21b 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -235,3 +235,18 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil
// Use this config to control the minimum size of the initializer when externalizing it during serialization
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";
+
+// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file.
+// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
+// "0": disable. (default)
+// "1": enable.
+static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
+
+// Specify the file path for the Onnx model which has EP context.
+// Default to original_file_name_ctx.onnx if not specified
+static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
+
+// Flag to specify whether to dump the EP context into the Onnx model.
+// "0": dump the EP context into separate file, keep the file name in the Onnx model.
+// "1": dump the EP context into the Onnx model. (default).
+static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
\ No newline at end of file
diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts
index e129c6971a85c..3e1e833addb91 100644
--- a/js/common/lib/backend-impl.ts
+++ b/js/common/lib/backend-impl.ts
@@ -82,7 +82,7 @@ export const resolveBackend = async(backendHints: readonly string[]): Promise;
+ init(backendName: string): Promise;
createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise;
diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts
index 5f5ad49a2dea8..e8eb0e9babf5a 100644
--- a/js/node/lib/backend.ts
+++ b/js/node/lib/backend.ts
@@ -20,7 +20,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
}
async dispose(): Promise {
- return Promise.resolve();
+ this.#inferenceSession.dispose();
}
readonly inputNames: string[];
diff --git a/js/node/lib/binding.ts b/js/node/lib/binding.ts
index 8a0ce89abfa64..54b5767139904 100644
--- a/js/node/lib/binding.ts
+++ b/js/node/lib/binding.ts
@@ -28,6 +28,8 @@ export declare namespace Binding {
readonly outputNames: string[];
run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType;
+
+ dispose(): void;
}
export interface InferenceSessionConstructor {
diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc
index c409fdc8895f7..1bbb6df1ce1c8 100644
--- a/js/node/src/inference_session_wrap.cc
+++ b/js/node/src/inference_session_wrap.cc
@@ -31,6 +31,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
Napi::Function func = DefineClass(
env, "InferenceSession",
{InstanceMethod("loadModel", &InferenceSessionWrap::LoadModel), InstanceMethod("run", &InferenceSessionWrap::Run),
+ InstanceMethod("dispose", &InferenceSessionWrap::Dispose),
InstanceAccessor("inputNames", &InferenceSessionWrap::GetInputNames, nullptr, napi_default, nullptr),
InstanceAccessor("outputNames", &InferenceSessionWrap::GetOutputNames, nullptr, napi_default, nullptr)});
@@ -45,7 +46,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
}
InferenceSessionWrap::InferenceSessionWrap(const Napi::CallbackInfo &info)
- : Napi::ObjectWrap(info), initialized_(false), session_(nullptr),
+ : Napi::ObjectWrap(info), initialized_(false), disposed_(false), session_(nullptr),
defaultRunOptions_(nullptr) {}
Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) {
@@ -53,6 +54,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) {
Napi::HandleScope scope(env);
ORT_NAPI_THROW_ERROR_IF(this->initialized_, env, "Model already loaded. Cannot load model multiple times.");
+ ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
size_t argsLength = info.Length();
ORT_NAPI_THROW_TYPEERROR_IF(argsLength == 0, env, "Expect argument: model file path or buffer.");
@@ -129,6 +131,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) {
Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
+ ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
Napi::EscapableHandleScope scope(env);
return scope.Escape(CreateNapiArrayFrom(env, inputNames_));
@@ -137,6 +140,7 @@ Napi::Value InferenceSessionWrap::GetInputNames(const Napi::CallbackInfo &info)
Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
+ ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
Napi::EscapableHandleScope scope(env);
return scope.Escape(CreateNapiArrayFrom(env, outputNames_));
@@ -145,6 +149,7 @@ Napi::Value InferenceSessionWrap::GetOutputNames(const Napi::CallbackInfo &info)
Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
+ ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
ORT_NAPI_THROW_TYPEERROR_IF(info.Length() < 2, env, "Expect argument: inputs(feed) and outputs(fetch).");
ORT_NAPI_THROW_TYPEERROR_IF(!info[0].IsObject() || !info[1].IsObject(), env,
"Expect inputs(feed) and outputs(fetch) to be objects.");
@@ -209,6 +214,18 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo &info) {
}
}
+Napi::Value InferenceSessionWrap::Dispose(const Napi::CallbackInfo &info) {
+ Napi::Env env = info.Env();
+ ORT_NAPI_THROW_ERROR_IF(!this->initialized_, env, "Session is not initialized.");
+ ORT_NAPI_THROW_ERROR_IF(this->disposed_, env, "Session already disposed.");
+
+ this->defaultRunOptions_.reset(nullptr);
+ this->session_.reset(nullptr);
+
+ this->disposed_ = true;
+ return env.Undefined();
+}
+
Napi::Value InferenceSessionWrap::ListSupportedBackends(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::EscapableHandleScope scope(env);
diff --git a/js/node/src/inference_session_wrap.h b/js/node/src/inference_session_wrap.h
index 9eee45b72dcb1..1e789c4814cd6 100644
--- a/js/node/src/inference_session_wrap.h
+++ b/js/node/src/inference_session_wrap.h
@@ -55,6 +55,14 @@ class InferenceSessionWrap : public Napi::ObjectWrap {
*/
Napi::Value Run(const Napi::CallbackInfo &info);
+ /**
+ * [sync] dispose the session.
+ * @param nothing
+ * @returns nothing
+ * @throw nothing
+ */
+ Napi::Value Dispose(const Napi::CallbackInfo &info);
+
// private members
// persistent constructor
@@ -62,6 +70,7 @@ class InferenceSessionWrap : public Napi::ObjectWrap {
// session objects
bool initialized_;
+ bool disposed_;
std::unique_ptr session_;
std::unique_ptr defaultRunOptions_;
diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts
index 78edcc90f55f9..2d123cdb71290 100644
--- a/js/web/lib/backend-wasm.ts
+++ b/js/web/lib/backend-wasm.ts
@@ -4,7 +4,7 @@
import {cpus} from 'node:os';
import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common';
-import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper';
+import {initializeOrtEp, initializeWebAssemblyAndOrtRuntime} from './wasm/proxy-wrapper';
import {OnnxruntimeWebAssemblySessionHandler} from './wasm/session-handler-inference';
/**
@@ -33,12 +33,23 @@ export const initializeFlags = (): void => {
};
export class OnnxruntimeWebAssemblyBackend implements Backend {
- async init(): Promise {
+ /**
+ * This function initializes the WebAssembly backend.
+ *
+ * This function will be called only once for each backend name. It will be called the first time when
+ * `ort.InferenceSession.create()` is called with a registered backend name.
+ *
+ * @param backendName - the registered backend name.
+ */
+ async init(backendName: string): Promise {
// populate wasm flags
initializeFlags();
// init wasm
- await initializeWebAssemblyInstance();
+ await initializeWebAssemblyAndOrtRuntime();
+
+ // performe EP specific initialization
+ await initializeOrtEp(backendName);
}
createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions):
Promise;
diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts
index 6060271ced156..499327741c82b 100644
--- a/js/web/lib/index.ts
+++ b/js/web/lib/index.ts
@@ -21,7 +21,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
- if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
+ if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 4f4a06c37a94f..6c3d22352772e 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -144,17 +144,7 @@ export class WebGpuBackend {
*/
sessionExternalDataMapping: Map> = new Map();
- async initialize(env: Env): Promise {
- if (!navigator.gpu) {
- // WebGPU is not available.
- throw new Error('WebGpuBackend: WebGPU is not available.');
- }
-
- const adapter = await navigator.gpu.requestAdapter();
- if (!adapter) {
- throw new Error('WebGpuBackend: Failed to get GPU adapter.');
- }
-
+ async initialize(env: Env, adapter: GPUAdapter): Promise {
this.env = env;
const requiredFeatures: GPUFeatureName[] = [];
const deviceDescriptor: GPUDeviceDescriptor = {
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index e6db631c44eea..cad1e87b24a51 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -130,64 +130,76 @@ class ComputeContextImpl implements ComputeContext {
}
}
-export const init = async(module: OrtWasmModule, env: Env): Promise => {
- const init = module.jsepInit;
- if (init && navigator.gpu) {
- if (!env.wasm.simd) {
- throw new Error(
- 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using WebGPU EP');
- }
- const backend = new WebGpuBackend();
- await backend.initialize(env);
-
- init(
- // backend
- backend,
-
- // jsepAlloc()
- (size: number) => backend.alloc(size),
-
- // jsepFree()
- (ptr: number) => backend.free(ptr),
-
- // jsepCopy(src, dst, size, isSourceGpu)
- (src: number, dst: number, size: number, isSourceGpu = false) => {
- if (isSourceGpu) {
- LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
- backend.memcpy(src, dst);
- } else {
- LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
- const data = module.HEAPU8.subarray(src, src + size);
- backend.upload(dst, data);
- }
- },
-
- // jsepCopyAsync(src, dst, size)
- async(gpuDataId: number, dataOffset: number, size: number):
- Promise => {
- LOG_DEBUG(
- 'verbose',
- () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);
-
- await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
- },
-
- // jsepCreateKernel
- (name: string, kernel: number, attribute: unknown) => backend.createKernel(
- name, kernel, attribute,
- env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),
-
- // jsepReleaseKernel
- (kernel: number) => backend.releaseKernel(kernel),
-
- // jsepRun
- (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => {
- LOG_DEBUG(
- 'verbose',
- () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
- contextDataOffset}`);
- const context = new ComputeContextImpl(module, backend, contextDataOffset);
- return backend.computeKernel(kernel, context, errors);
- });
+/**
+ * Initialize JSEP with WebGPU backend.
+ *
+ * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
+ * This function expects:
+ * - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
+ * - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
+ * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
+ * 'webgpu' backend.
+ *
+ * @param module - the ORT WebAssembly module
+ * @param env - the ORT environment variable (ort.env)
+ * @param gpuAdapter - the pre-created GPU adapter
+ */
+export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise => {
+ const jsepInit = module.jsepInit;
+ if (!jsepInit) {
+ throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}
+
+ const backend = new WebGpuBackend();
+ await backend.initialize(env, gpuAdapter);
+
+ jsepInit(
+ // backend
+ backend,
+
+ // jsepAlloc()
+ (size: number) => backend.alloc(size),
+
+ // jsepFree()
+ (ptr: number) => backend.free(ptr),
+
+ // jsepCopy(src, dst, size, isSourceGpu)
+ (src: number, dst: number, size: number, isSourceGpu = false) => {
+ if (isSourceGpu) {
+ LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
+ backend.memcpy(src, dst);
+ } else {
+ LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
+ const data = module.HEAPU8.subarray(src, src + size);
+ backend.upload(dst, data);
+ }
+ },
+
+ // jsepCopyAsync(src, dst, size)
+ async(gpuDataId: number, dataOffset: number, size: number):
+ Promise => {
+ LOG_DEBUG(
+ 'verbose',
+ () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);
+
+ await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
+ },
+
+ // jsepCreateKernel
+ (name: string, kernel: number, attribute: unknown) => backend.createKernel(
+ name, kernel, attribute,
+ env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),
+
+ // jsepReleaseKernel
+ (kernel: number) => backend.releaseKernel(kernel),
+
+ // jsepRun
+ (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => {
+ LOG_DEBUG(
+ 'verbose',
+ () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
+ contextDataOffset}`);
+ const context = new ComputeContextImpl(module, backend, contextDataOffset);
+ return backend.computeKernel(kernel, context, errors);
+ });
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index 5fffa2f266603..0eb0d40a3ea5e 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -772,14 +772,14 @@ class ShaderHelperImpl implements ShaderHelper {
const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3,
@builtin(local_invocation_id) local_id : vec3` :
- `@builtin(local_invocation_index) local_index : u32,
+ `@builtin(local_invocation_index) local_idx : u32,
@builtin(workgroup_id) workgroup_id : vec3,
@builtin(num_workgroups) num_workgroups : vec3`;
const globalIdxDefinition = is1DimensionDispatch ?
- 'let global_idx = global_id.x;' :
+ 'let global_idx = global_id.x; let local_idx = local_id.x;' :
`let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
- workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`;
+ workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`;
return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
fn main(${paramList}) {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
index 6e9dee41ce488..1c5d28e4b8e3f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
@@ -97,8 +97,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
- let m = global_id.x / N;
- let n = global_id.x % N;
+ let m = global_idx / N;
+ let n = global_idx % N;
var value = ${dataType}(0);
for (var k: u32 = 0u; k<${K}u; k++) {
@@ -107,7 +107,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${calculateAlpha}
${calculateC}
- output[global_id.x] = value;
+ output[global_idx] = value;
}`;
return {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
index 1365d1e9a12a4..7c440cbffea7b 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
@@ -141,7 +141,6 @@ export const createReduceSharedProgramInfo =
return ((a - 1u) / b + 1u);
}
${shaderHelper.mainStart(workgroupSize)}
- let local_idx = local_id.x;
let outputIndex = global_idx / ${workgroupSize};
let offset = outputIndex * uniforms.reduceSize;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
index 378a7e738dac9..324dc3af1a710 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
@@ -73,8 +73,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
}
${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)}
${shaderHelper.mainStart()}
- let gindex = i32(global_id.x);
- let lindex = i32(local_id.x);
+ let gindex = i32(global_idx);
+ let lindex = i32(local_idx);
const wg = ${WG};
let row = gindex / wg;
let cols = uniforms.packedCols;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
index 51114d8a99dd1..a25e7fe4229b4 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
@@ -125,8 +125,8 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
- const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
- const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
+ const min = (inputs.length >= 2 && inputs[1].data !== 0) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
+ const max = (inputs.length >= 3 && inputs[2].data !== 0) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};
diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts
index efeb086256cf3..02246c9ee4767 100644
--- a/js/web/lib/wasm/proxy-messages.ts
+++ b/js/web/lib/wasm/proxy-messages.ts
@@ -3,6 +3,9 @@
import type {Env, InferenceSession, Tensor} from 'onnxruntime-common';
+/**
+ * Among all the tensor locations, only 'cpu' is serializable.
+ */
export type SerializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu'];
@@ -12,15 +15,28 @@ export type GpuBufferMetadata = {
dispose?: () => void;
};
+/**
+ * Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable.
+ */
export type UnserializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];
+/**
+ * Tensor metadata is a tuple of [dataType, dims, data, location], where
+ * - dataType: tensor data type
+ * - dims: tensor dimensions
+ * - data: tensor data, which can be one of the following depending on the location:
+ * - cpu: Uint8Array
+ * - cpu-pinned: Uint8Array
+ * - gpu-buffer: GpuBufferMetadata
+ * - location: tensor data location
+ */
export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;
export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]];
-export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number];
+export type SerializableInternalBuffer = [bufferOffset: number, bufferLength: number];
interface MessageError {
err?: string;
@@ -28,35 +44,32 @@ interface MessageError {
interface MessageInitWasm extends MessageError {
type: 'init-wasm';
- in ?: Env.WebAssemblyFlags;
-}
-
-interface MessageInitOrt extends MessageError {
- type: 'init-ort';
in ?: Env;
+ out?: never;
}
-interface MessageCreateSessionAllocate extends MessageError {
- type: 'create_allocate';
- in ?: {model: Uint8Array};
- out?: SerializableModeldata;
+interface MessageInitEp extends MessageError {
+ type: 'init-ep';
+ in ?: {env: Env; epName: string};
+ out?: never;
}
-interface MessageCreateSessionFinalize extends MessageError {
- type: 'create_finalize';
- in ?: {modeldata: SerializableModeldata; options?: InferenceSession.SessionOptions};
- out?: SerializableSessionMetadata;
+interface MessageCopyFromExternalBuffer extends MessageError {
+ type: 'copy-from';
+ in ?: {buffer: Uint8Array};
+ out?: SerializableInternalBuffer;
}
interface MessageCreateSession extends MessageError {
type: 'create';
- in ?: {model: Uint8Array; options?: InferenceSession.SessionOptions};
+ in ?: {model: SerializableInternalBuffer|Uint8Array; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
}
interface MessageReleaseSession extends MessageError {
type: 'release';
in ?: number;
+ out?: never;
}
interface MessageRun extends MessageError {
@@ -71,12 +84,8 @@ interface MessageRun extends MessageError {
interface MesssageEndProfiling extends MessageError {
type: 'end-profiling';
in ?: number;
+ out?: never;
}
-interface MessageIsOrtEnvInitialized extends MessageError {
- type: 'is-ort-env-initialized';
- out?: boolean;
-}
-
-export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
- MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
+export type OrtWasmMessage = MessageInitWasm|MessageInitEp|MessageCopyFromExternalBuffer|MessageCreateSession|
+ MessageReleaseSession|MessageRun|MesssageEndProfiling;
diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts
index 1cb6d9e391e4f..4df524cdcfb22 100644
--- a/js/web/lib/wasm/proxy-worker/main.ts
+++ b/js/web/lib/wasm/proxy-worker/main.ts
@@ -36,104 +36,82 @@ declare global {
}
import {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages';
-import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
+import {createSession, copyFromExternalBuffer, endProfiling, extractTransferableBuffers, initEp, initRuntime, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';
self.onmessage = (ev: MessageEvent): void => {
- switch (ev.data.type) {
- case 'init-wasm':
- try {
- initializeWebAssembly(ev.data.in!)
+ const {type, in : message} = ev.data;
+ try {
+ switch (type) {
+ case 'init-wasm':
+ initializeWebAssembly(message!.wasm)
.then(
- () => postMessage({type: 'init-wasm'} as OrtWasmMessage),
- err => postMessage({type: 'init-wasm', err} as OrtWasmMessage));
- } catch (err) {
- postMessage({type: 'init-wasm', err} as OrtWasmMessage);
- }
- break;
- case 'init-ort':
- try {
- initRuntime(ev.data.in!).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({
- type: 'init-ort',
- err
- } as OrtWasmMessage));
- } catch (err) {
- postMessage({type: 'init-ort', err} as OrtWasmMessage);
- }
- break;
- case 'create_allocate':
- try {
- const {model} = ev.data.in!;
- const modeldata = createSessionAllocate(model);
- postMessage({type: 'create_allocate', out: modeldata} as OrtWasmMessage);
- } catch (err) {
- postMessage({type: 'create_allocate', err} as OrtWasmMessage);
+ () => {
+ initRuntime(message!).then(
+ () => {
+ postMessage({type});
+ },
+ err => {
+ postMessage({type, err});
+ });
+ },
+ err => {
+ postMessage({type, err});
+ });
+ break;
+ case 'init-ep': {
+ const {epName, env} = message!;
+ initEp(env, epName)
+ .then(
+ () => {
+ postMessage({type});
+ },
+ err => {
+ postMessage({type, err});
+ });
+ break;
}
- break;
- case 'create_finalize':
- try {
- const {modeldata, options} = ev.data.in!;
- const sessionMetadata = createSessionFinalize(modeldata, options);
- postMessage({type: 'create_finalize', out: sessionMetadata} as OrtWasmMessage);
- } catch (err) {
- postMessage({type: 'create_finalize', err} as OrtWasmMessage);
+ case 'copy-from': {
+ const {buffer} = message!;
+ const bufferData = copyFromExternalBuffer(buffer);
+ postMessage({type, out: bufferData} as OrtWasmMessage);
+ break;
}
- break;
- case 'create':
- try {
- const {model, options} = ev.data.in!;
+ case 'create': {
+ const {model, options} = message!;
const sessionMetadata = createSession(model, options);
- postMessage({type: 'create', out: sessionMetadata} as OrtWasmMessage);
- } catch (err) {
- postMessage({type: 'create', err} as OrtWasmMessage);
+ postMessage({type, out: sessionMetadata} as OrtWasmMessage);
+ break;
}
- break;
- case 'release':
- try {
- releaseSession(ev.data.in!);
- postMessage({type: 'release'} as OrtWasmMessage);
- } catch (err) {
- postMessage({type: 'release', err} as OrtWasmMessage);
- }
- break;
- case 'run':
- try {
- const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!;
+ case 'release':
+ releaseSession(message!);
+ postMessage({type});
+ break;
+ case 'run': {
+ const {sessionId, inputIndices, inputs, outputIndices, options} = message!;
run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options)
.then(
outputs => {
if (outputs.some(o => o[3] !== 'cpu')) {
- postMessage({type: 'run', err: 'Proxy does not support non-cpu tensor location.'});
+ postMessage({type, err: 'Proxy does not support non-cpu tensor location.'});
} else {
postMessage(
- {type: 'run', out: outputs} as OrtWasmMessage,
+ {type, out: outputs} as OrtWasmMessage,
extractTransferableBuffers(outputs as SerializableTensorMetadata[]));
}
},
err => {
- postMessage({type: 'run', err} as OrtWasmMessage);
+ postMessage({type, err});
});
- } catch (err) {
- postMessage({type: 'run', err} as OrtWasmMessage);
- }
- break;
- case 'end-profiling':
- try {
- const handler = ev.data.in!;
- endProfiling(handler);
- postMessage({type: 'end-profiling'} as OrtWasmMessage);
- } catch (err) {
- postMessage({type: 'end-profiling', err} as OrtWasmMessage);
- }
- break;
- case 'is-ort-env-initialized':
- try {
- const ortEnvInitialized = isOrtEnvInitialized();
- postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
- } catch (err) {
- postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
+ break;
}
- break;
- default:
+ case 'end-profiling':
+ endProfiling(message!);
+ postMessage({type});
+ break;
+ default:
+ }
+ } catch (err) {
+ postMessage({type, err} as OrtWasmMessage);
}
};
diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts
index 069a1fa452dbc..86017a4ec6904 100644
--- a/js/web/lib/wasm/proxy-wrapper.ts
+++ b/js/web/lib/wasm/proxy-wrapper.ts
@@ -1,9 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {Env, env, InferenceSession} from 'onnxruntime-common';
+import {env, InferenceSession} from 'onnxruntime-common';
-import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
+import {OrtWasmMessage, SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
import * as core from './wasm-core-impl';
import {initializeWebAssembly} from './wasm-factory';
@@ -13,18 +13,18 @@ let initializing = false;
let initialized = false;
let aborted = false;
-// resolve; reject
-type PromiseCallbacks = [(result: T) => void, (reason: unknown) => void];
-
+type PromiseCallbacks = [resolve: (result: T) => void, reject: (reason: unknown) => void];
let initWasmCallbacks: PromiseCallbacks;
-let initOrtCallbacks: PromiseCallbacks;
-const createSessionAllocateCallbacks: Array> = [];
-const createSessionFinalizeCallbacks: Array> = [];
-const createSessionCallbacks: Array> = [];
-const releaseSessionCallbacks: Array> = [];
-const runCallbacks: Array> = [];
-const endProfilingCallbacks: Array> = [];
-const isOrtEnvInitializedCallbacks: Array> = [];
+const queuedCallbacks: Map>> = new Map();
+
+const enqueueCallbacks = (type: OrtWasmMessage['type'], callbacks: PromiseCallbacks): void => {
+ const queue = queuedCallbacks.get(type);
+ if (queue) {
+ queue.push(callbacks);
+ } else {
+ queuedCallbacks.set(type, [callbacks]);
+ }
+};
const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
@@ -44,82 +44,40 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => {
initWasmCallbacks[0]();
}
break;
- case 'init-ort':
- if (ev.data.err) {
- initOrtCallbacks[1](ev.data.err);
- } else {
- initOrtCallbacks[0]();
- }
- break;
- case 'create_allocate':
- if (ev.data.err) {
- createSessionAllocateCallbacks.shift()![1](ev.data.err);
- } else {
- createSessionAllocateCallbacks.shift()![0](ev.data.out!);
- }
- break;
- case 'create_finalize':
- if (ev.data.err) {
- createSessionFinalizeCallbacks.shift()![1](ev.data.err);
- } else {
- createSessionFinalizeCallbacks.shift()![0](ev.data.out!);
- }
- break;
+ case 'init-ep':
+ case 'copy-from':
case 'create':
- if (ev.data.err) {
- createSessionCallbacks.shift()![1](ev.data.err);
- } else {
- createSessionCallbacks.shift()![0](ev.data.out!);
- }
- break;
case 'release':
- if (ev.data.err) {
- releaseSessionCallbacks.shift()![1](ev.data.err);
- } else {
- releaseSessionCallbacks.shift()![0]();
- }
- break;
case 'run':
+ case 'end-profiling': {
+ const callbacks = queuedCallbacks.get(ev.data.type)!;
if (ev.data.err) {
- runCallbacks.shift()![1](ev.data.err);
- } else {
- runCallbacks.shift()![0](ev.data.out!);
- }
- break;
- case 'end-profiling':
- if (ev.data.err) {
- endProfilingCallbacks.shift()![1](ev.data.err);
- } else {
- endProfilingCallbacks.shift()![0]();
- }
- break;
- case 'is-ort-env-initialized':
- if (ev.data.err) {
- isOrtEnvInitializedCallbacks.shift()![1](ev.data.err);
+ callbacks.shift()![1](ev.data.err);
} else {
- isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!);
+ callbacks.shift()![0](ev.data.out!);
}
break;
+ }
default:
}
};
const scriptSrc = typeof document !== 'undefined' ? (document?.currentScript as HTMLScriptElement)?.src : undefined;
-export const initializeWebAssemblyInstance = async(): Promise => {
- if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
- if (initialized) {
- return;
- }
- if (initializing) {
- throw new Error('multiple calls to \'initWasm()\' detected.');
- }
- if (aborted) {
- throw new Error('previous call to \'initWasm()\' failed.');
- }
+export const initializeWebAssemblyAndOrtRuntime = async(): Promise => {
+ if (initialized) {
+ return;
+ }
+ if (initializing) {
+ throw new Error('multiple calls to \'initWasm()\' detected.');
+ }
+ if (aborted) {
+ throw new Error('previous call to \'initWasm()\' failed.');
+ }
- initializing = true;
+ initializing = true;
+ if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
// overwrite wasm filepaths
if (env.wasm.wasmPaths === undefined) {
if (scriptSrc && scriptSrc.indexOf('blob:') !== 0) {
@@ -142,78 +100,78 @@ export const initializeWebAssemblyInstance = async(): Promise => {
proxyWorker.onmessage = onProxyWorkerMessage;
URL.revokeObjectURL(workerUrl);
initWasmCallbacks = [resolve, reject];
- const message: OrtWasmMessage = {type: 'init-wasm', in : env.wasm};
+ const message: OrtWasmMessage = {type: 'init-wasm', in : env};
proxyWorker.postMessage(message);
});
} else {
- return initializeWebAssembly(env.wasm);
+ try {
+ await initializeWebAssembly(env.wasm);
+ await core.initRuntime(env);
+ initialized = true;
+ } catch (e) {
+ aborted = true;
+ throw e;
+ } finally {
+ initializing = false;
+ }
}
};
-export const initializeRuntime = async(env: Env): Promise => {
+export const initializeOrtEp = async(epName: string): Promise => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise((resolve, reject) => {
- initOrtCallbacks = [resolve, reject];
- const message: OrtWasmMessage = {type: 'init-ort', in : env};
+ enqueueCallbacks('init-ep', [resolve, reject]);
+ const message: OrtWasmMessage = {type: 'init-ep', in : {epName, env}};
proxyWorker!.postMessage(message);
});
} else {
- await core.initRuntime(env);
+ await core.initEp(env, epName);
}
};
-export const createSessionAllocate = async(model: Uint8Array): Promise => {
+export const copyFromExternalBuffer = async(buffer: Uint8Array): Promise => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
- return new Promise((resolve, reject) => {
- createSessionAllocateCallbacks.push([resolve, reject]);
- const message: OrtWasmMessage = {type: 'create_allocate', in : {model}};
- proxyWorker!.postMessage(message, [model.buffer]);
+ return new Promise((resolve, reject) => {
+ enqueueCallbacks('copy-from', [resolve, reject]);
+ const message: OrtWasmMessage = {type: 'copy-from', in : {buffer}};
+ proxyWorker!.postMessage(message, [buffer.buffer]);
});
} else {
- return core.createSessionAllocate(model);
+ return core.copyFromExternalBuffer(buffer);
}
};
-export const createSessionFinalize = async(modeldata: SerializableModeldata, options?: InferenceSession.SessionOptions):
- Promise => {
- if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
- ensureWorker();
- return new Promise((resolve, reject) => {
- createSessionFinalizeCallbacks.push([resolve, reject]);
- const message: OrtWasmMessage = {type: 'create_finalize', in : {modeldata, options}};
- proxyWorker!.postMessage(message);
- });
- } else {
- return core.createSessionFinalize(modeldata, options);
- }
- };
-
export const createSession =
- async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => {
- if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
- // check unsupported options
- if (options?.preferredOutputLocation) {
- throw new Error('session option "preferredOutputLocation" is not supported for proxy.');
- }
- ensureWorker();
- return new Promise((resolve, reject) => {
- createSessionCallbacks.push([resolve, reject]);
- const message: OrtWasmMessage = {type: 'create', in : {model, options}};
- proxyWorker!.postMessage(message, [model.buffer]);
- });
- } else {
- return core.createSession(model, options);
- }
-};
+ async(model: SerializableInternalBuffer|Uint8Array, options?: InferenceSession.SessionOptions):
+ Promise => {
+ if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
+ // check unsupported options
+ if (options?.preferredOutputLocation) {
+ throw new Error('session option "preferredOutputLocation" is not supported for proxy.');
+ }
+ ensureWorker();
+ return new Promise((resolve, reject) => {
+ enqueueCallbacks('create', [resolve, reject]);
+ const message: OrtWasmMessage = {type: 'create', in : {model, options}};
+ const transferable: Transferable[] = [];
+ if (model instanceof Uint8Array) {
+ transferable.push(model.buffer);
+ }
+ proxyWorker!.postMessage(message, transferable);
+ });
+ } else {
+ return core.createSession(model, options);
+ }
+ };
export const releaseSession = async(sessionId: number): Promise => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise((resolve, reject) => {
- releaseSessionCallbacks.push([resolve, reject]);
+ enqueueCallbacks('release', [resolve, reject]);
const message: OrtWasmMessage = {type: 'release', in : sessionId};
proxyWorker!.postMessage(message);
});
@@ -236,7 +194,7 @@ export const run = async(
}
ensureWorker();
return new Promise((resolve, reject) => {
- runCallbacks.push([resolve, reject]);
+ enqueueCallbacks('run', [resolve, reject]);
const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU.
const message: OrtWasmMessage =
{type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}};
@@ -251,7 +209,7 @@ export const endProfiling = async(sessionId: number): Promise => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise((resolve, reject) => {
- endProfilingCallbacks.push([resolve, reject]);
+ enqueueCallbacks('end-profiling', [resolve, reject]);
const message: OrtWasmMessage = {type: 'end-profiling', in : sessionId};
proxyWorker!.postMessage(message);
});
@@ -259,16 +217,3 @@ export const endProfiling = async(sessionId: number): Promise => {
core.endProfiling(sessionId);
}
};
-
-export const isOrtEnvInitialized = async(): Promise => {
- if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
- ensureWorker();
- return new Promise((resolve, reject) => {
- isOrtEnvInitializedCallbacks.push([resolve, reject]);
- const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
- proxyWorker!.postMessage(message);
- });
- } else {
- return core.isOrtEnvInitialized();
- }
-};
diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts
index 3ca34d957c572..b62287483208a 100644
--- a/js/web/lib/wasm/session-handler-inference.ts
+++ b/js/web/lib/wasm/session-handler-inference.ts
@@ -2,14 +2,12 @@
// Licensed under the MIT License.
import {readFile} from 'node:fs/promises';
-import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
+import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
-import {SerializableModeldata, TensorMetadata} from './proxy-messages';
-import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
+import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
+import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';
-let runtimeInitializationPromise: Promise|undefined;
-
export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
switch (tensor.location) {
case 'cpu':
@@ -44,7 +42,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
inputNames: string[];
outputNames: string[];
- async createSessionAllocate(path: string): Promise {
+ async fetchModelAndCopyToWasmMemory(path: string): Promise {
// fetch model from url and move to wasm heap. The arraybufffer that held the http
// response is freed once we return
const response = await fetch(path);
@@ -52,33 +50,26 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
throw new Error(`failed to load model: ${path}`);
}
const arrayBuffer = await response.arrayBuffer();
- return createSessionAllocate(new Uint8Array(arrayBuffer));
+ return copyFromExternalBuffer(new Uint8Array(arrayBuffer));
}
async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise {
- if (!(await isOrtEnvInitialized())) {
- if (!runtimeInitializationPromise) {
- runtimeInitializationPromise = initializeRuntime(env);
- }
- await runtimeInitializationPromise;
- runtimeInitializationPromise = undefined;
- }
+ let model: Parameters[0];
if (typeof pathOrBuffer === 'string') {
if (typeof process !== 'undefined' && process.versions && process.versions.node) {
// node
- const model = await readFile(pathOrBuffer);
- [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);
+ model = await readFile(pathOrBuffer);
} else {
// browser
- // fetch model and move to wasm heap.
- const modelData: SerializableModeldata = await this.createSessionAllocate(pathOrBuffer);
- // create the session
- [this.sessionId, this.inputNames, this.outputNames] = await createSessionFinalize(modelData, options);
+ // fetch model and copy to wasm heap.
+ model = await this.fetchModelAndCopyToWasmMemory(pathOrBuffer);
}
} else {
- [this.sessionId, this.inputNames, this.outputNames] = await createSession(pathOrBuffer, options);
+ model = pathOrBuffer;
}
+
+ [this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);
}
async dispose(): Promise {
diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts
index 71815f21e650a..e35759192fe3c 100644
--- a/js/web/lib/wasm/session-handler-training.ts
+++ b/js/web/lib/wasm/session-handler-training.ts
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';
+import {InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler} from 'onnxruntime-common';
-import {SerializableModeldata, TensorMetadata} from './proxy-messages';
+import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
-import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
+import {copyFromExternalBuffer} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, lazyResetGrad, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl';
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
@@ -18,7 +18,7 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
evalInputNames: string[] = [];
evalOutputNames: string[] = [];
- async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise {
+ async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise {
let buffer: Uint8Array;
if (typeof uriOrBuffer === 'string') {
const response = await fetch(uriOrBuffer);
@@ -27,21 +27,18 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
} else {
buffer = uriOrBuffer;
}
- return createSessionAllocate(buffer);
+ return copyFromExternalBuffer(buffer);
}
async createTrainingSession(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions) {
- if (!isOrtEnvInitialized()) {
- await initRuntime(env);
- }
- const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
- const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
+ const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
+ const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
// 0 is supposed to be the nullptr
- let evalModelData: SerializableModeldata = [0, 0];
- let optimizerModelData: SerializableModeldata = [0, 0];
+ let evalModelData: SerializableInternalBuffer = [0, 0];
+ let optimizerModelData: SerializableInternalBuffer = [0, 0];
if (evalModelUriOrBuffer !== '') {
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 3aacf8f4d90e0..a9dfd9218bb6f 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -3,37 +3,60 @@
import {Env, InferenceSession, Tensor} from 'onnxruntime-common';
-import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
+import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';
-let ortEnvInitialized = false;
+// #region Initializations
/**
- * get the input/output count of the session.
- * @param sessionHandle the handle representing the session. should be non-zero.
- * @returns a tuple including 2 numbers, representing the input count and output count.
+ * There are 4 different "initialization" steps for ORT. They happen in different places and different time.
+ *
+ * 1. JavaScript initialization for onnxruntime-common and onnxruntime-web.
+ * This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend()
+ * function multiple times to register all the available backends. The backend registration is very fast. It only
+ * registers the backend name with the uninitialized backend object. No heavy initialization is done in this step.
+ * Refer to web/lib/index.ts for the backend registration.
+ *
+ * 2. WebAssembly artifact initialization.
+ * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or
+ * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings:
+ * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
+ * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
+ * JavaScript code to initialize the WebAssembly runtime.
+ * - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'.
+ * - downloading the 'ort-wasm{...}.wasm' file is done in this step.
+ * - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool.
+ *
+ * 3. ORT environment initialization.
+ * This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization.
+ * Function `_OrtInit()` is called in this step.
+ * - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'.
+ * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
+ *
+ * 4. Session initialization.
+ * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3
+ * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the
+ * followings:
+ * If the parameter is a URL:
+ * - download the model data from the URL.
+ * - copy the model data to the WASM heap. (proxy: 'copy-from')
+ * - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected.
+ * - call `_OrtCreateSession()` to create the session. (proxy: 'create')
+ *
+ * If the parameter is a Uint8Array object:
+ * - copy the model data to the WASM heap. (proxy: 'copy-from')
+ * - call `_OrtCreateSession()` to create the session. (proxy: 'create')
+ *
+ *
*/
-const getSessionInputOutputCount = (sessionHandle: number): [number, number] => {
- const wasm = getInstance();
- const stack = wasm.stackSave();
- try {
- const dataOffset = wasm.stackAlloc(8);
- const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4);
- if (errorCode !== 0) {
- checkLastError('Can\'t get session input/output count.');
- }
- return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
- } finally {
- wasm.stackRestore(stack);
- }
-};
/**
* initialize ORT environment.
+ *
* @param numThreads SetGlobalIntraOpNumThreads(numThreads)
* @param loggingLevel CreateEnv(static_cast(logging_level))
*/
@@ -51,18 +74,41 @@ const initOrt = (numThreads: number, loggingLevel: number): void => {
export const initRuntime = async(env: Env): Promise => {
// init ORT
initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel));
+};
+
+/**
+ * perform EP specific initialization.
+ *
+ * @param env
+ * @param epName
+ */
+export const initEp = async(env: Env, epName: string): Promise => {
+ if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
+ // perform WebGPU availability check
+ if (typeof navigator === 'undefined' || !navigator.gpu) {
+ throw new Error('WebGPU is not supported in current environment');
+ }
+ const adapter = await navigator.gpu.requestAdapter();
+ if (!adapter) {
+ throw new Error(
+ 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
+ }
+
+ if (!env.wasm.simd) {
+ throw new Error(
+ 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
+ }
- if (!BUILD_DEFS.DISABLE_WEBGPU) {
// init JSEP if available
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
const initJsep = require('./jsep/init').init;
- await initJsep(getInstance(), env);
+ await initJsep(getInstance(), env, adapter);
}
-
- ortEnvInitialized = true;
};
+// #endregion Initializations
+
/**
* valid data locations for input/output tensors.
*/
@@ -97,13 +143,33 @@ type SessionMetadata = [
const activeSessions = new Map();
-export const isOrtEnvInitialized = (): boolean => ortEnvInitialized;
+/**
+ * get the input/output count of the session.
+ * @param sessionHandle the handle representing the session. should be non-zero.
+ * @returns a tuple including 2 numbers, representing the input count and output count.
+ */
+const getSessionInputOutputCount = (sessionHandle: number): [number, number] => {
+ const wasm = getInstance();
+ const stack = wasm.stackSave();
+ try {
+ const dataOffset = wasm.stackAlloc(8);
+ const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4);
+ if (errorCode !== 0) {
+ checkLastError('Can\'t get session input/output count.');
+ }
+ return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
+ } finally {
+ wasm.stackRestore(stack);
+ }
+};
/**
- * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession.
+ * allocate the memory and memcpy the external buffer.
+ *
+ * @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap.
* @returns a 2-elements tuple - the pointer and size of the allocated buffer
*/
-export const createSessionAllocate = (model: Uint8Array): [number, number] => {
+export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {
const wasm = getInstance();
const modelDataOffset = wasm._malloc(model.byteLength);
if (modelDataOffset === 0) {
@@ -114,15 +180,30 @@ export const createSessionAllocate = (model: Uint8Array): [number, number] => {
};
/**
- * create an inference session using the prepared buffer containing the model data.
- * @param modelData a 2-elements tuple containing the pointer and size of the model data buffer.
+ * create an inference session from a model data buffer.
+ *
+ * @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the
+ * pointer and size of the model data buffer.
* @param options an optional session options object.
* @returns a 3-elements tuple containing [session handle, input names, output names]
*/
-export const createSessionFinalize =
- (modelData: SerializableModeldata, options?: InferenceSession.SessionOptions): SerializableSessionMetadata => {
+export const createSession =
+ (modelData: Uint8Array|SerializableInternalBuffer,
+ options?: InferenceSession.SessionOptions): SerializableSessionMetadata => {
+ let modelDataOffset: number, modelDataLength: number;
const wasm = getInstance();
+ if (Array.isArray(modelData)) {
+ // if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
+ [modelDataOffset, modelDataLength] = modelData;
+ } else if (modelData.buffer === wasm.HEAPU8.buffer) {
+ // if model data uses the same buffer as the WASM heap, we don't need to copy it.
+ [modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
+ } else {
+ // otherwise, copy the model data to the WASM heap.
+ [modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
+ }
+
let sessionHandle = 0;
let sessionOptionsHandle = 0;
let ioBindingHandle = 0;
@@ -133,7 +214,7 @@ export const createSessionFinalize =
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
- sessionHandle = wasm._OrtCreateSession(modelData[0], modelData[1], sessionOptionsHandle);
+ sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}
@@ -201,7 +282,7 @@ export const createSessionFinalize =
}
throw e;
} finally {
- wasm._free(modelData[0]);
+ wasm._free(modelDataOffset);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
@@ -209,17 +290,6 @@ export const createSessionFinalize =
}
};
-
-/**
- * create an instance of InferenceSession.
- * @returns the metadata of InferenceSession. 0-value handle for failure.
- */
-export const createSession =
- (model: Uint8Array, options?: InferenceSession.SessionOptions): SerializableSessionMetadata => {
- const modelData: SerializableModeldata = createSessionAllocate(model);
- return createSessionFinalize(modelData, options);
- };
-
export const releaseSession = (sessionId: number): void => {
const wasm = getInstance();
const session = activeSessions.get(sessionId);
diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts
index 0cc28188a6093..c65178e2358d2 100644
--- a/js/web/lib/wasm/wasm-training-core-impl.ts
+++ b/js/web/lib/wasm/wasm-training-core-impl.ts
@@ -3,7 +3,7 @@
import {InferenceSession, Tensor} from 'onnxruntime-common';
-import {SerializableModeldata, TensorMetadata} from './proxy-messages';
+import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
@@ -32,7 +32,7 @@ const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero
}
};
-export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
+export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => {
const wasm = getInstance();
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
@@ -108,8 +108,8 @@ export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel:
};
export const createTrainingSessionHandle =
- (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
- optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => {
+ (checkpointHandle: number, trainModelData: SerializableInternalBuffer, evalModelData: SerializableInternalBuffer,
+ optimizerModelData: SerializableInternalBuffer, options: InferenceSession.SessionOptions): number => {
const wasm = getInstance();
let trainingSessionHandle = 0;
diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts
index 29acc07e118f9..5e9b0910a2c68 100644
--- a/js/web/test/test-runner.ts
+++ b/js/web/test/test-runner.ts
@@ -850,7 +850,7 @@ export class ProtoOpTestContext {
this.backendHint = test.backend!;
this.ioBindingMode = test.ioBinding;
- this.loadedData = onnx.ModelProto.encode(model).finish();
+ this.loadedData = onnx.ModelProto.encode(model).finish().slice();
// in debug mode, open a new tab in browser for the generated onnx model.
if (ort.env.debug) {
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index 320a05bb97dac..b060d500c6484 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -20,30 +20,158 @@ class MatMulNBits final : public OpKernel {
K_{narrow(info.GetAttr("K"))},
N_{narrow(info.GetAttr("N"))},
block_size_{narrow(info.GetAttr("block_size"))},
- nbits_{narrow(info.GetAttr("bits"))} {
+ nbits_{narrow(info.GetAttr("bits"))},
+ accuracy_level_{info.GetAttr("accuracy_level")} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
+ is_asym_ = info.GetInputCount() >= 4;
+ const Tensor* tensor_B = nullptr;
+ const Tensor* tensor_scale = nullptr;
+ const Tensor* tensor_zero_point = nullptr;
+ bool B_constant = info.TryGetConstantInput(1, &tensor_B);
+ bool scale_constant = info.TryGetConstantInput(2, &tensor_scale);
+ bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point);
+ all_constant_ = B_constant && scale_constant;
+ all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_;
}
Status Compute(OpKernelContext* context) const override;
+ Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) override;
+
+ Status UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx,
+ /*out*/ bool& used_shared_buffers) override;
+
private:
const size_t K_;
const size_t N_;
const size_t block_size_;
const size_t nbits_;
+ const int64_t accuracy_level_;
const bool column_wise_quant_{true};
+ IAllocatorUniquePtr packed_b_;
+ size_t packed_b_size_{0};
+ bool is_asym_{false};
+ bool all_constant_{false};
};
+Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
+ /*out*/ bool& is_packed,
+ /*out*/ PrePackedWeights* prepacked_weights) {
+ is_packed = false;
+ if (!all_constant_) {
+ return Status::OK();
+ }
+ auto compt_type = static_cast(accuracy_level_);
+ MLAS_THREADPOOL* pool = NULL;
+ if (input_idx == 1) {
+ packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type);
+ if (packed_b_size_ == 0) return Status::OK();
+ auto qptr = tensor.Data();
+ packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true);
+ if (packed_b_ == nullptr) {
+ return Status::OK();
+ }
+ std::memset(packed_b_.get(), 0, packed_b_size_);
+ MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_),
+ is_asym_, false, compt_type, pool);
+ if (prepacked_weights) {
+ prepacked_weights->buffers_.push_back(std::move(packed_b_));
+ prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
+ }
+ is_packed = true;
+ }
+ if (input_idx == 2 && packed_b_ != nullptr) {
+ auto sptr = tensor.Data();
+ MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_),
+ is_asym_, !is_asym_, compt_type, pool);
+ if (prepacked_weights) {
+ prepacked_weights->buffers_.push_back(std::move(packed_b_));
+ prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
+ }
+ is_packed = true;
+ }
+ if (input_idx == 3 && packed_b_ != nullptr) {
+ auto zptr = tensor.Data();
+ MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_),
+ is_asym_, is_asym_, compt_type, pool);
+ if (prepacked_weights) {
+ prepacked_weights->buffers_.push_back(std::move(packed_b_));
+ prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
+ }
+ is_packed = true;
+ }
+
+ return Status::OK();
+}
+
+Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx,
+ /*out*/ bool& used_shared_buffers) {
+ used_shared_buffers = false;
+ // Pack three tensors into one buffer
+ if (input_idx == 1) {
+ used_shared_buffers = true;
+ packed_b_ = std::move(prepacked_buffers[0]);
+ }
+ if (input_idx == 2) {
+ used_shared_buffers = true;
+ packed_b_ = std::move(prepacked_buffers[0]);
+ }
+ if (input_idx == 3) {
+ used_shared_buffers = true;
+ packed_b_ = std::move(prepacked_buffers[0]);
+ }
+ return Status::OK();
+}
+
Status MatMulNBits::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const Tensor* a = ctx->Input(0);
+ const auto* a_data = a->Data();
+
+ if (packed_b_.get()) {
+ TensorShape b_shape({static_cast(N_), static_cast(K_)});
+
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
+
+ Tensor* y = ctx->Output(0, helper.OutputShape());
+
+ // Bail out early if the output is going to be empty
+ if (y->Shape().Size() == 0) return Status::OK();
+
+ auto* y_data = y->MutableData();
+
+ const size_t max_len = helper.OutputOffsets().size();
+ const size_t M = static_cast(helper.M());
+ const size_t N = static_cast(helper.N());
+ const size_t K = static_cast(helper.K());
+ const size_t lda = helper.Lda(false);
+ std::vector gemm_params(max_len);
+ AllocatorPtr allocator;
+ auto status = ctx->GetTempSpaceAllocator(&allocator);
+ ORT_RETURN_IF_ERROR(status);
+ for (size_t i = 0; i < max_len; i++) {
+ gemm_params[i].A = a_data + helper.LeftOffsets()[i];
+ gemm_params[i].lda = lda;
+ gemm_params[i].B = packed_b_.get();
+ gemm_params[i].C = y_data + helper.OutputOffsets()[i];
+ gemm_params[i].ldc = N;
+ }
+ auto ws_size = MlasSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
+ // workspace for activation process(dynamic quantization and others)
+ auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size);
+ MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
+ thread_pool);
+ return Status::OK();
+ }
+
const Tensor* b = ctx->Input(1);
const Tensor* scales = ctx->Input(2);
const Tensor* zero_points = ctx->Input(3);
-
- const auto* a_data = a->Data();
const uint8_t* b_data = b->Data();
const auto* scales_data = scales->Data();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data();
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 26fca454c96f0..54eb43753931a 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3359,6 +3359,13 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("bits", "number of bits used for weight quantization (default 4)", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
+ .Attr("accuracy_level",
+ "The minimum accuracy level of input A, can be: 0(unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8) "
+ "(default unset). It is used to control how input A is quantized or downcast internally while "
+ "doing computation, for example: 0 means input A will not be quantized or downcast while doing "
+ "computation. 4 means input A can be quantized with the same block_size to int8 internally from "
+ "type T1.",
+ AttributeProto::INT, static_cast(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h
index 9620dd42d1da9..1e83dd1cec400 100644
--- a/onnxruntime/core/mlas/inc/mlas_qnbit.h
+++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h
@@ -77,3 +77,144 @@ MlasIsSQNBitGemmAvailable(
size_t BlkBitWidth,
size_t BlkLen
);
+
+/**
+ * @brief Define compute types of block quantization
+ */
+typedef enum {
+ CompUndef = 0, /*!< undef */
+ CompFp32 = 1, /*!< input fp32, accumulator fp32 */
+ CompFp16 = 2, /*!< input fp16, accumulator fp16 */
+ CompBf16 = 3, /*!< input bf16, accumulator fp32 */
+ CompInt8 = 4 /*!< input int8, accumulator int32 */
+} MLAS_SQNBIT_COMPUTE_TYPE;
+
+/**
+ * @brief Data parameters for NBits GEMM routine
+ * C = A * B
+ * A, C must be a float32 matrix
+ * B must be a packed nbits blob
+ * All except C are [in] parameters
+ */
+struct MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS {
+ const float* A = nullptr; /**< address of A (float32 matrix)*/
+ const void* B = nullptr; /**< address of B (packed nbits blob)*/
+ float* C = nullptr; /**< address of result matrix */
+ size_t lda = 0; /**< leading dimension of A */
+ size_t ldc = 0; /**< leading dimension of C*/
+};
+
+/**
+ * @brief Compute the byte size of the parameter combination
+ *
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param block_size size of the block to quantize, elements from the same block share the same
+ * scale and zero point
+ * @param nbits number of bits used for weight quantization
+ * @param is_asym flag for asymmetric quantization
+ * @param comp_type specify input data type and accumulator data type
+ * @return size of the packing buffer, 0 if the operation is not yet supported.
+ */
+size_t MLASCALL
+MlasNBitsGemmPackBSize(
+ size_t N, size_t K, size_t block_size, int nbits, bool is_asym, MLAS_SQNBIT_COMPUTE_TYPE comp_type
+);
+
+/**
+ * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers.
+ *
+ * @param PackedBuf packed data buffer
+ * @param QData quantized data buffer
+ * @param Scale scale pointer
+ * @param Zp zero point pointer
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param ldb leading dimension of B
+ * @param block_size size of the block to quantize, elements from the same block share the same
+ * scale and zero point
+ * @param nbits number of bits used for weight quantization (default 4)
+ * @param is_asym flag for asymmetric quantization
+ * @param comp_type specify input data type and accumulator data type
+ * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
+ * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
+ * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
+ * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
+ * (is_asym is false) and Zp(is_asym is true).
+ * @param thread_pool
+ */
+void MLASCALL
+MlasNBitsGemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t block_size,
+ int nbits,
+ bool is_asym,
+ bool last_call,
+ MLAS_SQNBIT_COMPUTE_TYPE comp_type,
+ MLAS_THREADPOOL* thread_pool
+);
+
+/**
+ * @brief Unpack and dequantize to fp32
+ *
+ * @param FpData unpacked float32 data
+ * @param PackedBuf quantized and packed data
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param ldb leading dimension of B
+ * @param thread_pool
+ */
+void MLASCALL
+MlasNBitsGemmUnPackB(
+ float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* thread_pool
+);
+
+/**
+ * @brief Get the workspace size required by computation.
+ *
+ * @param[in] M row size of matrix A and C
+ * @param[in] N column size of matrix B and C
+ * @param[in] K column size of matrix A and row size of matrix B
+ * @param[in] BatchN number of batches
+ * @param[inout] DataParams An array (size BatchN) of parameter blocks
+ * @return Workspace size in bytes
+ */
+size_t MLASCALL
+MlasSQNBitsGemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+);
+
+/**
+ * @brief Batched GEMM: C = A * B
+ * A, C must be a float32 matrix
+ * B must be a packed nbits blob
+ *
+ * @param[in] M row size of matrix A and C
+ * @param[in] N column size of matrix B and C
+ * @param[in] K column size of matrix A and row size of matrix B
+ * @param[in] BatchN number of batches
+ * @param[inout] DataParams An array (size BatchN) of parameter blocks
+ * @param[in] WorkSpace temporary buffer
+ * @param[in] ThreadPool
+ * @return
+ */
+void MLASCALL
+MlasSQNBitsGemmBatchPackedB(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ void* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool = nullptr
+);
diff --git a/onnxruntime/core/mlas/lib/jblas_defs.h b/onnxruntime/core/mlas/lib/jblas_defs.h
new file mode 100644
index 0000000000000..9cd1711a3ffd2
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/jblas_defs.h
@@ -0,0 +1,73 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+--*/
+
+#pragma once
+
+#include "jblas/jit_blas_prologue_b.h"
+#include "jblas/jit_blas_wrapper.h"
+
+namespace jblas
+{
+
+/*
+Name conversion explaination:
+Fp32: comp type, determined by GemmCore, can be any jblas::gemm::SCorexxx(float GemmCore)
+S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(also support other integer and float weight
+classes)
+F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and
+jblas::epilogue::gemm::AccumulatorWriteBackFp32.
+
+Tips: jblas::epilogue::gemm::CompFp32BlockEpilogue is a fixed class for all fp32 accumulator GemmCores.
+*/
+template
+using tLauncher_Fp32_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock<
+ GemmCore_T::ISA,
+ GemmCore_T,
+ jblas::prologue_a::gemm::ActivationKBlockBaseF32,
+ jblas::prologue_b::gemm::WeightKBlockS4,
+ jblas::epilogue::gemm::CompFp32BlockEpilogue,
+ jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+
+/*
+Name conversion explaination:
+Int8: comp type, determined by GemmCore, can be any jblas::gemm::ICorexxx(integer GemmCore)
+S4: weight dtype, determined by jblas::prologue_b::gemm::WeightKBlockS4(support integer weight classes only)
+F32F32: input/output dtype, determined by jblas::prologue_a::gemm::ActivationKBlockBaseF32 and
+jblas::epilogue::gemm::AccumulatorWriteBackFp32.
+
+Tips: jblas::epilogue::gemm::CompInt8BlockEpilogue is a fixed class for all int32 accumulator GemmCores.
+*/
+template
+using tLauncher_Int8_S4_F32F32 = jblas::wrapper::gemm::LauncherKBlock<
+ GemmCore_T::ISA,
+ GemmCore_T,
+ jblas::prologue_a::gemm::ActivationF32KBlockQuantize,
+ jblas::prologue_b::gemm::WeightKBlockS4,
+ jblas::epilogue::gemm::CompInt8BlockEpilogue,
+ jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+
+using tAVX512F = jblas::gemm::SCoreRowNAvx512f<48, 8>;
+using tAMX_BF16 = jblas::gemm::HCoreRowNAmxbf16<64, 16>;
+using tAVX512_FP16 = jblas::gemm::HCoreRowNAvx512fp16<96, 8>;
+using tAVX_VNNI = jblas::gemm::ICoreRowNAvxvnni<48, 2>; // TODO(Yu) use 24x4 for higher efficiency
+using tAVX512_VNNI = jblas::gemm::ICoreRowNAvx512vnni<48, 8>;
+using tAMX_INT8_US = jblas::gemm::ICoreRowNAmxint8<64, 16>;
+using tAMX_INT8_SS = jblas::gemm::ICoreRowNAmxint8SS<64, 16>;
+using tAVX2 = jblas::gemm::SCoreRowNAvx2<48, 2>; // TODO(Yu) use 24x4 for higher efficiency
+
+class ORTThreading : public jblas::parallel::IThreading
+{
+ public:
+ ORTThreading(void* tp);
+ void parallel_for(const jblas::parallel::thread_func& func) override;
+ void set_threads(int nthreads) override { assert(0); }
+ void sync() override { assert(0); }
+ void* mTp;
+};
+
+} // namespace jblas
diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.cpp b/onnxruntime/core/mlas/lib/jblas_gemm.cpp
new file mode 100644
index 0000000000000..f3cae3186c28e
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/jblas_gemm.cpp
@@ -0,0 +1,534 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ jblas_gemm.cpp
+
+Abstract:
+
+ Currently only support Q4 gemm.
+--*/
+
+#include "jblas_gemm.h"
+
+#include "jblas_defs.h"
+#include "mlasi.h"
+
+using namespace jblas;
+
+jblas::ORTThreading::ORTThreading(void* tp)
+ : IThreading(MLAS_THREADPOOL::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp)
+{
+}
+
+void
+jblas::ORTThreading::parallel_for(const jblas::parallel::thread_func& func)
+{
+ MlasTrySimpleParallel(reinterpret_cast(mTp), mThreadNum, [&](ptrdiff_t tid) {
+ func(static_cast(tid));
+ });
+}
+
+template
+static void
+JblasSQ4GemmCompF32(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc,
+ int8_t* WorkSpace,
+ jblas::parallel::IThreading* th
+)
+{
+ auto M_ = static_cast(M);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto lda_ = static_cast(lda);
+ auto ldc_ = static_cast(ldc);
+ if (M <= 16) {
+ using Parallel = jblas::parallel::gemm::SchedulerKBlock;
+ using Launcher = tLauncher_Fp32_S4_F32F32;
+ static Launcher kernel;
+ auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize);
+ if (B->mIsAsym) {
+ reduceA.assign(WorkSpace);
+ ORTThreading single(nullptr);
+ kernel.mProA.reduce({A, lda_}, &reduceA, M_, K_, &single);
+ }
+ typename Launcher::BEpiParam blkargs{
+ B->template SPtr(), B->mScaT, B->mCStep, B->template ZPtr(),
+ reduceA.template get(), reduceA.lda};
+
+ typename Launcher::Param args{M_, N_, K_, B->mBlockSize, {A, lda_}, {B}, blkargs, {C, ldc_}};
+ jblas::parallel::GemmKBlockRun(kernel, args, th);
+ } else {
+ using Parallel = jblas::parallel::gemm::SchedulerBase;
+ using Launcher = jblas::wrapper::gemm::LauncherBase<
+ GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase,
+ jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+ static Launcher kernel;
+
+ typename Launcher::Param args{M_, N_, K_, {A, lda_}, {B}, {C, ldc_}};
+ jblas::parallel::GemmBaseRun(kernel, args, th);
+ }
+}
+
+template
+static void
+JblasSQ4GemmCompInt8(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc,
+ int8_t* WorkSpace,
+ jblas::parallel::IThreading* th
+)
+{
+ using Parallel = jblas::parallel::gemm::SchedulerKBlock;
+ using Launcher = tLauncher_Int8_S4_F32F32;
+ auto M_ = static_cast(M);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto lda_ = static_cast(lda);
+ auto ldc_ = static_cast(ldc);
+ static Launcher kernel;
+ auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->mIsAsym);
+ quanA.assign(WorkSpace);
+ if (M <= 16) {
+ ORTThreading single(nullptr);
+ kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single);
+ } else {
+ kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th);
+ }
+ typename Launcher::Param args{
+ M_,
+ N_,
+ K_,
+ B->mBlockSize,
+ {A, lda_, &quanA},
+ {B},
+ {B->template SPtr(), B->mScaT, B->mCStep, quanA.template SPtr(), quanA.mCStep,
+ quanA.template ZPtr(), B->template RPtr(), B->mRedT, B->template ZPtr(),
+ quanA.template RPtr(), B->mBlockSize},
+ {C, ldc_}};
+ jblas::parallel::GemmKBlockRun(kernel, args, th);
+}
+
+bool
+JblasSQ4GemmBatchDriver(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ int8_t* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ GetCPUDevice();
+ ORTThreading orth(ThreadPool);
+ bool processed = true;
+ for (size_t i = 0; i < BatchN; i++) {
+ auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
+ auto uptr = std::unique_ptr(ptr);
+ if (ptr) {
+ if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
+ auto kptr = reinterpret_cast(ptr);
+ auto coretype = ptr->mCoreId;
+ auto NTile = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT
+ );
+ auto CType = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT
+ );
+ if (CType == uint32_t(gemm::CompType::COMP_FP32)) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
+ JblasSQ4GemmCompF32(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
+ JblasSQ4GemmCompF32(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) {
+ if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) {
+ if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) {
+ JblasSQ4GemmCompInt8(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc,
+ WorkSpace, &orth
+ );
+ }
+ }
+ }
+ } else {
+ processed = false;
+ break;
+ }
+ }
+ return processed;
+}
+
+template
+static size_t
+JblasSQ4GemmCompF32WorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc
+)
+{
+ auto M_ = static_cast(M);
+ auto K_ = static_cast(K);
+ (void)(N);
+ (void)(lda);
+ (void)(ldc);
+ if (M <= 16) {
+ using Launcher = tLauncher_Fp32_S4_F32F32;
+ static Launcher kernel;
+ if (B->mIsAsym) {
+ auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize);
+ return reduceA.mSize;
+ }
+ return 0;
+ } else {
+ using Launcher = jblas::wrapper::gemm::LauncherBase<
+ GemmCore_T::ISA, GemmCore_T, jblas::prologue_a::gemm::ActivationBase,
+ jblas::prologue_b::gemm::WeightKBlockS4, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
+ static Launcher kernel;
+ return 0;
+ }
+ return 0;
+}
+
+template
+static size_t
+JblasSQ4GemmCompInt8WorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const float* A,
+ const size_t lda,
+ jblas::storage::gemm::StorageWeightKBlockS4* B,
+ float* C,
+ const size_t ldc
+)
+{
+ using Parallel = jblas::parallel::gemm::SchedulerKBlock;
+ using Launcher = tLauncher_Int8_S4_F32F32;
+ static Launcher kernel;
+ (void)(N);
+ (void)(lda);
+ (void)(ldc);
+ auto quanA = kernel.mProA.createStorage(
+ static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->mIsAsym
+ );
+ return quanA.mSize;
+}
+
+size_t
+JblasSQ4GemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+)
+{
+ GetCPUDevice();
+ size_t size = 0;
+ for (size_t i = 0; i < BatchN; i++) {
+ auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
+ auto uptr = std::unique_ptr(ptr);
+ if (ptr) {
+ if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
+ auto kptr = reinterpret_cast(ptr);
+ auto coretype = ptr->mCoreId;
+ auto NTile = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT
+ );
+ auto CType = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT
+ );
+ if (CType == uint32_t(gemm::CompType::COMP_FP32)) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
+ size = std::max(
+ JblasSQ4GemmCompF32WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
+ size = std::max(
+ JblasSQ4GemmCompF32WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_US_INT32)) {
+ if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ }
+ }
+ if (CType == uint32_t(gemm::CompType::COMP_INT8_SS_INT32)) {
+ if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) {
+ size = std::max(
+ JblasSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc
+ ),
+ size
+ );
+ }
+ }
+ }
+ }
+ }
+ return size;
+}
+
+template
+static size_t
+JblasQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym)
+{
+ static T launcher;
+ auto stor = launcher.mProB.createStorage(
+ static_cast(N), static_cast(K), static_cast(block_size), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32,
+ JBLAS_DTYPE::BF16, isAsym
+ );
+ // TODO(Yu) support more scale dtype
+ return stor.mSize;
+}
+
+size_t
+JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType)
+{
+ GetCPUDevice();
+ if (K % BlkSize != 0) {
+ return 0;
+ }
+ // from low precision to high precision
+ switch (CompType) {
+ case CompInt8:
+ if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ case CompBf16:
+ case CompFp16:
+ case CompFp32:
+ case CompUndef:
+ if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ return JblasQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ break;
+ default:
+ return 0;
+ }
+ return 0;
+}
+
+template
+static void
+JblasQ4GemmPackBImpl(
+ void* PackedBuf,
+ size_t BlkSize,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ bool IsAsym,
+ bool lastCall,
+ size_t ldb,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ static T JblasKernel;
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto stor = JblasKernel.mProB.createStorage(
+ N_, K_, static_cast(BlkSize), JBLAS_DTYPE::S4_CLIP, JBLAS_DTYPE::F32, JBLAS_DTYPE::BF16, IsAsym
+ );
+ stor.assign(reinterpret_cast(PackedBuf));
+ ORTThreading orth(ThreadPool);
+ JblasKernel.mProB.packNbitsWeight(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth);
+ if (lastCall) {
+ JblasKernel.mProB.reduceWeight(&stor, &orth);
+ }
+}
+
+bool
+JblasQ4GemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t BlkSize,
+ bool isAsym,
+ bool lastCall,
+ MLAS_SQNBIT_COMPUTE_TYPE CompType,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ GetCPUDevice();
+ // explicit statement fall through.
+ switch (CompType) {
+ case CompInt8:
+ if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ case CompBf16:
+ case CompFp16:
+ case CompFp32:
+ case CompUndef:
+ if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ JblasQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool
+ );
+ return true;
+ }
+ default:
+ return false;
+ }
+ return false;
+}
+
+bool
+JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool)
+{
+ auto ptr = jblas::storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf);
+ auto uptr = std::unique_ptr(ptr);
+ ORTThreading orth(ThreadPool);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto ldb_ = static_cast(ldb);
+ GetCPUDevice();
+ if (ptr) {
+ if (ptr->mPrologueID == JBLAS_PROLOGUEB_IDS::WeightKBlockS4) {
+ auto NTile = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::NTILE_MASK, jblas::gemm::CoreAttr::NTILE_SHIFT
+ );
+ auto CType = jblas::gemm::CoreAttr::get_mask_val(
+ ptr->mCoreId, jblas::gemm::CoreAttr::COMP_MASK, jblas::gemm::CoreAttr::COMP_SHIFT
+ );
+ if (CType == uint32_t(jblas::gemm::CompType::COMP_FP32)) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ }
+ }
+ if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_US_INT32)) {
+ if (NTile == tAMX_INT8_US::NTILE && _cd->AMX_INT8()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX512_VNNI::NTILE && _cd->AVX512_VNNI()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX_VNNI::NTILE && _cd->AVX_VNNI()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ }
+ }
+ if (CType == uint32_t(jblas::gemm::CompType::COMP_INT8_SS_INT32)) {
+ if (NTile == tAMX_INT8_SS::NTILE && _cd->AMX_INT8()) {
+ static jblas::prologue_b::gemm::WeightKBlockS4 proB;
+ proB.unpackWeight(N_, K_, ptr, FpData, ldb_, &orth);
+ }
+ }
+ }
+ return true;
+ }
+ return false;
+}
diff --git a/onnxruntime/core/mlas/lib/jblas_gemm.h b/onnxruntime/core/mlas/lib/jblas_gemm.h
new file mode 100644
index 0000000000000..044dc5e849a0a
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/jblas_gemm.h
@@ -0,0 +1,61 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ jblas_gemm.h
+
+Abstract:
+
+ Currently only support Q4 gemm.
+--*/
+
+#pragma once
+
+#include "mlas_qnbit.h"
+
+size_t
+JblasQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType);
+
+bool
+JblasQ4GemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t BlkSize,
+ bool isAsym,
+ bool lastCall,
+ MLAS_SQNBIT_COMPUTE_TYPE CompType,
+ MLAS_THREADPOOL* ThreadPool
+);
+
+bool
+JblasQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb
+ , MLAS_THREADPOOL* ThreadPool);
+
+bool
+JblasSQ4GemmBatchDriver(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ int8_t* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool
+);
+
+size_t
+JblasSQ4GemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+);
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index 7bda1bb504173..7bb8b17031a84 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -50,7 +50,9 @@ Module Name:
#include
#endif
#if defined(__x86_64__) || defined(__i386__)
+#if !defined(signature_VORTEX_ebx) && !defined(signature_NEXGEN_ebx) && !defined(signature_AMD_ebx)//workaround for Bug 96238 - [i386] cpuid.h header needs include guards
#include
+#endif
#if defined(__GNUC__) && __GNUC__ >= 12
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h.
diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
index f964b1affec31..7f1d1b084aec0 100644
--- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
+++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
@@ -15,6 +15,9 @@ Module Name:
--*/
#include "sqnbitgemm.h"
+#ifdef MLAS_JBLAS
+#include "jblas_gemm.h"
+#endif
namespace
{
@@ -142,3 +145,127 @@ MlasIsSQNBitGemmAvailable(
return true;
}
+
+size_t MLASCALL
+MlasNBitsGemmPackBSize(
+ size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, MLAS_SQNBIT_COMPUTE_TYPE CompType
+)
+{
+#ifdef MLAS_JBLAS
+ if (nbits == 4) {
+ auto jsize = JblasQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType);
+ if (jsize) {
+ return jsize;
+ }
+ }
+#endif
+ (void)(N);
+ (void)(K);
+ (void)(BlkSize);
+ (void)(nbits);
+ (void)(isAsym);
+ (void)(CompType);
+ return 0;
+}
+
+void MLASCALL
+MlasNBitsGemmPackB(
+ void* PackedBuf,
+ const uint8_t* QData,
+ const float* Scale,
+ const uint8_t* Zp,
+ size_t N,
+ size_t K,
+ size_t ldb,
+ size_t BlkSize,
+ int nbits,
+ bool isAsym,
+ bool lastCall,
+ MLAS_SQNBIT_COMPUTE_TYPE CompType,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+#ifdef MLAS_JBLAS
+ if (nbits == 4) {
+ if (JblasQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) {
+ return;
+ }
+ }
+#endif
+ (void)(PackedBuf);
+ (void)(QData);
+ (void)(Scale);
+ (void)(Zp);
+ (void)(N);
+ (void)(K);
+ (void)(ldb);
+ (void)(BlkSize);
+ (void)(nbits);
+ (void)(isAsym);
+ (void)(lastCall);
+ (void)(CompType);
+ (void)(ThreadPool);
+}
+
+void MLASCALL
+MlasNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, MLAS_THREADPOOL* ThreadPool)
+{
+#ifdef MLAS_JBLAS
+ if (JblasQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) {
+ return;
+ }
+#endif
+ (void)(FpData);
+ (void)(PackedBuf);
+ (void)(N);
+ (void)(K);
+ (void)(ldb);
+ (void)(ThreadPool);
+}
+
+size_t MLASCALL
+MlasSQNBitsGemmBatchWorkspaceSize(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams
+)
+{
+#ifdef MLAS_JBLAS
+ return JblasSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams);
+#endif
+ (void)(M);
+ (void)(N);
+ (void)(K);
+ (void)(BatchN);
+ (void)(DataParams);
+ return 0;
+}
+
+void MLASCALL
+MlasSQNBitsGemmBatchPackedB(
+ const size_t M,
+ const size_t N,
+ const size_t K,
+ const size_t BatchN,
+ const MLAS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams,
+ void* WorkSpace,
+ MLAS_THREADPOOL* ThreadPool
+)
+{
+ GetMlasPlatform();
+#ifdef MLAS_JBLAS
+ if (JblasSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) {
+ // PackedWeight is created by jblas
+ return;
+ }
+#endif
+ (void)(M);
+ (void)(N);
+ (void)(K);
+ (void)(BatchN);
+ (void)(DataParams);
+ (void)(WorkSpace);
+ (void)(ThreadPool);
+}
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format
new file mode 100644
index 0000000000000..84b876706161d
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/.clang-format
@@ -0,0 +1,7 @@
+Language: Cpp
+BasedOnStyle: Google
+DerivePointerAlignment: false
+ColumnLimit: 120
+SpaceBeforeParens: ControlStatements
+SpaceBeforeRangeBasedForLoopColon: true
+SortIncludes: false
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt
new file mode 100644
index 0000000000000..5d9c5edf45a96
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/CMakeLists.txt
@@ -0,0 +1,33 @@
+cmake_minimum_required(VERSION 3.5)
+
+project(jblas LANGUAGES CXX VERSION 0.1.0)
+
+file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp)
+file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp)
+
+add_library(${PROJECT_NAME} INTERFACE)
+add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME})
+
+target_include_directories(
+ ${PROJECT_NAME} INTERFACE
+ "$"
+ "$"
+)
+
+if(WIN32)
+ target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX)
+ target_compile_options(${PROJECT_NAME} INTERFACE /wd4068 /wd4849 /wd6262 /wd4702 /wd4100)
+ #4068 ignore unroll and GCC flags
+ #4849 ignore collapse
+ #6262 ignore stack too large
+ #4702 unreachable code(false warning on constexpr condition)
+ #4100 unreferenced formal parameter
+
+ target_link_options(${PROJECT_NAME} INTERFACE /STACK:3145728) #Stack requires up to L2 cache size
+endif(WIN32)
+
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17)
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h
new file mode 100644
index 0000000000000..143adb771760b
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_base.h
@@ -0,0 +1,303 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include
+
+#include
+#include
+#include "xbyak/xbyak.h"
+#include "xbyak/xbyak_util.h"
+
+#define OFFSET(field) offsetof(params, field)
+
+namespace jblas {
+
+namespace xbyak {
+class JitBase : protected Xbyak::CodeGenerator {
+ protected:
+ JitBase(size_t size = 16 * 1024) : CodeGenerator(size) {}
+
+ void load32(const Xbyak::Reg64& reg, const Xbyak::Address& addr) {
+ xor_(reg, reg);
+ mov(reg.cvt32(), addr);
+ }
+
+ void vreg_push(const Xbyak::Reg64& baseaddr) {
+#ifdef _WIN32
+ for (int i = 0; i < 10; i++) {
+ movaps(xword[baseaddr + i * 16], Xbyak::Xmm(6 + i));
+ }
+#endif
+ }
+
+ void vreg_pop(const Xbyak::Reg64& baseaddr) {
+#ifdef _WIN32
+ for (int i = 0; i < 10; i++) {
+ movaps(Xbyak::Xmm(6 + i), xword[baseaddr + i * 16]);
+ }
+#endif
+ }
+
+ void padto_le(const Xbyak::Reg64& _src, int padding) {
+ // _src=_src/padding*padding
+ if (padding == 1) {
+ return;
+ }
+ for (int i = 1; i < 16; i++) {
+ if ((1 << i) == padding) {
+ shr(_src, i);
+ shl(_src, i);
+ return;
+ }
+ }
+ assert(0);
+ }
+
+ void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total,
+ const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
+ inLocalLabel();
+ lea(_tmp, _total);
+ sub(_tmp, _pos);
+ cmp(_tmp, N);
+ jb(".maskflag");
+ cmp(_tmp, 0);
+ jl(".zeroflag");
+ uint64_t allmask = (static_cast(1) << N) - 1;
+ if (N == 64) {
+ allmask = static_cast(-1);
+ }
+ mov(_tmp, allmask);
+ kmovq(_msk, _tmp);
+ jmp(".maskend");
+ L(".maskflag");
+ mov(_tmp1, 1);
+ shlx(_tmp1, _tmp1, _tmp);
+ sub(_tmp1, 1);
+ kmovq(_msk, _tmp1);
+ jmp(".maskend");
+ L(".zeroflag");
+ mov(_tmp1, 0);
+ kmovq(_msk, _tmp1);
+ L(".maskend");
+ outLocalLabel();
+ }
+ void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Reg64& _total,
+ const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
+ generate_Nbitsmask(_msk, _pos, ptr[_total], _tmp, _tmp1, N);
+ }
+};
+
+class JitAvx : protected JitBase {
+ protected:
+ static int constexpr VBits = 256;
+ static int constexpr VecBytes = VBits / 8;
+ static int constexpr RegCount = 16;
+ typedef Xbyak::Ymm vreg_t;
+};
+
+class JitAvx2 : protected JitAvx {
+ protected:
+ static int constexpr VBits = 256;
+ typedef Xbyak::Ymm vreg_t;
+ void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); }
+
+ void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) {
+ vpmovzxwd(dst, addr);
+ vpslld(dst, dst, 16);
+ }
+};
+
+class JitAvx512f : protected JitAvx2 {
+ protected:
+ static int constexpr VBits = 512;
+ static int constexpr VecBytes = VBits / 8;
+ static int constexpr RegCount = 32;
+ typedef Xbyak::Zmm vreg_t;
+
+ void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); }
+
+ void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) {
+ vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]);
+ vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]);
+ vshuff32x4(src_2regs[0], tmp_2reg[0], tmp_2reg[1], 0 | (1 << 2) | (0 << 4) | (1 << 6));
+ vshuff32x4(src_2regs[0], src_2regs[0], src_2regs[0], 0 | (2 << 2) | (1 << 4) | (3 << 6));
+ vshuff32x4(src_2regs[1], tmp_2reg[0], tmp_2reg[1], 2 | (3 << 2) | (2 << 4) | (3 << 6));
+ vshuff32x4(src_2regs[1], src_2regs[1], src_2regs[1], 0 | (2 << 2) | (1 << 4) | (3 << 6));
+ }
+
+ void transpose16x16_4B(Xbyak::Zmm* src, Xbyak::Zmm* tmp, const int N = 16) {
+ for (int i = 0; i < 8; ++i) {
+ vpunpckldq(tmp[2 * i + 0], src[2 * i], src[2 * i + 1]);
+ vpunpckhdq(tmp[2 * i + 1], src[2 * i], src[2 * i + 1]);
+ }
+
+ for (int i = 0; i < 4; ++i) {
+ vpunpcklqdq(src[4 * i + 0], tmp[4 * i + 0], tmp[4 * i + 2]);
+ vpunpckhqdq(src[4 * i + 1], tmp[4 * i + 0], tmp[4 * i + 2]);
+ vpunpcklqdq(src[4 * i + 2], tmp[4 * i + 1], tmp[4 * i + 3]);
+ vpunpckhqdq(src[4 * i + 3], tmp[4 * i + 1], tmp[4 * i + 3]);
+ }
+
+ for (int i = 0; i < 2; ++i) {
+ vshufi32x4(tmp[8 * i + 0], src[8 * i + 0], src[8 * i + 4], 0x88);
+ vshufi32x4(tmp[8 * i + 1], src[8 * i + 1], src[8 * i + 5], 0x88);
+ vshufi32x4(tmp[8 * i + 2], src[8 * i + 2], src[8 * i + 6], 0x88);
+ vshufi32x4(tmp[8 * i + 3], src[8 * i + 3], src[8 * i + 7], 0x88);
+ vshufi32x4(tmp[8 * i + 4], src[8 * i + 0], src[8 * i + 4], 0xdd);
+ vshufi32x4(tmp[8 * i + 5], src[8 * i + 1], src[8 * i + 5], 0xdd);
+ vshufi32x4(tmp[8 * i + 6], src[8 * i + 2], src[8 * i + 6], 0xdd);
+ vshufi32x4(tmp[8 * i + 7], src[8 * i + 3], src[8 * i + 7], 0xdd);
+ }
+
+ // last step and move out
+ for (int i = 0; i < N; ++i) {
+ vshufi32x4(src[i], tmp[i % 8], tmp[8 + i % 8], i < 8 ? 0x88 : 0xdd);
+ }
+ }
+
+ void interleave_4rows_6regs(Xbyak::Zmm* src_4regs, Xbyak::Zmm* tmp_regs, const Xbyak::Opmask* masks) {
+ vpunpcklbw(tmp_regs[0], src_4regs[0], src_4regs[1]);
+ vpunpckhbw(tmp_regs[1], src_4regs[0], src_4regs[1]);
+ vpunpcklbw(tmp_regs[2], src_4regs[2], src_4regs[3]);
+ vpunpckhbw(tmp_regs[3], src_4regs[2], src_4regs[3]);
+
+ vpunpcklwd(tmp_regs[4], tmp_regs[0], tmp_regs[2]);
+ vpunpckhwd(tmp_regs[5], tmp_regs[0], tmp_regs[2]);
+ vpunpcklwd(tmp_regs[0], tmp_regs[1], tmp_regs[3]);
+ vpunpckhwd(tmp_regs[2], tmp_regs[1], tmp_regs[3]);
+ vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (4 << 4) | 4);
+ vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (4 << 4) | 4);
+ vmovups(src_4regs[0], tmp_regs[1]);
+ vshuff32x4(src_4regs[0] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6));
+ vmovups(src_4regs[1], tmp_regs[3]);
+ vshuff32x4(src_4regs[1] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6));
+ vshuff32x4(tmp_regs[1], tmp_regs[4], tmp_regs[0], (14 << 4) | 14);
+ vshuff32x4(tmp_regs[3], tmp_regs[5], tmp_regs[2], (14 << 4) | 14);
+ vmovups(src_4regs[2], tmp_regs[1]);
+ vshuff32x4(src_4regs[2] | masks[0], tmp_regs[3], tmp_regs[3], 0 | (0 << 2) | (0 << 4) | (2 << 6));
+ vmovups(src_4regs[3], tmp_regs[3]);
+ vshuff32x4(src_4regs[3] | masks[1], tmp_regs[1], tmp_regs[1], 1 | (0 << 2) | (3 << 4) | (0 << 6));
+ }
+
+ void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) {
+ vpsrld(_fp32, _fp32, 16);
+ vpmovdw(_bf16, _fp32);
+ }
+
+ void loadbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Address& addr) {
+ vpmovzxwd(dst, addr);
+ vpslld(dst, dst, 16);
+ }
+
+ void broadcastbf16_f32(const Xbyak::Zmm& dst, const Xbyak::Reg64& tmp, const Xbyak::Address& addr) {
+ mov(tmp.cvt16(), addr);
+ shl(tmp.cvt32(), 16);
+ vpbroadcastd(dst, tmp.cvt32());
+ }
+
+ void store_fp32_bf16(const Xbyak::Zmm& _fp32, const Xbyak::Address& _add) {
+ auto bf16 = Xbyak::Ymm(_fp32.getIdx());
+ cvt_fp32_bf16(bf16, _fp32);
+ vmovups(_add, bf16);
+ }
+};
+
+class JitAvx512_bf16 : protected JitAvx512f {};
+
+class JitAvx512_fp16 : protected JitAvx512f {};
+
+class JitAvx512vnni : protected JitAvx512f {
+ protected:
+ void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
+ vpdpbusds(x1, x2, op, Xbyak::EvexEncoding);
+ }
+};
+
+class JitAvxvnni : protected JitAvx2 {
+ protected:
+ void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
+ vpdpbusds(x1, x2, op, Xbyak::VexEncoding);
+ }
+};
+
+class JitAmxtile : protected JitAvx512f {
+ public:
+ struct alignas(64) tileconfig_t {
+ uint8_t palette_id;
+ uint8_t reserved[15];
+ uint16_t colb[16];
+ uint8_t rows[16];
+ };
+ static int constexpr TileCount = 8;
+
+ typedef long long (*configure_t)(void*);
+
+ static void generate_config(Xbyak::CodeGenerator* g) {
+ Xbyak::util::StackFrame st(g, 1, 0, 0);
+ auto& parambase = st.p[0];
+ g->ldtilecfg(g->ptr[parambase]);
+ }
+
+ static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum,
+ int CNum) {
+ // Filling tile configure structure. Could be done offline.
+ tc.palette_id = 1;
+ // Configure C tiles
+ int t = 0;
+ for (; t < CNum; ++t) {
+ tc.rows[t] = static_cast(TILE_M);
+ tc.colb[t] = static_cast(TILE_N * 4);
+ }
+ // Configure A tiles
+ for (; t < CNum + ANum; ++t) {
+ tc.rows[t] = static_cast(TILE_M);
+ tc.colb[t] = static_cast(TILE_K * elesize);
+ }
+ // Configure B tile. B effectively has 64 rows and 16 columns.
+ int kpack = 4 / elesize;
+ for (; t < CNum + ANum + BNum; ++t) {
+ tc.rows[t] = static_cast(TILE_K / kpack);
+ tc.colb[t] = static_cast(TILE_N * 4);
+ }
+ }
+};
+
+class JitAmxbf16 : protected JitAmxtile {
+ protected:
+ void cvt_fp32_bf16(const Xbyak::Ymm& _bf16, const Xbyak::Zmm& _fp32) { vcvtneps2bf16(_bf16, _fp32); }
+};
+
+class JitAmxint8 : protected JitAmxtile {
+ protected:
+ template
+ void _tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3);
+};
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbssd(x1, x2, x3);
+}
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbsud(x1, x2, x3);
+}
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbusd(x1, x2, x3);
+}
+template <>
+inline void JitAmxint8::_tdpb(const Xbyak::Tmm& x1, const Xbyak::Tmm& x2, const Xbyak::Tmm& x3) {
+ tdpbuud(x1, x2, x3);
+}
+} // namespace xbyak
+} // namespace jblas
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h
new file mode 100644
index 0000000000000..8ecf3535c17f4
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas.h
@@ -0,0 +1,96 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include
+enum JBLAS_CODE {
+ JblasSuccess = 0,
+ JblasInvalidParam = 1,
+ JblasInvalidISA = 2,
+ JblasRuntimeError = 4,
+ JblasNotSupport = 8,
+};
+enum JBLAS_ISA : uint32_t {
+ JblasNoSIMD = 0,
+ JblasAVX,
+ JblasAVX2,
+ JblasAVX_VNNI,
+ JblasAVX512F,
+ JblasAVX512_VNNI,
+ JblasAMX_BF16,
+ JblasAMX_INT8,
+ JblasAVX512_FP16,
+ JblasAVX512_BF16,
+};
+enum class JBLAS_DTYPE : uint32_t {
+ EleBitsMask = 0xff,
+ EleBitsUndef = 0,
+ EleBits4 = 4,
+ EleBits8 = 8,
+ EleBits16 = 16,
+ EleBits32 = 32,
+ EleBits64 = 64,
+ TypeMask = 0xff00,
+ TypeFloat = 0 << 8,
+ TypeInt = 1 << 8,
+ SubTypeMask = 0xff0000,
+ SubType0 = 0 << 16,
+ SubType1 = 1 << 16,
+ SubType2 = 2 << 16,
+ F64 = EleBits64 | TypeFloat,
+ F32 = EleBits32 | TypeFloat,
+ F16 = EleBits16 | TypeFloat,
+ BF16 = EleBits16 | TypeFloat | SubType1,
+ F8_E4M3 = EleBits8 | TypeFloat,
+ F8_E5M2 = EleBits8 | TypeFloat | SubType1,
+ F8_E3M4 = EleBits8 | TypeFloat | SubType2,
+ S8 = EleBits8 | TypeInt,
+ U8 = EleBits8 | TypeInt | SubType1,
+ S4_CLIP = EleBits4 | TypeInt,
+ S4_FULLRANGE = EleBits4 | TypeInt | SubType1,
+ F4_E2M1 = EleBits4 | TypeFloat,
+ F4_BNB = EleBits4 | TypeFloat | SubType1,
+ F4_NF4 = EleBits4 | TypeFloat | SubType2,
+ S32 = EleBits32 | TypeInt,
+ U32 = EleBits32 | TypeInt | SubType1,
+};
+
+enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 };
+enum JBLAS_TRANSPOSE {
+ JblasNoTrans = 111,
+ JblasTrans = 112,
+ JblasConjTrans = 113,
+};
+enum JBLAS_ELTWISEOP {
+ GELU,
+ SWISH,
+ TANH,
+ EXP,
+ LOW_PRECISION_EXP,
+ RELU,
+ LINEAR,
+};
+
+enum class JBLAS_PROLOGUEB_IDS : uint32_t {
+ Undef = (uint32_t)-1,
+ Begin = 0,
+ NormalBegin = Begin,
+ WeightPack = NormalBegin,
+ NormalEnd,
+ KBlockBegin = NormalEnd,
+ WeightKBlockS8 = KBlockBegin,
+ WeightKBlockS4,
+ WeightKBlockF4,
+ KBlockEnd,
+ End,
+};
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h
new file mode 100644
index 0000000000000..5cac1080bc610
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_device.h
@@ -0,0 +1,277 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include "jit_blas.h"
+#include "xbyak/xbyak_util.h"
+
+namespace jblas {
+
+namespace device {
+
+struct X64_ISA {
+ int64_t MMX : 1; // 0
+ int64_t SSE : 1; // 1
+ int64_t SSE2 : 1; // 2
+ int64_t SSE3 : 1; // 3
+ int64_t SSSE3 : 1; // 4
+ int64_t SSE41 : 1; // 5
+ int64_t SSE42 : 1; // 6
+ int64_t AVX : 1; // 7
+ int64_t F16C : 1; // 8
+ int64_t FMA : 1; // 9
+ int64_t AVX2 : 1; // 10
+ int64_t AVX_VNNI : 1; // 11
+ int64_t AVX_VNNI_INT8 : 1; // 12
+ int64_t AVX_NE_CONVERT : 1; // 13
+ int64_t AVX_IFMA : 1; // 14
+ int64_t AVX512F : 1; // 15
+ int64_t AVX512BW : 1; // 16
+ int64_t AVX512CD : 1; // 17
+ int64_t AVX512DQ : 1; // 18
+ int64_t AVX512ER : 1; // 19
+ int64_t AVX512IFMA52 : 1; // 20
+ int64_t AVX512PF : 1; // 21
+ int64_t AVX512VL : 1; // 22
+ int64_t AVX512VPOPCNTDQ : 1; // 23
+ int64_t AVX512_4FMAPS : 1; // 24
+ int64_t AVX512_4VNNIW : 1; // 25
+ int64_t AVX512_BF16 : 1; // 26
+ int64_t AVX512_BITALG : 1; // 27
+ int64_t AVX512_VBMI : 1; // 28
+ int64_t AVX512_VBMI2 : 1; // 29
+ int64_t AVX512_VNNI : 1; // 30
+ int64_t AVX512_VP2INTERSECT : 1; // 31
+ int64_t AVX512_FP16 : 1; // 32
+ int64_t AMX_TILE : 1; // 33
+ int64_t AMX_BF16 : 1; // 34
+ int64_t AMX_INT8 : 1; // 35
+ int64_t AMX_FP16 : 1; // 36
+ int64_t AMX_COMPLEX : 1; // 37
+ int64_t reserved : (64 - 38);
+};
+
+class AVX2_Default {
+ public:
+ static constexpr bool MMX = 1;
+ static constexpr bool SSE = 1;
+ static constexpr bool SSE2 = 1;
+ static constexpr bool SSE3 = 1;
+ static constexpr bool SSSE3 = 1;
+ static constexpr bool SSE41 = 1;
+ static constexpr bool SSE42 = 1;
+ static constexpr bool AVX = 1;
+ static constexpr bool F16C = 1;
+ static constexpr bool FMA = 1;
+ static constexpr bool AVX2 = 1;
+ static constexpr bool AVX_VNNI = 0;
+ static constexpr bool AVX_VNNI_INT8 = 0;
+ static constexpr bool AVX_NE_CONVERT = 0;
+ static constexpr bool AVX_IFMA = 0;
+ static constexpr bool AVX512F = 0;
+ static constexpr bool AVX512BW = 0;
+ static constexpr bool AVX512CD = 0;
+ static constexpr bool AVX512DQ = 0;
+ static constexpr bool AVX512ER = 0;
+ static constexpr bool AVX512IFMA52 = 0;
+ static constexpr bool AVX512PF = 0;
+ static constexpr bool AVX512VL = 0;
+ static constexpr bool AVX512VPOPCNTDQ = 0;
+ static constexpr bool AVX512_4FMAPS = 0;
+ static constexpr bool AVX512_4VNNIW = 0;
+ static constexpr bool AVX512_BF16 = 0;
+ static constexpr bool AVX512_BITALG = 0;
+ static constexpr bool AVX512_VBMI = 0;
+ static constexpr bool AVX512_VBMI2 = 0;
+ static constexpr bool AVX512_VNNI = 0;
+ static constexpr bool AVX512_VP2INTERSECT = 0;
+ static constexpr bool AVX512_FP16 = 0;
+ static constexpr bool AMX_TILE = 0;
+ static constexpr bool AMX_BF16 = 0;
+ static constexpr bool AMX_INT8 = 0;
+ static constexpr bool AMX_FP16 = 0;
+ static constexpr bool AMX_COMPLEX = 0;
+};
+
+class AVX512_VNNI_Default {
+ public:
+ static constexpr bool MMX = 1;
+ static constexpr bool SSE = 1;
+ static constexpr bool SSE2 = 1;
+ static constexpr bool SSE3 = 1;
+ static constexpr bool SSSE3 = 1;
+ static constexpr bool SSE41 = 1;
+ static constexpr bool SSE42 = 1;
+ static constexpr bool AVX = 1;
+ static constexpr bool F16C = 1;
+ static constexpr bool FMA = 1;
+ static constexpr bool AVX2 = 1;
+ static constexpr bool AVX_VNNI = 0;
+ static constexpr bool AVX_VNNI_INT8 = 0;
+ static constexpr bool AVX_NE_CONVERT = 0;
+ static constexpr bool AVX_IFMA = 0;
+ static constexpr bool AVX512F = 1;
+ static constexpr bool AVX512BW = 1;
+ static constexpr bool AVX512CD = 1;
+ static constexpr bool AVX512DQ = 1;
+ static constexpr bool AVX512ER = 0;
+ static constexpr bool AVX512IFMA52 = 0;
+ static constexpr bool AVX512PF = 0;
+ static constexpr bool AVX512VL = 1;
+ static constexpr bool AVX512VPOPCNTDQ = 0;
+ static constexpr bool AVX512_4FMAPS = 0;
+ static constexpr bool AVX512_4VNNIW = 0;
+ static constexpr bool AVX512_BF16 = 0;
+ static constexpr bool AVX512_BITALG = 0;
+ static constexpr bool AVX512_VBMI = 0;
+ static constexpr bool AVX512_VBMI2 = 0;
+ static constexpr bool AVX512_VNNI = 1;
+ static constexpr bool AVX512_VP2INTERSECT = 0;
+ static constexpr bool AVX512_FP16 = 0;
+ static constexpr bool AMX_TILE = 0;
+ static constexpr bool AMX_BF16 = 0;
+ static constexpr bool AMX_INT8 = 0;
+ static constexpr bool AMX_FP16 = 0;
+ static constexpr bool AMX_COMPLEX = 0;
+};
+
+class SapphireRapids {
+ public:
+ static constexpr bool MMX = 1;
+ static constexpr bool SSE = 1;
+ static constexpr bool SSE2 = 1;
+ static constexpr bool SSE3 = 1;
+ static constexpr bool SSSE3 = 1;
+ static constexpr bool SSE41 = 1;
+ static constexpr bool SSE42 = 1;
+ static constexpr bool AVX = 1;
+ static constexpr bool F16C = 1;
+ static constexpr bool FMA = 1;
+ static constexpr bool AVX2 = 1;
+ static constexpr bool AVX_VNNI = 0;
+ static constexpr bool AVX_VNNI_INT8 = 0;
+ static constexpr bool AVX_NE_CONVERT = 0;
+ static constexpr bool AVX_IFMA = 0;
+ static constexpr bool AVX512F = 1;
+ static constexpr bool AVX512BW = 1;
+ static constexpr bool AVX512CD = 1;
+ static constexpr bool AVX512DQ = 1;
+ static constexpr bool AVX512ER = 0;
+ static constexpr bool AVX512IFMA52 = 0;
+ static constexpr bool AVX512PF = 0;
+ static constexpr bool AVX512VL = 1;
+ static constexpr bool AVX512VPOPCNTDQ = 0;
+ static constexpr bool AVX512_4FMAPS = 0;
+ static constexpr bool AVX512_4VNNIW = 0;
+ static constexpr bool AVX512_BF16 = 0;
+ static constexpr bool AVX512_BITALG = 0;
+ static constexpr bool AVX512_VBMI = 0;
+ static constexpr bool AVX512_VBMI2 = 0;
+ static constexpr bool AVX512_VNNI = 1;
+ static constexpr bool AVX512_VP2INTERSECT = 0;
+ static constexpr bool AVX512_FP16 = 0;
+ static constexpr bool AMX_TILE = 1;
+ static constexpr bool AMX_BF16 = 1;
+ static constexpr bool AMX_INT8 = 1;
+ static constexpr bool AMX_FP16 = 0;
+ static constexpr bool AMX_COMPLEX = 0;
+};
+
+template
+class isa_base {
+ public:
+ static bool constexpr avx = ISA_T >= JblasAVX;
+ static bool constexpr avx2 = ISA_T >= JblasAVX2;
+ static bool constexpr avx512f = ISA_T >= JblasAVX512F;
+ static bool constexpr avx512_vnni = ISA_T >= JblasAVX512_VNNI;
+ static bool constexpr avx512_fp16 = ISA_T >= JblasAVX512_FP16;
+ static bool constexpr amx_bf16 = ISA_T >= JblasAMX_BF16;
+ static bool constexpr amx_int8 = ISA_T >= JblasAMX_INT8;
+};
+
+class CpuDevice {
+ public:
+ inline void setThreads(int _nth) {
+ if (_nth <= 0) {
+ numthreads = numcores;
+ } else {
+ numthreads = std::min(numcores, _nth);
+ }
+ }
+ inline int getThreads() { return numthreads; }
+ inline int getCores() { return numcores; }
+ inline uint32_t getL2CacheSize() { return L2Cache; }
+ inline uint32_t getL1CacheSize() { return L1Cache; }
+ inline bool AVX() { return mHasAVX; }
+ inline bool AVX2() { return mHasAVX2; }
+ inline bool AVX_VNNI() { return mHasAVX_VNNI; }
+ inline bool AVX512F() { return mHasAVX512F; }
+ inline bool AVX512_VNNI() { return mHasAVX512_VNNI; }
+ inline bool AMX_INT8() { return mHasAMX_INT8; }
+ inline bool AMX_BF16() { return mHasAMX_BF16; }
+ inline bool AVX512_BF16() { return mHasAVX512_BF16; }
+ inline bool AVX512_FP16() { return mHasAVX512_FP16; }
+#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa)
+ CpuDevice() {
+ static Xbyak::util::Cpu _cpu;
+ L1Cache = _cpu.getDataCacheSize(0);
+ L2Cache = _cpu.getDataCacheSize(1);
+ ADD_FLAG(AVX);
+ ADD_FLAG(AVX2);
+ ADD_FLAG(AVX512F);
+ ADD_FLAG(AVX512_VNNI);
+ ADD_FLAG(AVX_VNNI);
+ ADD_FLAG(AMX_BF16);
+ ADD_FLAG(AMX_INT8);
+ ADD_FLAG(AVX512_BF16);
+ ADD_FLAG(AVX512_FP16);
+ numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
+ numthreads = numcores;
+ }
+
+ static CpuDevice* getInstance() {
+ static CpuDevice instance;
+ return &instance;
+ }
+
+ void print() {
+ printf(
+ "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n",
+ mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16,
+ mHasAVX512_FP16);
+ }
+#undef ADD_FLAG
+
+ protected:
+ uint32_t L2Cache, L1Cache;
+ bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16,
+ mHasAVX512_FP16;
+ int numcores;
+ int numthreads;
+};
+
+#define GetCPUDevice() auto _cd = jblas::device::CpuDevice::getInstance();
+
+class CpuBase {
+ public:
+ CpuBase() {
+ GetCPUDevice();
+ mL2Cache = _cd->getL2CacheSize();
+ mL1Cache = _cd->getL1CacheSize();
+ mNumThreads = _cd->getThreads();
+ }
+ size_t mL2Cache, mL1Cache;
+ int mNumThreads;
+};
+} // namespace device
+} // namespace jblas
diff --git a/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h
new file mode 100644
index 0000000000000..ceb7a545092d8
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/x86_64/jblas/jblas/jit_blas_epilogue.h
@@ -0,0 +1,329 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+#include
+
+#include "jit_base.h"
+#include "jit_blas.h"
+#include "jit_blas_utils.h"
+#include "kernel_wrapper.h"
+
+namespace jblas {
+namespace epilogue {
+namespace gemm {
+
+template
+class AccumulatorWriteBack {
+ public:
+ using SType = _SRC_T;
+ using DType = _DST_T;
+ struct Param {
+ DType* C;
+ int ldc;
+ void* elt_const_v;
+ };
+
+ template
+ JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize, Eltops... ops) {
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ bool constexpr Valid = !std::is_same::value || std::is_same::value;
+ static_assert(Valid, "fp32 to bf16 conversion only.");
+ if constexpr (std::is_same::value) {
+ return kernel::wrapper::Memcpy2DFp32CvtBf16::template forward(
+ const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false);
+ } else if constexpr (std::is_same, std::tuple>::value) {
+ return kernel::wrapper::Memcpy2DFp16CvtFp32::template forward(
+ const_cast<_SRC_T*>(cacheptr), cptr, M, N, cachestep * sizeof(SType), _param.ldc * sizeof(DType), false);
+ } else if constexpr (sizeof(SType) == sizeof(DType)) {
+ return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep,
+ _param.ldc, _param.elt_const_v, ops...);
+ } else {
+ assert(false);
+ }
+ }
+};
+
+template
+class CustomAccumulatorWriteBackWithEltop {
+ public:
+ struct Param {
+ _DST_T* C;
+ int ldc;
+ void* elt_const_v;
+ };
+ JBLAS_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
+ const int N, const Param& _param, void* tmpcache, size_t cachesize) {
+ auto COffset = M_offset * _param.ldc + N_offset;
+ auto cptr = _param.C + COffset;
+ if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) {
+ return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep,
+ _param.ldc, _param.elt_const_v);
+ } else {
+ assert(false);
+ }
+ }
+};
+template
+using AccumulatorWriteBackFp32 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackInt32 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackBf16 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackFp16 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack;
+template
+using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack