-
+#endif
namespace Ort {
namespace Custom {
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
index 60196d0c80cbb..32a9f06464ace 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
@@ -11,6 +11,8 @@
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
///
struct OrtTensorRTProviderOptionsV2 {
+ OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator
+
int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
@@ -46,8 +48,26 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
- int trt_dump_ep_context_model{0}; // Dump EP context node model
- int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
- int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
- const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
+
+ /*
+ * Please note that there are rules for using following context model related provider options:
+ *
+ * 1. In the case of dumping the context model and loading the context model,
+ * for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be
+ * the absolute path or relative path that is outside of context model directory.
+ * It means engine cache needs to be in the same directory or sub-directory of context model.
+ *
+ * 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory.
+ * For example:
+ * If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled,
+ * if "trt_ep_context_file_path" is "./context_model_dir",
+ * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir"
+ * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir"
+ *
+ */
+ int trt_dump_ep_context_model{0}; // Dump EP context node model
+ const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
+ int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
+
+ const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index b321b2b2bac27..101a578ec3e1d 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -38,7 +38,7 @@
*
* This value is used by some API functions to behave as this version of the header expects.
*/
-#define ORT_API_VERSION 17
+#define ORT_API_VERSION 18
#ifdef __cplusplus
extern "C" {
@@ -3608,6 +3608,14 @@ struct OrtApi {
* - "1": Faster preparation time, less optimal graph.
* - "2": Longer preparation time, more optimal graph.
* - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details.
+ * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown).
+ * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options:
+ * - "0": Default (none).
+ * - "68"
+ * - "69"
+ * - "73"
+ * - "75"
+ * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index df79cb6e5b21b..b282438795eb5 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";
-// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file.
+// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
// "0": disable. (default)
// "1": enable.
@@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
-static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
\ No newline at end of file
+static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
+
+// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
+// Option values:
+// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
+// - "1": Gemm FastMath mode is enabled.
+static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java
index 69ccb954e8afe..1c21387b50455 100644
--- a/java/src/main/java/ai/onnxruntime/TensorInfo.java
+++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java
@@ -7,6 +7,7 @@
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;
+import java.util.stream.Collectors;
/** Describes an {@link OnnxTensor}, including it's size, shape and element type. */
public class TensorInfo implements ValueInfo {
@@ -159,6 +160,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
/** The shape of the tensor. */
final long[] shape;
+ /** The names of the unbound dimensions. */
+ final String[] dimensionNames;
+
+ /** If there are non-empty dimension names */
+ private final boolean hasNames;
+
/** The Java type of this tensor. */
public final OnnxJavaType type;
@@ -177,6 +184,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
*/
TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
this.shape = shape;
+ this.dimensionNames = new String[shape.length];
+ Arrays.fill(dimensionNames, "");
+ this.hasNames = false;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
@@ -188,10 +198,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
* Called from JNI.
*
* @param shape The tensor shape.
+ * @param names The dimension names.
* @param typeInt The native type int.
*/
- TensorInfo(long[] shape, int typeInt) {
+ TensorInfo(long[] shape, String[] names, int typeInt) {
this.shape = shape;
+ this.dimensionNames = names;
+ boolean hasNames = false;
+ for (String s : names) {
+ if (!s.isEmpty()) {
+ hasNames = true;
+ break;
+ }
+ }
+ this.hasNames = hasNames;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
@@ -206,15 +226,42 @@ public long[] getShape() {
return Arrays.copyOf(shape, shape.length);
}
+ /**
+ * Get a copy of the tensor's named dimensions.
+ *
+ * @return A copof the tensor's named dimensions.
+ */
+ public String[] getDimensionNames() {
+ return Arrays.copyOf(dimensionNames, dimensionNames.length);
+ }
+
@Override
public String toString() {
- return "TensorInfo(javaType="
- + type.toString()
- + ",onnxType="
- + onnxType.toString()
- + ",shape="
- + Arrays.toString(shape)
- + ")";
+ String output =
+ "TensorInfo(javaType="
+ + type.toString()
+ + ",onnxType="
+ + onnxType.toString()
+ + ",shape="
+ + Arrays.toString(shape);
+ if (hasNames) {
+ output =
+ output
+ + ",dimNames=["
+ + Arrays.stream(dimensionNames)
+ .map(
+ a -> {
+ if (a.isEmpty()) {
+ return "\"\"";
+ } else {
+ return a;
+ }
+ })
+ .collect(Collectors.joining(","))
+ + "]";
+ }
+ output = output + ")";
+ return output;
}
/**
diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c
index 879ba8a310618..7b26291581395 100644
--- a/java/src/main/native/OrtJniUtil.c
+++ b/java/src/main/native/OrtJniUtil.c
@@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
if (code != ORT_OK) {
return NULL;
}
- //printf("numDim %d\n",numDim);
int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
if (code != ORT_OK) {
@@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
free(dimensions);
dimensions = NULL;
+ // Create the string array for the names.
+ const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim);
+ if (dimensionNames == NULL) {
+ throwOrtException(jniEnv, 1, "Not enough memory");
+ return NULL;
+ }
+ code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim));
+ if (code != ORT_OK) {
+ // extraction failed, exception has been thrown, return to Java.
+ free(dimensionNames);
+ return NULL;
+ }
+ jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
+ jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
+ for (size_t i = 0; i < numDim; i++) {
+ jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
+ (*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName);
+ }
+ free(dimensionNames);
+
// Create the TensorInfo object
static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
- jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V");
- //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
- jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt);
+ jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([J[Ljava/lang/String;I)V");
+ jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt);
return tensorInfo;
}
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index f6f9da1829402..7fef2dc784b7b 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -590,6 +590,12 @@ public void testSymbolicDimensionAssignment() throws OrtException {
Map infoMap = session.getInputInfo();
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
+ assertEquals(2, aInfo.dimensionNames.length);
+ assertEquals("n", aInfo.dimensionNames[0]);
+ assertEquals("", aInfo.dimensionNames[1]);
+ TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo();
+ assertEquals(1, bInfo.dimensionNames.length);
+ assertEquals("m", bInfo.dimensionNames[0]);
}
}
// Check that when the options are assigned it overrides the symbolic dimension
diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/common/lib/version.ts
+++ b/js/common/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/common/package-lock.json b/js/common/package-lock.json
index 84f6dba83fa59..a5ada877b916a 100644
--- a/js/common/package-lock.json
+++ b/js/common/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/common/package.json b/js/common/package.json
index beab7d29be263..64ab2736adbe3 100644
--- a/js/common/package.json
+++ b/js/common/package.json
@@ -2,7 +2,7 @@
"license": "MIT",
"type": "module",
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"repository": {
"url": "https://github.com/Microsoft/onnxruntime.git",
"type": "git"
diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/node/lib/version.ts
+++ b/js/node/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/node/package-lock.json b/js/node/package-lock.json
index 542eebe746d59..2d7c39c86097f 100644
--- a/js/node/package-lock.json
+++ b/js/node/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-node",
- "version": "1.17.0",
+ "version": "1.18.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-node",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"os": [
"win32",
@@ -27,7 +27,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/node/package.json b/js/node/package.json
index 8e591d8f46b9d..026840742e29e 100644
--- a/js/node/package.json
+++ b/js/node/package.json
@@ -13,7 +13,7 @@
3
]
},
- "version": "1.17.0",
+ "version": "1.18.0",
"dependencies": {
"onnxruntime-common": "file:../common"
},
diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/react_native/lib/version.ts
+++ b/js/react_native/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/react_native/package.json b/js/react_native/package.json
index 39e6cb08bb06a..47324a76fe55f 100644
--- a/js/react_native/package.json
+++ b/js/react_native/package.json
@@ -36,7 +36,7 @@
"registry": "https://registry.npmjs.org/"
},
"source": "lib/index",
- "version": "1.17.0",
+ "version": "1.18.0",
"main": "dist/commonjs/index",
"homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md",
"files": [
diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock
index ff9be7fbe3a5b..4dca90d7415cf 100644
--- a/js/react_native/yarn.lock
+++ b/js/react_native/yarn.lock
@@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2:
mimic-fn "^2.1.0"
"onnxruntime-common@file:../common":
- version "1.17.0"
+ version "1.18.0"
open@^6.2.0:
version "6.4.0"
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index 2f510308d9306..2557971eb4ded 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -52,6 +52,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
+| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts
index d9f63fec9c492..31ecffb07e40c 100644
--- a/js/web/lib/backend-wasm.ts
+++ b/js/web/lib/backend-wasm.ts
@@ -31,6 +31,12 @@ export const initializeFlags = (): void => {
}
if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) {
+ // Web: when crossOriginIsolated is false, SharedArrayBuffer is not available so WebAssembly threads will not work.
+ // Node.js: onnxruntime-web does not support multi-threads in Node.js.
+ if ((typeof self !== 'undefined' && !self.crossOriginIsolated) ||
+ (typeof process !== 'undefined' && process.versions && process.versions.node)) {
+ env.wasm.numThreads = 1;
+ }
const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency;
env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2));
}
diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/web/lib/version.ts
+++ b/js/web/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index 9d4d5875310b7..68054210e79a7 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepCreateDownloader:
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise;
+ /**
+ * [exported from js_internal_api.js] Called when InferenceSession.run started.
+ */
+ jsepOnRunStart: () => void;
// #endregion
}
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 2956ec1cad4da..afef7042a4280 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -208,7 +208,7 @@ export class WebGpuBackend {
Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
- // init queryType, which is necessary for createKernel
+ // init queryType, which is necessary for InferenceSession.create
this.setQueryType();
}
@@ -223,8 +223,6 @@ export class WebGpuBackend {
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();
- // refresh queryType, as sometimes we only need to enable query for a specific run
- this.setQueryType();
if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
@@ -639,6 +637,7 @@ export class WebGpuBackend {
return createView(data.buffer, type);
};
}
+ // #endregion
writeTimestamp(index: number): void {
if (this.queryType !== 'inside-passes') {
return;
@@ -657,5 +656,7 @@ export class WebGpuBackend {
}
}
}
- // #endregion
+ onRunStart(): void {
+ this.setQueryType();
+ }
}
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index 90e02da986b8f..cc504093ca0d7 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
+ ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
index 30754c84413b7..a0d4021516bf7 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
@@ -100,8 +100,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${calculateAlpha}
${(() => {
if (c != null) {
- return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += uniforms.beta * ${
- c.getByOffset('cOffset')};`;
+ return `let cOffset = ${c.broadcastedIndicesToOffset('vec2(m, n)', output)}; value += ${
+ dataType}(uniforms.beta) * ${c.getByOffset('cOffset')};`;
}
return '';
})()}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
index a25e7fe4229b4..82311d72e58b9 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
@@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
};
+export interface HardSigmoidAttributes extends AttributeWithCacheKey {
+ readonly alpha: number;
+ readonly beta: number;
+}
+
+export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes =>
+ createAttributeWithCacheKey(attributes as {
+ alpha: number;
+ beta: number;
+ });
+
+export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
+ context.compute(createElementwiseProgramInfo(
+ context.inputs[0], 'HardSigmoid',
+ a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
+ attributes.beta})))`,
+ undefined, attributes.cacheKey));
+};
+
export const sin = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
};
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 5821fac3c468f..8768643fa7257 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -488,8 +488,8 @@ export const run = async(
}
}
+ wasm.jsepOnRunStart?.();
let errorCode: number;
-
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
errorCode = await wasm._OrtRunWithBinding(
sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);
diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts
index 81508a253ce8b..9b9334c93b78c 100644
--- a/js/web/lib/wasm/wasm-factory.ts
+++ b/js/web/lib/wasm/wasm-factory.ts
@@ -28,13 +28,34 @@ let initialized = false;
let initializing = false;
let aborted = false;
-const isMultiThreadSupported = (): boolean => {
- try {
- // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work.
- if (typeof SharedArrayBuffer === 'undefined') {
- return false;
+const isMultiThreadSupported = (numThreads: number): boolean => {
+ // WebAssembly threads are set to 1 (single thread).
+ if (numThreads === 1) {
+ return false;
+ }
+
+ // If 'SharedArrayBuffer' is not available, WebAssembly threads will not work.
+ if (typeof SharedArrayBuffer === 'undefined') {
+ if (typeof self !== 'undefined' && !self.crossOriginIsolated) {
+ // eslint-disable-next-line no-console
+ console.warn(
+ 'env.wasm.numThreads is set to ' + numThreads +
+ ', but this will not work unless you enable crossOriginIsolated mode. ' +
+ 'See https://web.dev/cross-origin-isolation-guide/ for more info.');
}
+ return false;
+ }
+
+ // onnxruntime-web does not support multi-threads in Node.js.
+ if (typeof process !== 'undefined' && process.versions && process.versions.node) {
+ // eslint-disable-next-line no-console
+ console.warn(
+ 'env.wasm.numThreads is set to ' + numThreads +
+ ', however, currently onnxruntime-web does not support multi-threads in Node.js. ' +
+ 'Please consider using onnxruntime-node for performance critical scenarios.');
+ }
+ try {
// Test for transferability of SABs (for browsers. needed for Firefox)
// https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ
if (typeof MessageChannel !== 'undefined') {
@@ -106,7 +127,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise
const numThreads = flags.numThreads!;
const simd = flags.simd!;
- const useThreads = numThreads > 1 && isMultiThreadSupported();
+ const useThreads = isMultiThreadSupported(numThreads);
const useSimd = simd && isSimdSupported();
const wasmPaths = flags.wasmPaths;
diff --git a/js/web/lib/wasm/wasm-utils-load-file.ts b/js/web/lib/wasm/wasm-utils-load-file.ts
index abe480a43c790..c6cdba2320bde 100644
--- a/js/web/lib/wasm/wasm-utils-load-file.ts
+++ b/js/web/lib/wasm/wasm-utils-load-file.ts
@@ -47,9 +47,19 @@ export const loadFile = async(file: string|Blob|ArrayBufferLike|Uint8Array): Pro
}
const reader = response.body.getReader();
- // use WebAssembly Memory to allocate larger ArrayBuffer
- const pages = Math.ceil(fileSize / 65536);
- const buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer;
+ let buffer;
+ try {
+ // try to create ArrayBuffer directly
+ buffer = new ArrayBuffer(fileSize);
+ } catch (e) {
+ if (e instanceof RangeError) {
+ // use WebAssembly Memory to allocate larger ArrayBuffer
+ const pages = Math.ceil(fileSize / 65536);
+ buffer = new WebAssembly.Memory({initial: pages, maximum: pages}).buffer;
+ } else {
+ throw e;
+ }
+ }
let offset = 0;
// eslint-disable-next-line no-constant-condition
diff --git a/js/web/package-lock.json b/js/web/package-lock.json
index cd71c20ba4d2f..41c44aaa2679b 100644
--- a/js/web/package-lock.json
+++ b/js/web/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-web",
- "version": "1.17.0",
+ "version": "1.18.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-web",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"dependencies": {
"flatbuffers": "^1.12.0",
@@ -28,7 +28,7 @@
"@webgpu/types": "^0.1.38",
"base64-js": "^1.5.1",
"chai": "^4.3.7",
- "electron": "^23.1.2",
+ "electron": "^28.1.4",
"globby": "^13.1.3",
"karma": "^6.4.1",
"karma-browserstack-launcher": "^1.6.0",
@@ -49,7 +49,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
@@ -862,9 +862,9 @@
}
},
"node_modules/cross-spawn/node_modules/semver": {
- "version": "5.7.1",
- "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz",
- "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==",
+ "version": "5.7.2",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz",
+ "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==",
"dev": true,
"bin": {
"semver": "bin/semver"
@@ -1042,14 +1042,14 @@
"dev": true
},
"node_modules/electron": {
- "version": "23.3.13",
- "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz",
- "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==",
+ "version": "28.1.4",
+ "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz",
+ "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==",
"dev": true,
"hasInstallScript": true,
"dependencies": {
"@electron/get": "^2.0.0",
- "@types/node": "^16.11.26",
+ "@types/node": "^18.11.18",
"extract-zip": "^2.0.1"
},
"bin": {
@@ -1059,12 +1059,6 @@
"node": ">= 12.20.55"
}
},
- "node_modules/electron/node_modules/@types/node": {
- "version": "16.18.14",
- "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz",
- "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==",
- "dev": true
- },
"node_modules/emoji-regex": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz",
@@ -1432,9 +1426,9 @@
}
},
"node_modules/get-func-name": {
- "version": "2.0.0",
- "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz",
- "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==",
+ "version": "2.0.2",
+ "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz",
+ "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==",
"dev": true,
"engines": {
"node": "*"
@@ -1542,9 +1536,9 @@
}
},
"node_modules/global-agent/node_modules/semver": {
- "version": "7.3.8",
- "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz",
- "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==",
+ "version": "7.5.4",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
+ "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
"dev": true,
"optional": true,
"dependencies": {
@@ -2908,9 +2902,9 @@
"dev": true
},
"node_modules/semver": {
- "version": "6.3.0",
- "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz",
- "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==",
+ "version": "6.3.1",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz",
+ "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==",
"dev": true,
"bin": {
"semver": "bin/semver.js"
@@ -4203,9 +4197,9 @@
},
"dependencies": {
"semver": {
- "version": "5.7.1",
- "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz",
- "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==",
+ "version": "5.7.2",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz",
+ "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==",
"dev": true
}
}
@@ -4339,22 +4333,14 @@
"dev": true
},
"electron": {
- "version": "23.3.13",
- "resolved": "https://registry.npmjs.org/electron/-/electron-23.3.13.tgz",
- "integrity": "sha512-BaXtHEb+KYKLouUXlUVDa/lj9pj4F5kiE0kwFdJV84Y2EU7euIDgPthfKtchhr5MVHmjtavRMIV/zAwEiSQ9rQ==",
+ "version": "28.1.4",
+ "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz",
+ "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==",
"dev": true,
"requires": {
"@electron/get": "^2.0.0",
- "@types/node": "^16.11.26",
+ "@types/node": "^18.11.18",
"extract-zip": "^2.0.1"
- },
- "dependencies": {
- "@types/node": {
- "version": "16.18.14",
- "resolved": "https://registry.npmjs.org/@types/node/-/node-16.18.14.tgz",
- "integrity": "sha512-wvzClDGQXOCVNU4APPopC2KtMYukaF1MN/W3xAmslx22Z4/IF1/izDMekuyoUlwfnDHYCIZGaj7jMwnJKBTxKw==",
- "dev": true
- }
}
},
"emoji-regex": {
@@ -4657,9 +4643,9 @@
"dev": true
},
"get-func-name": {
- "version": "2.0.0",
- "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.0.tgz",
- "integrity": "sha512-Hm0ixYtaSZ/V7C8FJrtZIuBBI+iSgL+1Aq82zSu8VQNB4S3Gk8e7Qs3VwBDJAhmRZcFqkl3tQu36g/Foh5I5ig==",
+ "version": "2.0.2",
+ "resolved": "https://registry.npmjs.org/get-func-name/-/get-func-name-2.0.2.tgz",
+ "integrity": "sha512-8vXOvuE167CtIc3OyItco7N/dpRtBbYOsPsXCz7X/PMnlGjYjSGuZJgM1Y7mmew7BKf9BqvLX2tnOVy1BBUsxQ==",
"dev": true
},
"get-intrinsic": {
@@ -4742,9 +4728,9 @@
},
"dependencies": {
"semver": {
- "version": "7.3.8",
- "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz",
- "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==",
+ "version": "7.5.4",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
+ "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
"dev": true,
"optional": true,
"requires": {
@@ -5780,9 +5766,9 @@
"dev": true
},
"semver": {
- "version": "6.3.0",
- "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.0.tgz",
- "integrity": "sha512-b39TBaTSfV6yBrapU89p5fKekE2m/NwnDocOVruQFS1/veMgdzuPcnOM34M6CwxW8jH/lxEa5rBoDeUwu5HHTw==",
+ "version": "6.3.1",
+ "resolved": "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz",
+ "integrity": "sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==",
"dev": true
},
"semver-compare": {
diff --git a/js/web/package.json b/js/web/package.json
index 7ffc9ba16aaa9..a502c2b6b032d 100644
--- a/js/web/package.json
+++ b/js/web/package.json
@@ -8,7 +8,7 @@
"type": "git"
},
"author": "fs-eire",
- "version": "1.17.0",
+ "version": "1.18.0",
"jsdelivr": "dist/ort.min.js",
"dependencies": {
"flatbuffers": "^1.12.0",
@@ -47,7 +47,7 @@
"@webgpu/types": "^0.1.38",
"base64-js": "^1.5.1",
"chai": "^4.3.7",
- "electron": "^23.1.2",
+ "electron": "^28.1.4",
"globby": "^13.1.3",
"karma": "^6.4.1",
"karma-browserstack-launcher": "^1.6.0",
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index 033b3b3f4b0f5..373b3c645df57 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -597,9 +597,9 @@
// // "test_hardmax_example",
// // "test_hardmax_negative_axis",
// // "test_hardmax_one_hot",
- // // "test_hardsigmoid_default",
- // // "test_hardsigmoid_example",
- // // "test_hardsigmoid",
+ "test_hardsigmoid_default",
+ "test_hardsigmoid_example",
+ "test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index 57219c50f39aa..c3699f0fb33ad 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -7,7 +7,7 @@
For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_
or the `Github project `_.
"""
-__version__ = "1.17.0"
+__version__ = "1.18.0"
__author__ = "Microsoft"
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index 4711ccf487cc8..768676259aa14 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const {
relative_position_bias,
¶meters));
+ if (parameters.do_rotary) {
+ ORT_NOT_IMPLEMENTED(
+ "Rotary embedding is not supported in Attention CPU kernel. \
+ Please fuse the model with MHA + RotaryEmbedding.");
+ }
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int input_hidden_size = parameters.input_hidden_size;
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index 694c40bf3eda6..eb25d0fd7cc1e 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
num_heads_ = static_cast(num_heads);
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
+ is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
}
// Reshape Q/K/V from BxSxD to BxSxNxH
@@ -283,8 +284,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
nullptr,
¶meters,
num_heads_,
- scale,
mask_filter_value_,
+ scale,
+ is_unidirectional_,
past_present_share_buffer,
false));
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
index 4c86b777e9842..fb7da78a5c0a5 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
@@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
protected:
int num_heads_; // number of attention heads
float mask_filter_value_;
+ bool is_unidirectional_;
};
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index 00e82c9844b3d..c91f5b601b4e9 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -25,6 +25,7 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
+ bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing) {
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
@@ -315,7 +316,7 @@ Status CheckInputs(const T* query,
output_parameters->head_size = hidden_size / num_heads;
output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
- output_parameters->is_unidirectional = false;
+ output_parameters->is_unidirectional = is_unidirectional;
output_parameters->past_present_share_buffer = past_present_share_buffer;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
@@ -342,6 +343,7 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
+ bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing,
int max_threads_per_block) {
@@ -350,8 +352,8 @@ Status CheckInputs(const T* query,
}
return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value,
- past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer,
- dmmha_packing);
+ past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional,
+ past_present_share_buffer, dmmha_packing);
}
} // namespace multihead_attention_helper
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index 47f462d75fcc4..aa8b5b5f608fa 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
+ rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
+ num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+
+ if (rotary_embedding_dim > 0) {
+ ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
+ }
}
template
@@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
+ num_heads,
+ rotary_embedding_dim,
¶meters));
Tensor* output = context->Output(0, input->Shape());
@@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int num_heads = parameters.num_heads;
+ const int n_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
- const int half_head_size = head_size / 2;
+ const int rotary_emb_dim = parameters.rotary_embedding_dim;
+ const int half_rotary_emb_dim = rotary_emb_dim / 2;
+
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
- int seq_stride = num_heads * head_stride;
+ int seq_stride = n_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
- // Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
+ // Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
- batch_stride = num_heads * head_stride;
+ batch_stride = n_heads * head_stride;
}
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();
- const int loop_len = batch_size * sequence_length * num_heads;
- const double cost = static_cast(head_size);
+ const int loop_len = batch_size * sequence_length * n_heads;
+ const double cost = static_cast(rotary_emb_dim);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
- const int b = static_cast((ptr / num_heads) / sequence_length);
- const int s = static_cast((ptr / num_heads) % sequence_length);
- const int n = static_cast(ptr % num_heads);
+ const int b = static_cast((ptr / n_heads) / sequence_length);
+ const int s = static_cast((ptr / n_heads) % sequence_length);
+ const int n = static_cast(ptr % n_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;
- // Cache is (M, H/2)
+ // Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast(pos_ids_data[0]) + s
: static_cast(pos_ids_data[b * sequence_length + s]);
- const int cache_offset = position_id * half_head_size;
+ const int cache_offset = position_id * half_rotary_emb_dim;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;
int cache_idx = 0;
T sign = 0;
int j = 0;
- for (int i = 0; i < head_size; i++) {
+ for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
- cache_idx = (i / 2) % half_head_size;
+ cache_idx = (i / 2) % half_rotary_emb_dim;
sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
- cache_idx = i % half_head_size;
- sign = (i < half_head_size) ? static_cast(-1) : static_cast(1);
- j = (i + half_head_size) % head_size;
+ cache_idx = i % half_rotary_emb_dim;
+ sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1);
+ j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
+ for (int i = rotary_emb_dim; i < head_size; i++) {
+ output_data[i] = input_data[i];
+ }
}
});
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
index be834a66cdc69..4e32424a22b6c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
@@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel {
protected:
float scale;
+ int num_heads;
+ int rotary_embedding_dim;
bool interleaved;
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
index 7b2e8289f7b06..dcbb36d1c4a3c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
@@ -11,14 +11,15 @@ namespace rotary_embedding_helper {
// Parameters deduced from node attributes and inputs/outputs.
struct RotaryParameters {
- int batch_size; // Batch size used by input
- int sequence_length; // Sequence length used by input
- int hidden_size; // Hidden size used by input
- int head_size; // Head size used by cos/sin cache * 2
- int num_heads; // num_heads = hidden_size / head_size
- int max_sequence_length; // Sequence length used by cos/sin cache
- int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
- bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
+ int batch_size; // Batch size used by input
+ int sequence_length; // Sequence length used by input
+ int hidden_size; // Hidden size used by input
+ int head_size; // Head size
+ int rotary_embedding_dim; // Rotary embedding dimension.
+ int num_heads; // num_heads = hidden_size / head_size
+ int max_sequence_length; // Sequence length used by cos/sin cache
+ int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
+ bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};
template
@@ -26,11 +27,13 @@ Status CheckInputs(const T* input,
const T* position_ids,
const T* cos_cache,
const T* sin_cache,
+ int num_heads,
+ int rotary_embedding_dim,
void* parameters) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
- // cos cache : (max_sequence_length, head_size / 2)
- // sin cache : (max_sequence_length, head_size / 2)
+ // cos cache : (max_sequence_length, rotary_embedding_dim / 2)
+ // sin cache : (max_sequence_length, rotary_embedding_dim / 2)
// Check input
const auto& input_dims = input->Shape().GetDims();
@@ -60,6 +63,12 @@ Status CheckInputs(const T* input,
"the same shape");
}
+ // Check num_heads and rotary_embedding_dim
+ if (rotary_embedding_dim > 0 && num_heads == 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ",
+ "specified");
+ }
+
// Get attributes from inputs
int batch_size = static_cast(input_dims[0]);
int sequence_length = static_cast(input_dims[1]);
@@ -73,8 +82,13 @@ Status CheckInputs(const T* input,
transposed = true;
}
int max_sequence_length = static_cast(cos_cache_dims[0]);
- int head_size = static_cast(cos_cache_dims[1]) * 2;
- int num_heads = hidden_size / head_size;
+ int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2
+ : static_cast(hidden_size / num_heads);
+ if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
+ "head_size");
+ }
+
int position_ids_format = -1;
// Check position_ids input shapes
@@ -91,23 +105,15 @@ Status CheckInputs(const T* input,
} else {
position_ids_format = 0;
}
+
// Check cos_cache input shapes
if (max_sequence_length != static_cast(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
- if ((head_size / 2) != static_cast(cos_cache_dims[1])) {
+ if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
- "head_size / 2, got ", cos_cache_dims[1]);
- }
- // Check sin_cache input shapes
- if (max_sequence_length != static_cast(sin_cache_dims[0])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ",
- "max_sequence_length, got ", sin_cache_dims[0]);
- }
- if ((head_size / 2) != static_cast(sin_cache_dims[1])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ",
- "head_size / 2, got ", sin_cache_dims[1]);
+ "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
}
// Set rotary parameters
@@ -117,10 +123,11 @@ Status CheckInputs(const T* input,
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
- output_parameters->num_heads = num_heads;
+ output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size);
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
+ output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
index 406c73c95d444..72948c74d7877 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
@@ -9,6 +9,9 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
+#ifdef ORT_NEURAL_SPEED
+#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
+#endif
namespace onnxruntime {
namespace contrib {
@@ -24,15 +27,17 @@ class MatMulNBits final : public OpKernel {
accuracy_level_{info.GetAttr("accuracy_level")} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
- is_asym_ = info.GetInputCount() >= 4;
+#ifdef ORT_NEURAL_SPEED
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
const Tensor* tensor_zero_point = nullptr;
bool B_constant = info.TryGetConstantInput(1, &tensor_B);
bool scale_constant = info.TryGetConstantInput(2, &tensor_scale);
bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point);
+ is_asym_ = info.GetInputCount() >= 4;
all_constant_ = B_constant && scale_constant;
all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_;
+#endif
}
Status Compute(OpKernelContext* context) const override;
@@ -53,30 +58,34 @@ class MatMulNBits final : public OpKernel {
const bool column_wise_quant_{true};
IAllocatorUniquePtr packed_b_;
size_t packed_b_size_{0};
+#ifdef ORT_NEURAL_SPEED
bool is_asym_{false};
bool all_constant_{false};
+#endif
};
Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;
+#ifdef ORT_NEURAL_SPEED
if (!all_constant_) {
return Status::OK();
}
-
-#if defined(MLAS_JBLAS)
-
- auto compt_type = static_cast(accuracy_level_);
MLAS_THREADPOOL* pool = NULL;
+ if (nbits_ != 4) {
+ return Status::OK();
+ }
+ auto comp_type = static_cast(accuracy_level_);
+ auto nbits = static_cast(nbits_);
if (input_idx == 1) {
- packed_b_size_ = MlasNBitsGemmPackBSize(N_, K_, block_size_, static_cast(nbits_), is_asym_, compt_type);
+ packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.Data();
packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true);
std::memset(packed_b_.get(), 0, packed_b_size_);
- MlasNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_),
- is_asym_, false, compt_type, pool);
+ NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false,
+ comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
@@ -85,8 +94,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
if (input_idx == 2 && packed_b_ != nullptr) {
auto sptr = tensor.Data();
- MlasNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, static_cast(nbits_),
- is_asym_, !is_asym_, compt_type, pool);
+ NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_,
+ comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
@@ -95,8 +104,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
if (input_idx == 3 && packed_b_ != nullptr) {
auto zptr = tensor.Data();
- MlasNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, static_cast(nbits_),
- is_asym_, is_asym_, compt_type, pool);
+ NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_,
+ comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
@@ -104,7 +113,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}
-#else // defined(MLAS_JBLAS)
+#else // defined(ORT_NEURAL_SPEED)
if (input_idx == 1) {
packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_);
@@ -119,7 +128,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}
-#endif // defined(MLAS_JBLAS)
+#endif // defined(ORT_NEURAL_SPEED)
return Status::OK();
}
@@ -127,9 +136,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;
-
-#if defined(MLAS_JBLAS)
-
+#ifdef ORT_NEURAL_SPEED
// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
@@ -144,14 +151,14 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep
packed_b_ = std::move(prepacked_buffers[0]);
}
-#else // defined(MLAS_JBLAS)
+#else // defined(ORT_NEURAL_SPEED)
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
-#endif // defined(MLAS_JBLAS)
+#endif // defined(ORT_NEURAL_SPEED)
return Status::OK();
}
@@ -160,9 +167,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input(0);
const auto* a_data = a->Data();
-
-#if defined(MLAS_JBLAS)
-
+#ifdef ORT_NEURAL_SPEED
if (packed_b_.get()) {
TensorShape b_shape({static_cast(N_), static_cast(K_)});
@@ -181,7 +186,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t N = static_cast(helper.N());
const size_t K = static_cast(helper.K());
const size_t lda = helper.Lda(false);
- std::vector gemm_params(max_len);
+ std::vector gemm_params(max_len);
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
@@ -192,15 +197,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
- auto ws_size = MlasSQNBitsGemmBatchPackedBWorkspaceSize(M, N, K, max_len, gemm_params.data());
+ auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size);
- MlasSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(),
- thread_pool);
+ NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool);
return Status::OK();
}
-#endif // defined(MLAS_JBLAS)
+#endif // defined(ORT_NEURAL_SPEED)
const Tensor* scales = ctx->Input(2);
const Tensor* zero_points = ctx->Input(3);
diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h
new file mode 100644
index 0000000000000..864abffd131fe
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h
@@ -0,0 +1,45 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+--*/
+
+#pragma once
+
+#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h"
+
+namespace bestla {
+
+using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>;
+using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>;
+using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>;
+using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>;
+using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>;
+using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>;
+using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>;
+using tAVX2 = gemm::SCoreRowNAvx2<24, 4>;
+using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>;
+using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>;
+using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>;
+using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>;
+
+template
+using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger;
+template
+using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat;
+
+class ORTThreading : public parallel::IThreading {
+ public:
+ explicit ORTThreading(void* tp);
+ void parallel_for(const parallel::thread_func& func) const override;
+ void set_threads(int nthreads) override {
+ (void)(nthreads);
+ assert(0);
+ }
+ void sync() const override { assert(0); }
+ void* mTp;
+};
+
+} // namespace bestla
diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc
new file mode 100644
index 0000000000000..73aaa4ae61a6e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc
@@ -0,0 +1,438 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ neural_speed_gemm.cpp
+
+Abstract:
+
+ GEMM template combinations of neural_speed.
+--*/
+
+#include "contrib_ops/cpu/quantization/neural_speed_defs.h"
+#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
+#include "core/platform/threadpool.h"
+
+using ThreadPool = onnxruntime::concurrency::ThreadPool;
+
+namespace bestla {
+
+ORTThreading::ORTThreading(void* tp)
+ : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {}
+
+void ORTThreading::parallel_for(const parallel::thread_func& func) const {
+ ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum,
+ [&](ptrdiff_t tid) { func(static_cast(tid)); });
+}
+
+template
+static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda,
+ storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace,
+ parallel::IThreading* th) {
+ auto M_ = static_cast(M);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto lda_ = static_cast(lda);
+ auto ldc_ = static_cast(ldc);
+ utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize);
+ if (M <= 16) {
+ using Parallel = parallel::gemm::SchedulerKBlock;
+ using Launcher =
+ wrapper::gemm::LauncherKBlock;
+ static Launcher kernel;
+ auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize);
+ if (B->IsAsym()) {
+ reduceA.assign(WorkSpace);
+ ORTThreading single(nullptr);
+ kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single);
+ }
+ typename Launcher::Param args{gp,
+ {A, lda_, &reduceA},
+ {B},
+ {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(),
+ reduceA.template RPtr(), reduceA.lda},
+ {C, ldc_, nullptr}};
+ parallel::GemmRun(kernel, args, th);
+ } else {
+ using Parallel = parallel::gemm::SchedulerBase;
+ using Launcher =
+ wrapper::gemm::LauncherBase;
+ static Launcher kernel;
+ typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}};
+ parallel::GemmRun(kernel, args, th);
+ }
+}
+
+template
+static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda,
+ storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace,
+ parallel::IThreading* th) {
+ using Parallel = parallel::gemm::SchedulerKBlockS;
+ using Launcher =
+ wrapper::gemm::LauncherIntKBlock;
+ auto M_ = static_cast(M);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto lda_ = static_cast(lda);
+ auto ldc_ = static_cast(ldc);
+ static Launcher kernel;
+ auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym());
+ quanA.assign(WorkSpace);
+ if (M <= 16) {
+ ORTThreading single(nullptr);
+ kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single);
+ } else {
+ kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th);
+ }
+ utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize);
+ typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}};
+ parallel::GemmRun(kernel, args, th);
+}
+
+template
+static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda,
+ storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) {
+ auto M_ = static_cast(M);
+ auto K_ = static_cast(K);
+ (void)(A);
+ (void)(N);
+ (void)(C);
+ (void)(lda);
+ (void)(ldc);
+ if (M <= 16) {
+ using ProA = prologue_a::gemm::ActivationKBlockBaseF32;
+ static ProA proA;
+ if (B->IsAsym()) {
+ auto reduceA = proA.createStorage(M_, K_, B->mBlockSize);
+ return reduceA.mSize;
+ }
+ return 0;
+ } else {
+ // using ProA = prologue_a::gemm::ActivationBase;
+ return 0;
+ }
+}
+
+template
+static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda,
+ storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) {
+ (void)(N);
+ (void)(lda);
+ (void)(ldc);
+ (void)(A);
+ (void)(C);
+ using ProA = prologue_a::gemm::ActivationF32KBlockQuantize;
+ static ProA proA;
+ auto quanA =
+ proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym());
+ return quanA.mSize;
+}
+
+} // namespace bestla
+
+using namespace bestla;
+
+static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN,
+ const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace,
+ void* ThreadPool) {
+ GetCPUDevice();
+ bestla::ORTThreading orth(ThreadPool);
+ bool processed = true;
+ for (size_t i = 0; i < BatchN; i++) {
+ auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
+ auto uptr = std::unique_ptr(ptr);
+ if (ptr) {
+ auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
+ auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
+ auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
+ auto btype = static_cast(gemm::CompTypeHelper::get_B(CType));
+ if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
+ auto kptr = reinterpret_cast(ptr);
+ auto BlkSize = kptr->mBlockSize;
+ if (btype == gemm::CompType::tFP32 && PackRow == 1) {
+ if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
+ DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth);
+ } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C,
+ DataParams[i].ldc, WorkSpace, &orth);
+ }
+ }
+ if (btype == gemm::CompType::tS8 && PackRow == 4) {
+ if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() &&
+ BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
+ bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
+ DataParams[i].C, DataParams[i].ldc, WorkSpace,
+ &orth);
+ } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() &&
+ BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
+ bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
+ DataParams[i].C, DataParams[i].ldc, WorkSpace,
+ &orth);
+ } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() &&
+ BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
+ bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
+ DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth);
+ }
+ }
+ }
+ } else {
+ processed = false;
+ break;
+ }
+ }
+ return processed;
+}
+
+static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN,
+ const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) {
+ GetCPUDevice();
+ size_t size = 0;
+ for (size_t i = 0; i < BatchN; i++) {
+ auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B);
+ auto uptr = std::unique_ptr(ptr);
+ if (ptr) {
+ if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
+ auto kptr = reinterpret_cast(ptr);
+ auto NTile =
+ gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
+ auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
+ auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
+ auto btype = static_cast(gemm::CompTypeHelper::get_B(CType));
+ auto BlkSize = kptr->mBlockSize;
+ if (btype == gemm::CompType::tFP32 && PackRow == 1) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
+ DataParams[i].C, DataParams[i].ldc),
+ size);
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr,
+ DataParams[i].C, DataParams[i].ldc),
+ size);
+ }
+ }
+ if (btype == gemm::CompType::tS8 && PackRow == 4) {
+ if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
+ size = std::max(NSSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
+ size);
+ } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() &&
+ BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
+ size = std::max(NSSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
+ size);
+ } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
+ size = std::max(NSSQ4GemmCompInt8WorkspaceSize(
+ M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc),
+ size);
+ }
+ }
+ }
+ }
+ }
+ return size;
+}
+
+template
+static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) {
+ static T proB;
+ auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size),
+ BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym);
+ // TODO(Yu) support more scale dtype
+ return stor.mSize;
+}
+
+static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) {
+ auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf);
+ auto uptr = std::unique_ptr(ptr);
+ ORTThreading orth(ThreadPool);
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto ldb_ = static_cast(ldb);
+ GetCPUDevice();
+ if (ptr) {
+ auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT);
+ auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId);
+ auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId);
+ auto btype = static_cast(gemm::CompTypeHelper::get_B(CType));
+ if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) {
+ auto wptr = reinterpret_cast(ptr);
+ auto BlkSize = wptr->mBlockSize;
+ if (btype == gemm::CompType::tFP32 && PackRow == 1) {
+ if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ static tWeiNInt proB;
+ proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ static tWeiNInt proB;
+ proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
+ }
+ }
+ if (btype == gemm::CompType::tS8 && PackRow == 4) {
+ if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
+ static tWeiNInt proB;
+ proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() &&
+ BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
+ static tWeiNInt proB;
+ proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
+ } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
+ static tWeiNInt proB;
+ proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth);
+ }
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+template
+static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale,
+ const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb,
+ void* ThreadPool) {
+ static T proB;
+ auto N_ = static_cast(N);
+ auto K_ = static_cast(K);
+ auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32,
+ BTLA_DTYPE::BF16, IsAsym);
+ stor.assign(reinterpret_cast(PackedBuf));
+ ORTThreading orth(ThreadPool);
+ proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth);
+ if (lastCall) {
+ proB.reduceWeight(&stor, &orth);
+ }
+}
+
+static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) {
+ GetCPUDevice();
+ if (K % BlkSize != 0) {
+ return 0;
+ }
+ // from low precision to high precision
+ switch (CompType) {
+ case NSCompInt8:
+ if (!isAsym) { // asym int8 is not optimized, so fall through to others.
+ if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
+ return NSQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
+ return NSQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
+ return NSQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ }
+ [[fallthrough]];
+ case NSCompBf16:
+ case NSCompFp16:
+ case NSCompFp32:
+ case NSCompUndef:
+ if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ return NSQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ return NSQ4BuSize>(BlkSize, N, K, isAsym);
+ }
+ [[fallthrough]];
+ default:
+ return 0;
+ }
+}
+
+static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N,
+ size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall,
+ NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) {
+ GetCPUDevice();
+ // explicit statement fall through.
+ switch (CompType) {
+ case NSCompInt8:
+ if (!isAsym) { // asym int8 is not optimized, so fall through to others.
+ if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) {
+ NSQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool);
+ return true;
+ }
+ if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) {
+ NSQ4GemmPackBImpl>(
+ PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool);
+ return true;
+ }
+ if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) {
+ NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N,
+ K, isAsym, lastCall, ldb, ThreadPool);
+ return true;
+ }
+ }
+ [[fallthrough]];
+ case NSCompBf16:
+ case NSCompFp16:
+ case NSCompFp32:
+ case NSCompUndef:
+ if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) {
+ NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym,
+ lastCall, ldb, ThreadPool);
+ return true;
+ }
+ if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) {
+ NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall,
+ ldb, ThreadPool);
+ return true;
+ }
+ [[fallthrough]];
+ default:
+ return false;
+ }
+}
+
+size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym,
+ NS_SQNBIT_COMPUTE_TYPE CompType) {
+ if (nbits == 4) {
+ auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType);
+ if (jsize) {
+ return jsize;
+ }
+ }
+ return 0;
+}
+
+void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K,
+ size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall,
+ NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) {
+ if (nbits == 4) {
+ if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) {
+ return;
+ }
+ }
+}
+
+void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) {
+ // only nbits=4 can be packed, so not necessary to check the nbits in DataParams
+ if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) {
+ return;
+ }
+}
+
+size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN,
+ const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) {
+ // only nbits=4 can be packed, so not necessary to check the nbits in DataParams
+ return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams);
+}
+
+void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN,
+ const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace,
+ void* ThreadPool) {
+ // only nbits=4 can be packed, so not necessary to check the nbits in DataParams
+ if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) {
+ // PackedWeight is created by bestla
+ return;
+ }
+}
diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h
new file mode 100644
index 0000000000000..ebcb3027a209f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h
@@ -0,0 +1,129 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ neural_speed_gemm.h
+
+Abstract:
+
+ Prepack-weight GEMM APIs of neural_speed.
+--*/
+
+#pragma once
+
+#include
+#include
+
+/**
+ * @brief Define compute types of block quantization
+ */
+enum NS_SQNBIT_COMPUTE_TYPE {
+ NSCompUndef = 0, /*!< undef */
+ NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */
+ NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */
+ NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */
+ NSCompInt8 = 4 /*!< input int8, accumulator int32 */
+};
+
+/**
+ * @brief Data parameters for NBits GEMM routine
+ * C = A * B
+ * A, C must be a float32 matrix
+ * B must be a packed nbits blob
+ * All except C are [in] parameters
+ */
+struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS {
+ const float* A = nullptr; /**< address of A (float32 matrix)*/
+ const void* B = nullptr; /**< address of B (packed nbits blob)*/
+ float* C = nullptr; /**< address of result matrix */
+ size_t lda = 0; /**< leading dimension of A */
+ size_t ldc = 0; /**< leading dimension of C*/
+};
+
+/**
+ * @brief Compute the byte size of the parameter combination
+ *
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param block_size size of the block to quantize, elements from the same block share the same
+ * scale and zero point
+ * @param nbits number of bits used for weight quantization
+ * @param is_asym flag for asymmetric quantization
+ * @param comp_type specify input data type and accumulator data type
+ * @return size of the packing buffer, 0 if the operation is not yet supported.
+ */
+size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym,
+ NS_SQNBIT_COMPUTE_TYPE comp_type);
+
+/**
+ * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers.
+ *
+ * @param PackedBuf packed data buffer
+ * @param QData quantized data buffer
+ * @param Scale scale pointer
+ * @param Zp zero point pointer
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param ldb leading dimension of B
+ * @param block_size size of the block to quantize, elements from the same block share the same
+ * scale and zero point
+ * @param nbits number of bits used for weight quantization (default 4)
+ * @param is_asym flag for asymmetric quantization
+ * @param comp_type specify input data type and accumulator data type
+ * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor
+ * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where
+ * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up
+ * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale
+ * (is_asym is false) and Zp(is_asym is true).
+ * @param thread_pool
+ */
+void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K,
+ size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call,
+ NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool);
+
+/**
+ * @brief Unpack and dequantize to fp32
+ *
+ * @param FpData unpacked float32 data
+ * @param PackedBuf quantized and packed data
+ * @param N the number of columns of matrix B.
+ * @param K the number of rows of matrix B.
+ * @param ldb leading dimension of B
+ * @param thread_pool
+ */
+void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool);
+
+/**
+ * @brief Get the workspace size required by computation.
+ *
+ * @param[in] M row size of matrix A and C
+ * @param[in] N column size of matrix B and C
+ * @param[in] K column size of matrix A and row size of matrix B
+ * @param[in] BatchN number of batches
+ * @param[inout] DataParams An array (size BatchN) of parameter blocks
+ * @return Workspace size in bytes
+ */
+size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN,
+ const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams);
+
+/**
+ * @brief Batched GEMM: C = A * B
+ * A, C must be a float32 matrix
+ * B must be a packed nbits blob
+ *
+ * @param[in] M row size of matrix A and C
+ * @param[in] N column size of matrix B and C
+ * @param[in] K column size of matrix A and row size of matrix B
+ * @param[in] BatchN number of batches
+ * @param[inout] DataParams An array (size BatchN) of parameter blocks
+ * @param[in] WorkSpace temporary buffer
+ * @param[in] ThreadPool
+ * @return
+ */
+void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN,
+ const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace,
+ void* ThreadPool = nullptr);
diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h
new file mode 100644
index 0000000000000..d3902f9bd68c7
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h
@@ -0,0 +1,39 @@
+//-----------------------------------------------------------------------------
+//
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+//
+//-----------------------------------------------------------------------------
+#pragma once
+#if defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#pragma GCC diagnostic ignored "-Wsign-compare"
+#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+#pragma GCC diagnostic ignored "-Wunused-variable"
+#pragma GCC diagnostic ignored "-Wunused-value"
+#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
+#pragma GCC diagnostic ignored "-Wunused-function"
+#pragma GCC diagnostic ignored "-Wuninitialized"
+#pragma GCC diagnostic ignored "-Wclass-memaccess"
+#pragma GCC diagnostic ignored "-Wunused-but-set-variable"
+#pragma GCC diagnostic ignored "-Wunused-but-set-parameter"
+
+#elif defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4457)
+#pragma warning(disable : 4189)
+#pragma warning(disable : 4100)
+#pragma warning(disable : 4244)
+#pragma warning(disable : 4267)
+#pragma warning(disable : 4702)
+#endif
+
+#include "bestla/bestla_prologue_a.h"
+#include "bestla/bestla_wrapper.h"
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#elif defined(_MSC_VER)
+#pragma warning(pop)
+#endif
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
index 56d950ca2f41e..dc72a038c3d58 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
@@ -397,12 +397,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch
output_sequences_scores);
// Output per token scores
- if (output_scores) {
- gsl::span target = output_scores->MutableDataAsSpan();
- gsl::span source = beam_state.scores;
- assert(target.size() == source.size());
- ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
- }
+ gsl::span per_token_scores = beam_state.scores;
+ this->beam_scorer_->OutputScores(per_token_scores, output_scores);
return status;
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
index 94547887d3a90..cd891a9508019 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
@@ -404,12 +404,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches
output_sequences_scores);
// Output per token scores
- if (output_scores) {
- gsl::span target = output_scores->MutableDataAsSpan();
- gsl::span source = beam_state.scores;
- assert(target.size() == source.size());
- ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
- }
+ gsl::span per_token_scores = beam_state.scores;
+ this->beam_scorer_->OutputScores(per_token_scores, output_scores);
return status;
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
index 91b93a125ad7a..4d6643c68a98b 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
@@ -500,12 +500,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe
output_sequences_scores);
// Output per token scores
- if (output_scores) {
- gsl::span target = output_scores->MutableDataAsSpan();
- gsl::span source = beam_state.scores;
- assert(target.size() == source.size());
- ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
- }
+ gsl::span per_token_scores = beam_state.scores;
+ this->beam_scorer_->OutputScores(per_token_scores, output_scores);
return status;
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
index 7e2e5b2129221..0eccbe26605f5 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
@@ -50,11 +50,12 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con
return beams_.back().score < current_score;
}
+template
void BeamHypotheses::Output(
int top_k,
int max_length,
- gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
- gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty
+ gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
+ gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty
{
// Copy the top_k beams into the sequences
ORT_ENFORCE(top_k <= beams_used_);
@@ -67,7 +68,7 @@ void BeamHypotheses::Output(
gsl::copy(item.hypothesis, target);
if (!sequences_scores.empty())
- sequences_scores[index] = item.score;
+ sequences_scores[index] = (T)item.score;
}
}
@@ -181,21 +182,21 @@ void BeamSearchScorer::Process(ISequences& sequences,
}
}
-void BeamSearchScorer::Finalize(ISequences& sequences,
- gsl::span& final_beam_scores,
- Tensor* output_sequences,
- Tensor* output_sequence_scores) {
- ORT_ENFORCE(output_sequences != nullptr);
-
+template
+void OutputSequenceScores(BeamSearchScorer* scorer,
+ ISequences& sequences,
+ gsl::span& final_beam_scores,
+ Tensor* output_sequences,
+ Tensor* output_sequence_scores) {
// Finalize all open beam hypotheses and add to generated hypotheses.
- for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) {
- BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
+ for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) {
+ BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index];
if (beam_hyp.done_) {
continue;
}
- for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) {
- size_t batch_beam_index = batch_index * num_beams_ + beam_index;
+ for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) {
+ size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index;
float final_score = final_beam_scores[batch_beam_index];
auto final_tokens = sequences.GetSequence(narrow(batch_beam_index));
beam_hyp.Add(final_tokens, final_score);
@@ -206,26 +207,59 @@ void BeamSearchScorer::Finalize(ISequences& sequences,
gsl::span output = output_sequences->MutableDataAsSpan();
// Fill output sequences with pad token ID so that we do not need append it later.
- std::fill_n(output.data(), output.size(), pad_token_id_);
+ std::fill_n(output.data(), output.size(), scorer->pad_token_id_);
// Score of each sequence, with shape (batch_size * num_return_sequences).
- gsl::span sequence_scores;
+ gsl::span sequence_scores;
if (output_sequence_scores) {
- sequence_scores = output_sequence_scores->MutableDataAsSpan();
+ sequence_scores = output_sequence_scores->MutableDataAsSpan();
}
// Select the best hypotheses according to number of sequences to return.
- for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) {
- BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
+ for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) {
+ BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index];
- auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_,
- num_return_sequences_ * max_length_);
- gsl::span sequence_scores_buffer;
+ auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_,
+ scorer->num_return_sequences_ * scorer->max_length_);
+ gsl::span sequence_scores_buffer;
if (!sequence_scores.empty())
- sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_);
+ sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_);
+
+ beam_hyp.template Output(narrow(scorer->num_return_sequences_), narrow(scorer->max_length_), batch_output,
+ sequence_scores_buffer);
+ }
+}
+
+void BeamSearchScorer::Finalize(ISequences& sequences,
+ gsl::span& final_beam_scores,
+ Tensor* output_sequences,
+ Tensor* output_sequence_scores) {
+ ORT_ENFORCE(output_sequences != nullptr);
- beam_hyp.Output(narrow(num_return_sequences_), narrow(max_length_), batch_output,
- sequence_scores_buffer);
+ if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) {
+ OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores);
+ } else {
+ ORT_ENFORCE(output_sequence_scores->IsDataType());
+ OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores);
+ }
+}
+
+void BeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) {
+ if (output_scores) {
+ if (output_scores->IsDataType()) {
+ gsl::span target = output_scores->MutableDataAsSpan();
+ ORT_ENFORCE(target.size() == final_scores.size());
+ std::copy_n(final_scores.data(), final_scores.size(), target.data());
+ } else {
+ ORT_ENFORCE(output_scores->IsDataType());
+ gsl::span target = output_scores->MutableDataAsSpan();
+ ORT_ENFORCE(target.size() == final_scores.size());
+ const float* src = final_scores.data();
+ MLFloat16* dst = target.data();
+ for (size_t i = 0; i < target.size(); i++) {
+ dst[i] = MLFloat16(src[i]);
+ }
+ }
}
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
index 94b6d340d9f4a..dc92e8038a68e 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
@@ -35,10 +35,11 @@ struct BeamHypotheses {
bool CanImprove(float best_sum_logprobs, int current_length) const;
// Output results
- void Output(int top_k, // number of sequences to return
- int max_length, // max sequence length
- gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
- gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
+ template
+ void Output(int top_k, // number of sequences to return
+ int max_length, // max sequence length
+ gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
+ gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring
int beams_used_; // Number of elements used in beams_
@@ -60,13 +61,14 @@ struct BeamSearchScorer : IBeamScorer {
Tensor* output_sequences,
Tensor* output_sequence_scores) override;
+ void OutputScores(gsl::span& final_scores, Tensor* output_scores) override;
+
bool IsDone() const override { return not_done_count_ == 0; }
gsl::span GetNextScores() override { return next_beam_scores_; }
gsl::span GetNextTokens() override { return next_beam_tokens_; }
gsl::span GetNextIndicesCPU() override { return next_beam_indices_; }
- private:
size_t batch_size_;
size_t num_beams_;
size_t max_length_;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index f6faf2e325f8f..cb62e2f7bf4da 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -120,6 +120,9 @@ struct IBeamScorer {
Tensor* output_sequences,
Tensor* output_sequence_scores) = 0;
+ virtual void OutputScores(gsl::span& final_scores,
+ Tensor* output_scores) = 0;
+
virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event
virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index ebd66d8c6528e..f978f50c6851f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault("scale", 0.0f);
+ is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
+ ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false);
@@ -105,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
num_heads_,
mask_filter_value_,
scale_,
+ is_unidirectional_,
false, // past_present_share_buffer
false, // dmmha_packing
device_prop.maxThreadsPerBlock));
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
index c162f7133cc1c..86a32c92ce003 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
@@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel {
int num_heads_; // number of attention heads
float mask_filter_value_;
float scale_;
+ bool is_unidirectional_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
index 2d12e975d88d7..9de7ba3885c3c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -29,10 +29,13 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
+ rotary_embedding_dim = static_cast(info.GetAttrOrDefault