From a24273e5bb497349bbe6f4a210f9179378549e65 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Jan 2024 07:53:26 +0800 Subject: [PATCH 01/51] [js/webgpu] Add HardSigmoid support (#19215) ### Description This op is required in mobilenetv3-small-100. With this PR, mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on ADL. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 20 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 6 +++--- .../providers/js/js_execution_provider.cc | 2 ++ .../core/providers/js/operators/unary.cc | 3 +++ 6 files changed, 30 insertions(+), 3 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2f510308d9306..2557971eb4ded 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -52,6 +52,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 90e02da986b8f..cc504093ca0d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index a25e7fe4229b4..82311d72e58b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; +export interface HardSigmoidAttributes extends AttributeWithCacheKey { + readonly alpha: number; + readonly beta: number; +} + +export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => + createAttributeWithCacheKey(attributes as { + alpha: number; + beta: number; + }); + +export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'HardSigmoid', + a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta})))`, + undefined, attributes.cacheKey)); +}; + export const sin = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 033b3b3f4b0f5..373b3c645df57 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -597,9 +597,9 @@ // // "test_hardmax_example", // // "test_hardmax_negative_axis", // // "test_hardmax_one_hot", - // // "test_hardsigmoid_default", - // // "test_hardsigmoid_example", - // // "test_hardsigmoid", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", "test_if", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c2ff2ebc39e13..af9658271d210 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log); @@ -392,6 +393,7 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(13, Erf), KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), KERNEL_CREATE_INFO(13, Log), diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 78563d30b0136..9082527e3a8d7 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5) +JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid) + JSEP_KERNEL_IMPL(Log, Log) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) From c44d4977043ab066644d2114d9a2837b5c72b314 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Jan 2024 08:08:55 +0800 Subject: [PATCH 02/51] [js/webgpu] set query type in onRunStart (#19202) ### Description `env.webgpu.profiling` is a global flag. It may change before each session.run. So the best place is to update it in `onRunStart` event. After this, we can directly check `this.queryType`'s value. Without this pr, we need to make sure that `getCommandEncoder()` is called before checking `this.queryType`. Otherwise, it may happen that `pendingKernels`'s length is not equal to `pendingDispatchNumber`'s length. See the two ugly workarounds [1)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268) and [2)](https://github.com/microsoft/onnxruntime/pull/18989/commits/e630dbf528fc3a955702cceb968930d0abdfc652#diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80) if we don't introduce `onRunStart`. Or we need to call `setQueryType` in each kernel run. --- js/web/lib/wasm/binding/ort-wasm.d.ts | 4 ++++ js/web/lib/wasm/jsep/backend-webgpu.ts | 9 +++++---- js/web/lib/wasm/wasm-core-impl.ts | 2 +- onnxruntime/wasm/js_internal_api.js | 3 +++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 9d4d5875310b7..68054210e79a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. + */ + jsepOnRunStart: () => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 2956ec1cad4da..afef7042a4280 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -208,7 +208,7 @@ export class WebGpuBackend { Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - // init queryType, which is necessary for createKernel + // init queryType, which is necessary for InferenceSession.create this.setQueryType(); } @@ -223,8 +223,6 @@ export class WebGpuBackend { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - // refresh queryType, as sometimes we only need to enable query for a specific run - this.setQueryType(); if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { this.querySet = this.device.createQuerySet({ type: 'timestamp', @@ -639,6 +637,7 @@ export class WebGpuBackend { return createView(data.buffer, type); }; } + // #endregion writeTimestamp(index: number): void { if (this.queryType !== 'inside-passes') { return; @@ -657,5 +656,7 @@ export class WebGpuBackend { } } } - // #endregion + onRunStart(): void { + this.setQueryType(); + } } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5821fac3c468f..8768643fa7257 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -488,8 +488,8 @@ export const run = async( } } + wasm.jsepOnRunStart?.(); let errorCode: number; - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 25ece9c700d5d..7c70515e73eab 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepOnRunStart'] = () => { + return backend['onRunStart'](); + }; }; From 254b543f010755132640fd33438663fbb7703716 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 24 Jan 2024 00:25:05 +0800 Subject: [PATCH 03/51] [js/webgpu] Add FusedConv clip test case (#18900) Bug: https://github.com/microsoft/onnxruntime/issues/18899 --- js/web/test/data/ops/fused-conv.jsonc | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 812e9d7c2def0..ad1c0a72c11d3 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -108,5 +108,39 @@ ] } ] + }, + { + "name": "fused conv with clip", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [400, 470, 600, 600], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] From 7282e230cb68ec4a411cd792952bd45f7e936e7f Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Thu, 25 Jan 2024 01:12:21 +0530 Subject: [PATCH 04/51] [JS/WebGPU] Added Uniforms to SkipLayerNorm. (#18788) ### Description Added Uniforms to SkipLayerNorm ### Motivation and Context Improve performance --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 123 ++++++++++-------- 2 files changed, 69 insertions(+), 58 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index cc504093ca0d7..d737a28654220 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -25,7 +25,7 @@ import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; -import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; +import {skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; @@ -116,7 +116,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], ['Slice', [slice, parseSliceAttributes]], - ['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]], + ['SkipLayerNormalization', [skipLayerNorm]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Softmax', [softmax, parseSoftmaxAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index a2fda9f07d09f..509a722f4b52a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -4,10 +4,10 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo = const hasInputSkipBiasSumOutput = outputCount > 3; const components = getMaxComponents(hiddenSize); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanOutput) { - variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: f32 = ${hiddenSize}; - const hiddenSizeVectorized: u32 = ${hiddenSize / components}; - const epsilon: f32 = ${attributes.epsilon}; - ${shaderHelper.declareVariables(...variables)} + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, + {type: 'uint32', data: components}, + {type: 'uint32', data: hiddenSize}, + {type: 'float32', data: attributes.epsilon}, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniformsArray: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'components', type: 'u32'}, + {name: 'hidden_size', type: 'u32'}, + {name: 'epsilon', type: 'f32'}, + ]; + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + + ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')} + let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components; + let offset = global_idx * hidden_size_vectorized; var sum = ${fillVector('f32', components)}; var squareSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - let skipValue = skip[offset + i]; - let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; - let inputValue = x[offset + i]; - let value = inputValue + skipValue + biasValue; - ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + let skip_value = skip[offset + i]; + let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'}; + let input_value = x[offset + i]; + let value = input_value + skip_value + bias_value; + ${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''} output[offset + i] = value; - let f32Value = ${castToF32(dataType, components, 'value')}; - sum += f32Value; - squareSum += f32Value * f32Value; + let f32_value = ${castToF32(dataType, components, 'value')}; + sum += f32_value; + squareSum += f32_value * f32_value; } - let mean = ${sumVector('sum', components)} / hiddenSize; - let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); - ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} - ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''} - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i] - + ${hasBetaInput ? 'beta[i]' : '0.0'}; + let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); + let inv_std_dev = inverseSqrt(${ + sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon); + ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} + ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${ + hasBetaInput ? 'beta[i]' : '0.0'}; } }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (outputCount > 1) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo = if (outputCount > 3) { outputs.push({dims: inputShape, dataType: inputs[0].dataType}); } - return { name: 'SkipLayerNormalization', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: { + hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, + inputDependencies: inputs.map((_input, _index) => 'type') + }, getShaderSource, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}), }; }; @@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm context.compute( createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); }; - -export const parseSkipLayerNormAttributes = (attributes: Record): SkipLayerNormAttributes => { - const epsilon = attributes.epsilon as number; - return createAttributeWithCacheKey({epsilon}); -}; From 6bd8586fe7f479b76f70356b9476de022918436c Mon Sep 17 00:00:00 2001 From: Yang Gu Date: Thu, 25 Jan 2024 06:49:37 +0800 Subject: [PATCH 05/51] [js/webgpu] Fix issue of timestamp query (#19258) When we enable webgpu profiling mode between session.create and session.run, current implementation has a problem to create querySet (and also queryResolveBuffer) if we share the commandEncoder with inputs upload. This PR fixes this by moving the querySet creation to the place we set queryType. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index afef7042a4280..8ca025d66550c 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -222,16 +222,6 @@ export class WebGpuBackend { getCommandEncoder(): GPUCommandEncoder { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - - if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { - this.querySet = this.device.createQuerySet({ - type: 'timestamp', - count: this.maxDispatchNumber * 2, - }); - this.queryResolveBuffer = this.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); - } } return this.commandEncoder; } @@ -654,6 +644,16 @@ export class WebGpuBackend { } else if (this.device.features.has('timestamp-query')) { this.queryType = 'at-passes'; } + + if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.maxDispatchNumber * 2, + }); + this.queryResolveBuffer = this.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); + } } } onRunStart(): void { From 4eafe7364f6aac5162d582570ca9b24da5e5999b Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 25 Jan 2024 07:37:35 +0800 Subject: [PATCH 06/51] [WebNN EP] Support WebNN async API with Asyncify (#19145) --- js/web/lib/build-def.d.ts | 4 --- js/web/lib/index.ts | 4 +-- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 4 +-- js/web/script/build.ts | 7 +--- js/web/script/test-runner-cli-args.ts | 4 --- .../core/providers/webnn/builders/model.cc | 35 ++++++++----------- .../providers/webnn/builders/model_builder.cc | 12 +++---- .../webnn/webnn_execution_provider.cc | 3 +- onnxruntime/wasm/js_internal_api.js | 4 +++ 10 files changed, 30 insertions(+), 49 deletions(-) diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index b3868871a4753..2c9cd88a375bd 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -21,10 +21,6 @@ interface BuildDefinitions { /** * defines whether to disable the whole WebNN backend in the build. */ - readonly DISABLE_WEBNN: boolean; - /** - * defines whether to disable the whole WebAssembly backend in the build. - */ readonly DISABLE_WASM: boolean; /** * defines whether to disable proxy feature in WebAssembly backend in the build. diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index baf45e74addea..b212c0f49df3b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) { require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); + registerBackend('webnn', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.DISABLE_WEBNN) { - registerBackend('webnn', wasmBackend, 9); - } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 68054210e79a7..24d7062c85fcb 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; - _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; _OrtReleaseSession(sessionHandle: number): void; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 8768643fa7257..046336dc9cac0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); @@ -228,7 +228,7 @@ export const createSession = async( await Promise.all(loadingPromises); } - sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index ea0c122cb51de..d3652f3820357 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // /js/ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WEBGL': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'false', - 'BUILD_DEFS.DISABLE_WEBNN': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', @@ -364,7 +363,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -397,7 +395,7 @@ async function main() { // ort.webgpu[.min].js await addAllWebBuildTasks({ outputBundleName: 'ort.webgpu', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'}, + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, }); // ort.wasm[.min].js await addAllWebBuildTasks({ @@ -411,7 +409,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); // ort.wasm-core[.min].js @@ -421,7 +418,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -434,7 +430,6 @@ async function main() { 'BUILD_DEFS.DISABLE_TRAINING': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 8f6c5f6f04122..ed4dd76a6e315 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); - if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); - } - // Options: // --log-verbose=<...> // --log-info=<...> diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index eaf549ef4e072..ef807a8c4fa26 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap& inputs, "The input of graph has unsupported type, name: ", name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + // Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer. + // As Wasm ArrayBuffer is not detachable. wnn_inputs_[name].call("set", view); -#else - wnn_inputs_.set(name, view); -#endif } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // This vector uses for recording output buffers from WebNN graph compution when WebAssembly - // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews - // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory - // address is different from the non-shared one, additional memory copy is required here. InlinedHashMap output_views; -#endif + for (const auto& output : outputs) { const std::string& name = output.first; const struct OnnxTensorData tensor = output.second; @@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap& inputs, name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS output_views.insert({name, view}); -#else - wnn_outputs_.set(name, view); -#endif } - wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + emscripten::val results = wnn_context_.call( + "compute", wnn_graph_, wnn_inputs_, wnn_outputs_) + .await(); + + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. for (const auto& output : outputs) { const std::string& name = output.first; emscripten::val view = output_views.at(name); - view.call("set", wnn_outputs_[name]); + view.call("set", results["outputs"][name]); } -#endif + // WebNN compute() method would return the input and output buffers via the promise + // resolution. Reuse the buffers to avoid additional allocation. + wnn_inputs_ = results["inputs"]; + wnn_outputs_ = results["outputs"]; + return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cf8a0e23db43b..56f7ead8ccf5d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { for (auto& name : output_names_) { named_operands.set(name, wnn_operands_.at(name)); } - emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + + emscripten::val wnn_graph = wnn_builder_.call("build", named_operands).await(); if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } @@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Pre-allocate the input and output tensors for the WebNN graph - // when WebAssembly multi-threads is enabled since WebNN API only - // accepts non-shared ArrayBufferView. - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews + // for inputs and outputs because they will be transferred after compute() done. + // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution model->AllocateInputOutputBuffers(); -#endif return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 2922cf9540a8e..df7871614b267 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (webnn_power_flags.compare("default") != 0) { context_options.set("powerPreference", emscripten::val(webnn_power_flags)); } - wnn_context_ = ml.call("createContextSync", context_options); + + wnn_context_ = ml.call("createContext", context_options).await(); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7c70515e73eab..7e9c0a6f99c32 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea }; // replace the original functions with asyncified versions + Module['_OrtCreateSession'] = jsepWrapAsync( + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v); Module['_OrtRun'] = runAsync(jsepWrapAsync( Module['_OrtRun'], () => Module['_OrtRun'], From b03a1c588c2fbe8dcfba53d3e3bc95a4b4871bdf Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Fri, 26 Jan 2024 00:25:35 +0800 Subject: [PATCH 07/51] [js/webgpu] Fix Tanh explosion (#19201) ### Description ```math \tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}= \left\{ \begin{array}{cc} -\frac{1-e^{-2\cdot(-x)}}{1+e^{-2\cdot(-x)}}, & x<0 \\ 0, & x=0 \\ \frac{1-e^{-2x}}{1+e^{-2x}}, & x>0 \end{array} \right. ``` ### Motivation and Context On some platforms, $$\tanh(1000)=\frac{e^{1000}-e^{-1000}}{e^{1000}+e^{-1000}}$$ would produce NaN instead of 0.999... or 1 (imagine $e^{1000}=\infty$ and $\frac{\infty}{\infty}$ explodes). --- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 4 +++- js/web/test/data/ops/tanh.jsonc | 26 +++++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 js/web/test/data/ops/tanh.jsonc diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 82311d72e58b9..76929efb32537 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -279,7 +279,9 @@ export const tan = (context: ComputeContext): void => { }; export const tanh = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); + // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/test/data/ops/tanh.jsonc b/js/web/test/data/ops/tanh.jsonc new file mode 100644 index 0000000000000..f7691535bd71c --- /dev/null +++ b/js/web/test/data/ops/tanh.jsonc @@ -0,0 +1,26 @@ +[ + { + "name": "tanh with no attributes", + "operator": "Tanh", + "attributes": [], + "cases": [ + { + "name": "T[2,4]", + "inputs": [ + { + "data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 373b3c645df57..56db28b0a379c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1389,6 +1389,7 @@ "sub.jsonc", "sub_int32.jsonc", "tan.jsonc", + "tanh.jsonc", "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc", From f02accb99b2ef19bf64531f46be9222a1a0f1be6 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 07:37:05 +0800 Subject: [PATCH 08/51] [js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753) --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 125 +++++++------ .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 154 ++++++++-------- .../ops/3rd-party/conv_backprop_webgpu.ts | 174 +++++++++++------- .../ops/3rd-party/matmul_packed_webgpu.ts | 108 +++++------ .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 86 +++++---- .../wasm/jsep/webgpu/ops/conv-transpose.ts | 15 +- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 18 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 39 ++-- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 43 +++-- 9 files changed, 418 insertions(+), 344 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3638938df7dbe..1a03621512888 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; import {getActivationSnippet} from '../fuse-utils'; @@ -88,10 +88,10 @@ const conv2dCommonSnippet = let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; - let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels); + let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]); + let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; + let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; let xCh = ${col} % inChannels; var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for @@ -108,7 +108,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : @@ -117,7 +117,7 @@ const conv2dCommonSnippet = ${readXSnippet}` : ` let col = colIn * ${innerElementSizeX}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readXSnippet} } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); @@ -129,9 +129,8 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -142,7 +141,7 @@ const conv2dCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; @@ -181,31 +180,46 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; - const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); - const fitAOuter = dimAOuter % tileAOuter === 0; const fitBOuter = dimBOuter % tileBOuter === 0; const fitInner = dimInner % tileInner === 0; - const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; - const t = tensorTypeToWsglStorageType(inputs[0].dataType); - // TODO: support component 2, 3. - const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = - inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); - const inputVariables = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, + {type: 'int32', data: attributes.dilations} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, + {name: 'dilation', type: 'i32', length: 2} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } - let declareFunctions = ` + // TODO: support component 2, 3. + const components = isVec4 ? 4 : 1; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); + let declareFunctions = ` fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } @@ -213,51 +227,50 @@ export const createConv2DMatMulProgramInfo = let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; - if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); - - programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` + const x = inputVariable( + 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); + const inputVariables = [x, w]; + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; - } - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms, - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + } + + return ` ${utilFunctions('uniforms.result_strides')} //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)} - const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); - const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); - const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} - ${ + ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + sequentialAccessByThreads)}`; + }; + return { + name: 'Conv2DMatMul', + shaderCache: { + hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ + tileAOuter};${tileBOuter};${tileInner}`, + inputDependencies + }, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms, + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d425155857e14..33e50a9a39cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -21,8 +21,8 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; -import {ProgramInfo, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {getActivationSnippet} from '../fuse-utils'; @@ -74,21 +74,21 @@ const conv2dTransposeCommonSnippet = col % outWidth); `; - const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; - const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; + const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; const row = isChannelsLast ? 'row' : 'col'; const col = isChannelsLast ? 'col' : 'row'; const readASnippet = ` - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; let outRow = ${row} / outWidth; let outCol = ${row} % outWidth; - let WRow = ${col} / (filterDims[1] * inChannels); - let WCol = ${col} / inChannels % filterDims[1]; - let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); - let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + let WRow = ${col} / (uniforms.filter_dims[1] * inChannels); + let WCol = ${col} / inChannels % uniforms.filter_dims[1]; + let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]); + let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { return ${type}(0.0); } @@ -103,25 +103,25 @@ const conv2dTransposeCommonSnippet = const sampleA = isChannelsLast ? ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${readASnippet} } return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; - if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${readASnippet} } return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; - let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; - let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); - let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; + let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); + let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; if (${ - isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : - 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : + 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { let rowInner = row % inChannels; let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} @@ -129,9 +129,8 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const applyActivation = getActivationSnippet(attributes, type); const userCode = ` - ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -142,7 +141,7 @@ const conv2dTransposeCommonSnippet = fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueInput; let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; ${coordResSnippet} @@ -186,65 +185,64 @@ export const createConv2DTransposeMatMulProgramInfo = const innerElementSize = isVec4 ? 4 : 1; const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); const components = isVec4 ? 4 : 1; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); - const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); - const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); - const inputVariables = [x, w]; - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const effectiveFilterDims = [ + filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) + ]; - let declareFunctions = ''; + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, + {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, + {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); - inputVariables.push(bias); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - - declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { - return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; - }`; + inputDependencies.push('rank'); } - programUniforms.push(...createTensorShapeVariables(outputShape)); - return { - name: 'Conv2DTransposeMatMul', - shaderCache: {hint: attributes.cacheKey}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, - programUniforms - }), - getShaderSource: (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const inputVariables = [x, w]; + + let declareFunctions = ''; + if (hasBias) { + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); + inputVariables.push(bias); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + + const uniforms: UniformsArrayType = [ + {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, + {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, + {name: 'filter_dims', type: 'i32', length: filterDims.length}, + {name: 'pads', type: 'i32', length: pads.length} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` ${utilFunctions('uniforms.result_strides')} - ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .declareVariables(...inputVariables, output)}; - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ - attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${ - attributes.pads[1] + attributes.pads[3]})/2); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const dimAOuter : i32 = ${dimAOuter}; - const dimBOuter : i32 = ${dimBOuter}; - const dimInner : i32 = ${dimInner}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} ${ @@ -252,6 +250,18 @@ export const createConv2DTransposeMatMulProgramInfo = elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, - undefined, sequentialAccessByThreads)}` + undefined, sequentialAccessByThreads)}`; + }; + + return { + name: 'Conv2DTransposeMatMul', + shaderCache: + {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + programUniforms + }), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 50b0841a0200a..380efc8bc577a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -20,24 +20,18 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; -import {ProgramInfo} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = - (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, - dataType: string): string => { - const isChannelsLast = attributes.format === 'NHWC'; + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, + is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, + isChannelsLast = false): string => { const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; - const outputSize = ShapeUtil.size(outputShape); const workPerThread = isVec4 ? 2 : 1; - const group = attributes.group; - const wShape = inputs[1].dims; - const inputChannelsPerGroup = wShape[0] / group; - const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { @@ -50,20 +44,21 @@ const createConvTranspose2DOpProgramShaderSource = }`; } const components = isVec4 ? 4 : 1; - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); const inputVariables = [dy, w]; if (hasBias) { - inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); } - const output = outputVariable('result', inputs[0].dataType, outputShape, components); + const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); + const codeSnippet4 = `{ - let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; - let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; - let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); + let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. @@ -71,29 +66,29 @@ const createConvTranspose2DOpProgramShaderSource = for (var i = 0; i < ${workPerThread}; i++) { dotProd[i] = vec4<${dataType}>(0.0); } - for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); - let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || + for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); + let wRPerm = uniforms.filter_dims[0] - 1 - wR; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); - let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims[1] - 1 - wC; + for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -101,7 +96,7 @@ const createConvTranspose2DOpProgramShaderSource = let idyC: u32 = u32(dyC); let idyC2: u32 = u32(dyC2); if (bDyCVal && bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -123,7 +118,7 @@ const createConvTranspose2DOpProgramShaderSource = dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[${channelDim}]; + let d2Length = uniforms.Dy_shape[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -138,7 +133,7 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { - let d2Length = outBackprop[3]; + let d2Length = uniforms.Dy_shape[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; @@ -167,39 +162,39 @@ const createConvTranspose2DOpProgramShaderSource = let d1 = ${output.indicesGet('outputIndices', channelDim)}; let r = ${output.indicesGet('outputIndices', rowDim)}; let c = ${output.indicesGet('outputIndices', colDim)}; - let dyCorner = vec2(i32(r), i32(c)) - pads; + let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; - let groupId = d1 / ${outputChannelsPerGroup}; - let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; + let groupId = d1 / uniforms.output_channels_per_group; + let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. var dotProd = ${dataType}(0.0); - for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) { - if (wR % dilations.x != 0) { + for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { + if (wR % uniforms.dilations.x != 0) { continue; } - let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); - let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); + let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; + if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); - for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) { - if (wC % dilations.y != 0) { + for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { + if (wC % uniforms.dilations.y != 0) { continue; } - let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); - let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); + let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; + if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - var inputChannel = groupId * ${inputChannelsPerGroup}; - for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + var inputChannel = groupId * uniforms.input_channels_per_group; + for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; @@ -214,27 +209,11 @@ const createConvTranspose2DOpProgramShaderSource = `; return ` - ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${declareFunctions} - const outShape : vec4 = vec4(${outputShape.join(',')}); - const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); - const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); - const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ - attributes.kernelShape[isChannelsLast ? 2 : 3]}); - const dilations : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); - const effectiveFilterDims : vec2 = filterDims + vec2( - ${ - attributes.dilations[0] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, - ${ - attributes.dilations[1] <= 1 ? - 0 : - (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); - const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, - i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ${isVec4 ? codeSnippet4 : codeSnippet}}`; }; @@ -257,19 +236,72 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const isChannelsLast = attributes.format === 'NHWC'; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + const strides = [attributes.strides[0], attributes.strides[1]]; + const filterDims = + [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; + const dilations = [attributes.dilations[0], attributes.dilations[1]]; + const effectiveFilterDims = [ + filterDims[0] + + (attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), + filterDims[1] + + (attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) + ]; + const pads = [ + effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), + effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 + ]; + + const isVec4 = false; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; + + const programUniforms: ProgramUniform[] = [ + {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, + {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, + {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + ]; + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + + const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, + {name: 'filter_dims', type: 'u32', length: filterDims.length}, + {name: 'dilations', type: 'u32', length: filterDims.length}, + {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, + {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return `${ + createConvTranspose2DOpProgramShaderSource( + shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, + isChannelsLast)}`; + }; return { name: 'ConvTranspose2D', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, getRunData: () => ({ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType - }] + }], + programUniforms }), - getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, - dataType), + getShaderSource }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 47ec16a296712..ee71110245252 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; +import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -112,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3, ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { let inputRow = tileRow + innerRow; @@ -204,7 +204,7 @@ export const makeMatMulPackedSource = let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { @@ -260,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; let tileRowB = i32(localId.y) * ${rowPerThreadB}; // Loop over shared dimension. -for (var t = 0; t < numTiles; t = t + 1) { +for (var t = 0; t < num_tiles; t = t + 1) { // Load one tile of A into local memory. for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { @@ -322,7 +322,8 @@ fn main(@builtin(local_invocation_id) localId : vec3, @builtin(workgroup_id) workgroupId : vec3) { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} - let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + let num_tiles = ${ + splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; var acc : array, rowPerThread>; @@ -379,7 +380,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimAOuter && col < uniforms.dimInner) + if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ${getAIndices()} value = ${aVariable.getByIndices('aIndices')}; @@ -391,7 +392,7 @@ const matMulReadWriteFnSource = typeSnippet(component, dataType)} { var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; - if(row < uniforms.dimInner && col < uniforms.dimBOuter) + if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${getBIndices()} value = ${bVariable.getByIndices('bIndices')}; @@ -401,7 +402,7 @@ const matMulReadWriteFnSource = fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; - if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { var value = valueIn; let coords = vec3(batch, row, colIn); ${ @@ -422,16 +423,10 @@ export const createMatmulProgramInfo = isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { const aShape = inputs[0].dims; const bShape = inputs[1].dims; - const outerDimsA = aShape.slice(0, -2); const outerDimsB = bShape.slice(0, -2); - const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); - const enableBatchUniforms = enableShapesUniforms(outerDims.length); - const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); const batchSize = ShapeUtil.size(outerDims); - const dimAOuter = aShape[aShape.length - 2]; const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; @@ -446,72 +441,67 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; - const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); - const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; - + const aShapeOrRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); - const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; - + const bShapeOrRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); - const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); - const inputVariables = [A, B]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (enableBatchUniforms) { - programUniforms.push(...createTensorShapeVariables(outerDims)); + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); } - if (enableAShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(aShapeTemp)); - } - if (enableBShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(bShapeTemp)); - } - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), + ...createTensorShapeVariables(bShapeTemp)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); - const declareFunctions = matMulReadWriteFnSource( - components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], - isChannelsLast); if (hasBias) { - const biasComponents = isChannelsLast ? components : 1; - inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); - inputDependencies.push('rank'); } programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => { + const batchShapeOrRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + + const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); + const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); + const inputVariables = [A, B]; + if (hasBias) { + const biasComponents = isChannelsLast ? components : 1; + inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + } + const uniforms: UniformsArrayType = + [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const declareFunctions = matMulReadWriteFnSource( + components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], + isChannelsLast); + return ` ${ - shaderHelper.registerUniform('dimAOuter', 'i32') - .registerUniform('dimBOuter', 'i32') - .registerUniform('dimInner', 'i32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${declareFunctions} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} `; - // TODO: turn clipMax and clipMin to uniforms. + }; return { name: 'MatMul', shaderCache: { - hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + - `${isVec4}` + - `${isChannelsLast}`, + hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, inputDependencies }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 21b4953d3f90c..f81d6577890c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -3,9 +3,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {ProgramInfo, ProgramUniform} from '../types'; +import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {getActivationSnippet} from './fuse-utils'; @@ -27,52 +27,75 @@ export const createGroupedConvProgramInfo = xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); - const output = outputVariable('output', inputs[0].dataType, outputShape); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); - const x = inputVariable('x', inputs[0].dataType, xShape); - const w = inputVariable('w', inputs[1].dataType, wShape); - const inputVars = [x, w]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, + {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + ]; + if (attributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { - inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + inputDependencies.push('rank'); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); - const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); - - ${shaderHelper.declareVariables(...inputVars, output)} + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const applyActivation = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length); + const w = inputVariable('w', inputs[1].dataType, wShape.length); + const inputVars = [x, w]; + if (hasBias) { + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + } - ${activationFunction} + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, + {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, + {name: 'output_channels_per_group', type: 'u32'} + ]; + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let outputIndices = ${output.offsetToIndices('global_idx')}; let batch: u32 = outputIndices[0]; let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ - isChannelLast ? 2 : 3}]) * strides - pads; - let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; + let group_id: u32 = output_channel / uniforms.output_channels_per_group; var value: ${output.type.value} = ${output.type.value}(0); - for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { - let input_channel = group_id * ${wShape[1]}u + wInChannel; - for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { - let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { + let input_channel = group_id * uniforms.w_shape[1] + wInChannel; + for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { + let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; - if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) { continue; } - for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { - let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; - if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { + let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; + if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) { continue; } let xVal = ${ - isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : - x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : + x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; value += xVal*wVal; } @@ -82,15 +105,17 @@ export const createGroupedConvProgramInfo = ${applyActivation} ${output.setByOffset('global_idx', 'value')} }`; + }; return { name: 'GroupedConv', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [{ dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, dataType: inputs[0].dataType }], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; @@ -114,7 +139,7 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const applyActivation = getActivationSnippet(attributes, output.type.value); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; @@ -129,7 +154,6 @@ export const createGroupedConvVectorizeProgramInfo = .registerUniform('strides', 'i32', 2) .registerUniform('pads', 'i32', 2) .declareVariables(...inputVars, output)} - ${activationFunction} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -179,7 +203,7 @@ export const createGroupedConvVectorizeProgramInfo = return { name: 'GroupedConv-Vectorize', shaderCache: { - hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] }, getRunData: () => ({ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 32b1d52ed94ca..33d16754c737a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -2,7 +2,6 @@ // Licensed under the MIT License. import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; @@ -59,7 +58,6 @@ export interface ConvTransposeAttributes extends ConvAttributes { readonly outputShape: readonly number[]; } - const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); @@ -96,11 +94,7 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - const cacheKey = attributes.cacheKey + [ - kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), - dilations.join(',') - ].join('_'); - Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); return newAttributes; }; @@ -119,7 +113,7 @@ export const parseConvTransposeAttributes = (attributes: Record const wIsConst = (attributes.wIsConst as () => boolean)(); const outputPadding = attributes.outputPadding as [number, number, number, number]; const outputShape = attributes.outputShape as [number, number]; - return createAttributeWithCacheKey({ + return { autoPad, format, dilations, @@ -130,8 +124,9 @@ export const parseConvTransposeAttributes = (attributes: Record pads, strides, wIsConst, - ...activationAttributes - }); + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 7af2c5db49f40..5afec0389fac8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -3,7 +3,7 @@ import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext} from '../types'; import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; @@ -110,7 +110,7 @@ const getAdjustedConvAttributes = (attributes: T, inpu // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + Object.assign(newAttributes, {kernelShape, pads}); return newAttributes; }; @@ -126,8 +126,18 @@ export const parseConvAttributes = (attributes: Record): ConvAt const strides = attributes.strides as [number, number]; const wIsConst = (attributes.w_is_const as () => boolean)(); - return createAttributeWithCacheKey( - {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + return { + autoPad, + format, + dilations, + group, + kernelShape, + pads, + strides, + wIsConst, + ...activationAttributes, + cacheKey: `${attributes.format};${activationAttributes.activation};` + }; }; const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 0b5c0db2b5112..2e0aa33a957dc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -7,30 +7,21 @@ export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; - readonly activationCacheKey: string; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): - {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; - case 'Sigmoid': - return { - activationFunction: '', - applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` - }; - case 'Clip': - return { - activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ - attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + // TODO: adding other activations that can be fused. + default: + return ''; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { @@ -38,7 +29,7 @@ export const parseInternalActivationAttributes = if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; - return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + return {activation, clipMax, clipMin}; } - return {activation, activationCacheKey: activation}; + return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index de9309d1e436f..c946ea6366123 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -6,7 +6,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -27,11 +27,19 @@ export const createNaiveMatmulProgramInfo = const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); const batchSize = ShapeUtil.size(outerDims); const outputShapeInShader = [batchSize, M, N]; + const programUniforms: ProgramUniform[] = [ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape) + {type: 'uint32', data: K} ]; + if (activationAttributes.activation === 'Clip') { + programUniforms.push( + {type: 'float32', data: activationAttributes.clipMax!}, + {type: 'float32', data: activationAttributes.clipMin!}); + } + programUniforms.push( + ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), + ...createTensorShapeVariables(bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } @@ -42,7 +50,7 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { @@ -57,6 +65,14 @@ export const createNaiveMatmulProgramInfo = const outerDimsB = bShape.slice(0, -2); const broadCastADims = getBroadcastDims(outerDimsA, outerDims); const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'K', type: 'u32'} + ]; + if (activationAttributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; const name = variable.name; @@ -96,15 +112,10 @@ export const createNaiveMatmulProgramInfo = return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('M', 'u32') - .registerUniform('N', 'u32') - .registerUniform('K', 'u32') - .registerInternalVariables(batchDims) - .declareVariables(...inputVariables, output)} - ${activationFunction} + shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( + ...inputVariables, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let col = (global_idx % (uniforms.N / ${components})) * ${components}; var index1 = global_idx / (uniforms.N / ${components}); let stride1 = uniforms.M / ${outputNumber}; @@ -134,8 +145,7 @@ export const createNaiveMatmulProgramInfo = return { name: 'MatMulNaive', shaderCache: { - hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ - isChannelsLast}`, + hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] }, getRunData: () => ({ @@ -166,9 +176,8 @@ export const matMul = (context: ComputeContext): void => { const N = outputShape[outputShape.length - 1]; const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; if (N < 8 && K < 8) { - context.compute( - createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } else { - context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); } }; From 3abc3db2b77dd46f0268ba1e20a072df46395cb8 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 26 Jan 2024 08:58:22 +0800 Subject: [PATCH 09/51] [js/webgpu] Support f16 uniform (#19098) ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/backend-webgpu.ts | 26 +++++++++--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 40 +++++++++++++------ js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 4 +- js/web/lib/wasm/jsep/webgpu/types.ts | 2 +- .../core/providers/js/operators/pad.cc | 10 ++--- 5 files changed, 56 insertions(+), 26 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 8ca025d66550c..a48fe99570abf 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -428,13 +428,26 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const baseAlignment = data.length <= 2 ? data.length * 4 : 16; + const sizeOfElement = v.type === 'float16' ? 2 : 4; + let sizeOfVecOrMat; + let baseAlignment; + if (v.type === 'float16') { + baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); + sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; + } else { + baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; + sizeOfVecOrMat = 16; + } currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; offsets.push(currentOffset); - // When data.length > 4, the uniform variable is of type array,N>, where N = - // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * - // SizeOf(vec4). - currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; + // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where + // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * + // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type + // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte + // length is N * SizeOf(mat2x4). + const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : + data.length * sizeOfElement; }); // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set @@ -449,6 +462,9 @@ export class WebGpuBackend { new Int32Array(arrayBuffer, offset, data.length).set(data); } else if (v.type === 'uint32') { new Uint32Array(arrayBuffer, offset, data.length).set(data); + } else if (v.type === 'float16') { + // TODO: use Float16Array. + new Uint16Array(arrayBuffer, offset, data.length).set(data); } else { new Float32Array(arrayBuffer, offset, data.length).set(data); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index bc3265be955f0..643744108c0f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => { * @param name - the name of variable. * @param index - the index of variable element. * @param length - the length of variable. + * @param type - the type of variable, optional. */ -export const getElementAt = (name: string, index: number|string, length: number): string => { - if (name.startsWith('uniforms.') && length > 4) { - if (typeof (index) === 'string') { - return `${name}[(${index}) / 4][(${index}) % 4]`; - } else { - return `${name}[${Math.floor(index / 4)}][${index % 4}]`; - } - } else { - return length > 1 ? `${name}[${index}]` : name; - } -}; +export const getElementAt = + (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { + if (name.startsWith('uniforms.') && length > 4) { + if (typeof (index) === 'string') { + if (type === 'f16') { + return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; + } else { + return `${name}[(${index}) / 4][(${index}) % 4]`; + } + } else { + if (type === 'f16') { + return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; + } else { + return `${name}[${Math.floor(index / 4)}][${index % 4}]`; + } + } + } else { + return length > 1 ? `${name}[${index}]` : name; + } + }; /** * A helper function to get a IndicesHelper for a given input or output. @@ -688,7 +698,7 @@ export const internalVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'internal', components); -export type UniformDataElementType = 'u32'|'f32'|'i32'; +export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; /** @@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper { const uniformSnippets: string[] = []; for (const {name, type, length} of this.uniforms) { if (length && length > 4) { - uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + if (type === 'f16') { + uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); + } else { + uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } } else { const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; uniformSnippets.push(`${name}:${typeTemp}`); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index eca3fa7d944bb..c65b741e1105a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length < 1) { throw new Error('Too few inputs'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Input type must be float.'); + if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) { + throw new Error('Input type must be float or float16.'); } if (inputs.length >= 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index e55bfb6ba9f16..789ac70a6913a 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -24,7 +24,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float32'|'uint32'; + type: 'int32'|'float16'|'float32'|'uint32'; data: number|readonly number[]; } diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc index 24ba85cbf6e0d..83fee35481aa6 100644 --- a/onnxruntime/core/providers/js/operators/pad.cc +++ b/onnxruntime/core/providers/js/operators/pad.cc @@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 2, 10, kJsExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Pad); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 17, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 18, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), @@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX( 19, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", JsepSupportedFloatTypes()) .InputMemoryType(OrtMemTypeCPU, 1) .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3), From 5ef244a47937e693a2d5fe1650a123da1e261f38 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 29 Jan 2024 10:13:46 -0800 Subject: [PATCH 10/51] fix f16 for attention, enable slice and flatten for more types (#19262) --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 2 +- onnxruntime/core/providers/js/operators/flatten.cc | 8 ++++---- onnxruntime/core/providers/js/operators/slice.cc | 12 ++++-------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index ef8038dff487e..f07a21a343fa8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -297,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView if (sum == 0) { for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { - x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')}; + x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')}; } } else { for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { diff --git a/onnxruntime/core/providers/js/operators/flatten.cc b/onnxruntime/core/providers/js/operators/flatten.cc index 7e4b4c350951b..1aacae819e304 100644 --- a/onnxruntime/core/providers/js/operators/flatten.cc +++ b/onnxruntime/core/providers/js/operators/flatten.cc @@ -13,7 +13,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -23,7 +23,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -33,7 +33,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); ONNX_OPERATOR_KERNEL_EX( @@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); } // namespace js diff --git a/onnxruntime/core/providers/js/operators/slice.cc b/onnxruntime/core/providers/js/operators/slice.cc index bbafe40ea92ac..869b5450501e1 100644 --- a/onnxruntime/core/providers/js/operators/slice.cc +++ b/onnxruntime/core/providers/js/operators/slice.cc @@ -12,8 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 9, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice_1); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -26,8 +25,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3) .InputMemoryType(OrtMemTypeCPU, 4) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -40,8 +38,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3) .InputMemoryType(OrtMemTypeCPU, 4) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice); ONNX_OPERATOR_KERNEL_EX( @@ -54,8 +51,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3) .InputMemoryType(OrtMemTypeCPU, 4) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice); } // namespace js From 55cede94bc8caca5cca12775fdff9db78e8a492b Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 30 Jan 2024 09:49:06 +0800 Subject: [PATCH 11/51] [js/webgpu] Remove enableShapesUniforms (#19279) --- .../ops/3rd-party/matmul_packed_webgpu.ts | 12 +++--- js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 37 +++++++----------- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 3 -- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 26 ++++--------- js/web/lib/wasm/jsep/webgpu/ops/einsum.ts | 31 +++++---------- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 25 ++++-------- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 39 ++++++------------- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 25 +++++------- 9 files changed, 68 insertions(+), 134 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index ee71110245252..5881c055ef135 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -443,9 +443,9 @@ export const createMatmulProgramInfo = const components = isVec4 ? 4 : 1; const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; - const aShapeOrRank = aShapeTemp.length; + const aRank = aShapeTemp.length; const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; - const bShapeOrRank = bShapeTemp.length; + const bRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; @@ -467,12 +467,12 @@ export const createMatmulProgramInfo = programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); const getShaderSource = (shaderHelper: ShaderHelper) => { - const batchShapeOrRank = outerDims.length; - const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const batchRank = outerDims.length; + const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1); const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); - const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); + const A = inputVariable('a', inputs[0].dataType, aRank, components); + const B = inputVariable('b', inputs[1].dataType, bRank, components); const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); const inputVariables = [A, B]; if (hasBias) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 00a6ca75b34fa..159b971636765 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -8,7 +8,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; export interface BatchNormAttributes extends AttributeWithCacheKey { readonly epsilon: number; @@ -61,7 +61,7 @@ const createBatchNormInferenceProgramInfo = const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; const outputSize = ShapeUtil.size(yShape) / components; // Only support uniforms for opset version >= 9 (spatial = true). - const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const useShapesUniforms = spatial; const shapeOrRank = useShapesUniforms ? yShape.length : yShape; const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index c033c0ba05356..8e144a36dc1b0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; type BuiltinFunctionName = string; type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; @@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, - typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, - additionalImplementation?: string) => { + typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -31,12 +30,9 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; - const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; - const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; - const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); - const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); - const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); + const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4); + const a = inputVariable('aData', typeA, dimsA.length, 4); + const b = inputVariable('bData', typeB, dimsB.length, 4); let assignment: string; if (vectorize) { @@ -169,30 +165,25 @@ const createBinaryOpProgramInfo = vectorize = true; } cacheKeyAux.push(vectorize); - const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && - enableShapesUniforms(outputShape.length); + return { name, shaderCache: { hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), - inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], + inputDependencies: ['rank', 'rank'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, - a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), + a.dataType, b.dataType, outputDataType, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b.dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ], + programUniforms: [ + {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + ...createTensorShapeVariables(a.dims), + ...createTensorShapeVariables(b.dims), + ...createTensorShapeVariables(outputShape), + ], }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 643744108c0f4..1bedf31ee4e38 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -922,6 +922,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; - -// TODO: remove this when all related uses have been removed. -export const enableShapesUniforms = (_rank: number): boolean => true; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 43cc4a4c080bd..daa326b1a34e2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -94,32 +94,22 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P let previousSum = 0; const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputShapeOrRanks = []; - const enableInputShapesUniforms = []; + const inputRanks = []; const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; - enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); - inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); - inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); - inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); } for (let i = 0; i < inputs.length; ++i) { - if (enableInputShapesUniforms[i]) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - } - - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); } + programUniforms.push(...createTensorShapeVariables(outputShape)); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); - + const output = outputVariable('output', dataType, outputShape.length); const indicesAxis = output.indicesGet('indices', adjustedAxis); const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 4db7c04ad67be..9e1f58bbfa127 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -6,8 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; - +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface EinsumAttributes extends AttributeWithCacheKey { readonly equation: string; @@ -181,14 +180,12 @@ class EinsumEquation { const appendMax = (name: string): string => name + '_max'; const createEinsumProgramInfo = - (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, - einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { - const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); - const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + (inputShapes: Array, dataType: number, einsumEquation: EinsumEquation, + outputShape: readonly number[]): ProgramInfo => { + const ranks = inputShapes.map((dims) => dims.length); + const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank)); const outputSize = ShapeUtil.size(outputShape); - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); + const output = outputVariable('output', dataType, outputShape.length); const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -269,10 +266,7 @@ const createEinsumProgramInfo = }; return { name: 'Einsum', - shaderCache: { - hint: einsumEquation.equation, - inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') - }, + shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')}, getRunData: () => { // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The // filter is added to make sure that dimValue is never 0. @@ -281,12 +275,9 @@ const createEinsumProgramInfo = .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); programUniformsInit.push({type: 'uint32', data: outputSize}); const programUniforms: ProgramUniform[] = - inputShapes.filter((_, index) => enableInputShapesUniforms[index]) - .map((dims, _) => [...createTensorShapeVariables(dims)]) + inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + programUniforms.push(...createTensorShapeVariables(outputShape)); return ({ outputs: [{dims: outputShape, dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, @@ -299,11 +290,9 @@ const createEinsumProgramInfo = export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); - const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); const outputShape = einsumEquation.outputDims; const inputShapes = context.inputs.map((input, _) => input.dims); - context.compute(createEinsumProgramInfo( - enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); + context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); }; export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index 035d89755c7d7..dd18bd23a5912 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 2) { @@ -49,15 +49,9 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const components = dataType === DataType.bool ? 4 : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); - const enableInputShapeUniform = enableShapesUniforms(inputShape.length); - const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); - - const getShaderSource = (shaderHelper: ShaderHelper) => { - const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; - const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; - const input = inputVariable('input', dataType, inputShapeOrRank, components); - const output = outputVariable('output', dataType, outputShapeOrRank, components); + const input = inputVariable('input', dataType, inputShape.length, components); + const output = outputVariable('output', dataType, outputShape.length, components); let assignment: string; if (dataType === DataType.bool) { const singleAssignment = (resStr: string, x: number, typeCast = '') => ` @@ -90,16 +84,13 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ${assignment}`; }; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; - if (enableInputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(inputShape)); - } - if (enableOutputShapeUniform) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + ...createTensorShapeVariables(outputShape) + ]; return { name: 'Expand', - shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, + shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 469249f92ff28..e2a62c6655c72 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -33,33 +33,16 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const components = inputs[0].dataType === DataType.bool ? 4 : 1; const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); - const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); - const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; - const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); - const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; - const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; - if (enableInputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); - } - if (enableIndicesShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); - } - if (enableOutputShapesUniforms) { - programUniforms.push(...createTensorShapeVariables(outputShape)); - } - - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); - inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims), + ...createTensorShapeVariables(outputShape) + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { - const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components); - const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); - const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components); + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length, components); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length); + const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); const calcDataIndices = (x: number|string): string => { const indicesRank = indicesShape.length; @@ -127,7 +110,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }; return { name: 'Gather', - shaderCache: {hint: attributes.cacheKey, inputDependencies}, + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index c4d43e9f466f5..ab9a9ac8dd1f0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -39,12 +39,9 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); - const useShapesUniforms = enableShapesUniforms(inputRank); const outputShape = getOutputShape(inputTensor.dims, perm); - const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; - const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; - const output = outputVariable('output', inputDataType, outShapeOrRank); - const input = inputVariable('a', inputDataType, inShapeOrRank); + const output = outputVariable('output', inputDataType, outputShape.length); + const input = inputVariable('a', inputDataType, inputRank); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} @@ -61,21 +58,17 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu }`; return { name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, getRunData: (inputs) => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: useShapesUniforms ? - [ - {type: 'uint32', data: outputSize}, - ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape), - ] : - [ - {type: 'uint32', data: outputSize}, - ], + programUniforms: [ + {type: 'uint32', data: outputSize}, + ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(outputShape), + ], }; }, getShaderSource, From c61a8e5364738b6fda42ba7478556d00809098e8 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 31 Jan 2024 08:28:53 +0800 Subject: [PATCH 12/51] [js/webgpu] Add hardSigmoid activation for fusedConv (#19233) ### Description Add hardSigmoid activation for fusedConv. It will be used by mobilenetv3-small-100 model. --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 11 +- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 11 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 12 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 37 +++-- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 35 ++++- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 12 +- js/web/test/data/ops/fused-conv.jsonc | 144 ++++++++++++++++++ .../core/optimizer/conv_activation_fusion.cc | 2 +- 8 files changed, 207 insertions(+), 57 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 1a03621512888..e5ca3204d4433 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -193,10 +193,7 @@ export const createConv2DMatMulProgramInfo = {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; @@ -212,9 +209,7 @@ export const createConv2DMatMulProgramInfo = {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, {name: 'dilation', type: 'i32', length: 2} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); // TODO: support component 2, 3. const components = isVec4 ? 4 : 1; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 33e50a9a39cb9..e50733559dbe9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -201,10 +201,7 @@ export const createConv2DTransposeMatMulProgramInfo = {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, {type: 'int32', data: filterDims}, {type: 'int32', data: pads} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); @@ -237,9 +234,7 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'filter_dims', type: 'i32', length: filterDims.length}, {name: 'pads', type: 'i32', length: pads.length} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 5881c055ef135..00c1f86d67419 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -449,11 +449,7 @@ export const createMatmulProgramInfo = const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } + appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), ...createTensorShapeVariables(bShapeTemp)); @@ -481,9 +477,7 @@ export const createMatmulProgramInfo = } const uniforms: UniformsArrayType = [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(activationAttributes, uniforms); const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const declareFunctions = matMulReadWriteFnSource( components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index f81d6577890c5..c0aaaa7ce134b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../ import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActivationSnippet} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -32,10 +32,7 @@ export const createGroupedConvProgramInfo = {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShape)); @@ -61,9 +58,7 @@ export const createGroupedConvProgramInfo = {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, {name: 'output_channels_per_group', type: 'u32'} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} @@ -132,10 +127,13 @@ export const createGroupedConvVectorizeProgramInfo = const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape), - ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader) + {type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]} ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShapeInShader)); const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); @@ -147,13 +145,14 @@ export const createGroupedConvVectorizeProgramInfo = inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); } const processBias = hasBias ? 'value += b[output_channel];' : ''; - + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'strides', type: 'i32', length: 2}, + {name: 'pads', type: 'i32', length: 2}, + ]; + appendActivationUniforms(attributes, uniforms); return ` - ${ - shaderHelper.registerUniform('output_size', 'u32') - .registerUniform('strides', 'i32', 2) - .registerUniform('pads', 'i32', 2) - .declareVariables(...inputVars, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -173,7 +172,7 @@ export const createGroupedConvVectorizeProgramInfo = // Use constant instead of uniform can give better performance for w's height/width. for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { let x_height = x_corner.x + i32(w_height); - if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) { + if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) { for (var i = 0; i < ${xNumber}; i++) { let x_width = x_corner.y + i; if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { @@ -185,7 +184,7 @@ export const createGroupedConvVectorizeProgramInfo = for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; for (var i = 0u; i < ${outputNumber}u; i++) { - values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]); + values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]); } } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 2e0aa33a957dc..e1dc9a5e0ab7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -2,11 +2,16 @@ // Licensed under the MIT License. import {MAX_CLIP, MIN_CLIP} from '../../util'; +import {ProgramUniform} from '../types'; + +import {UniformsArrayType} from './common'; export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; + readonly alpha?: number; + readonly beta?: number; } export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { @@ -17,17 +22,41 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; case 'Clip': return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${ + valueType}(uniforms.beta)));`; + case '': + return ''; // TODO: adding other activations that can be fused. default: - return ''; + throw new Error(`Unsupported activation ${attributes.activation}`); + } +}; + +export const appendActivationUniformsData = + (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { + if (attributes.activation === 'Clip') { + programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } else if (attributes.activation === 'HardSigmoid') { + programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!}); + } + }; + +export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => { + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } else if (attributes.activation === 'HardSigmoid') { + uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); } }; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { const activation = attributes?.activation as string || ''; - - if (activation === 'Clip') { + if (activation === 'HardSigmoid') { + const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; + return {activation, alpha, beta}; + } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; return {activation, clipMax, clipMin}; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index c946ea6366123..188b88b2510d8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -7,7 +7,7 @@ import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; -import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], @@ -32,11 +32,7 @@ export const createNaiveMatmulProgramInfo = {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K} ]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } + appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), ...createTensorShapeVariables(bShape)); @@ -69,9 +65,7 @@ export const createNaiveMatmulProgramInfo = {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'} ]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(activationAttributes, uniforms); const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index ad1c0a72c11d3..c734d6db9b92a 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -142,5 +142,149 @@ ] } ] + }, + { + "name": "fused conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused group-conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index d27603e4ab3a1..b7cb3ba488c62 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -111,7 +111,7 @@ class ConvActivationSelector : public NodeSelector { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } - } else if (node_ep.empty() || node_ep == kCpuExecutionProvider) { + } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) { if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { return std::nullopt; From 43b95b05120caaf471647daa3b3763944daa0ccc Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 31 Jan 2024 10:28:03 +0800 Subject: [PATCH 13/51] [js/webgpu] Support capture and replay for jsep (#18989) This PR expands the graph capture capability to JS EP, which is similar to #16081. But for JS EP, we don't use the CUDA Graph, instead, we records all gpu commands and replay them, which removes most of the cpu overhead to avoid the the situation that gpu waiting for cpu. mobilenetv2-12 becomes 3.7ms from 6ms on NV 3090 and becomes 3.38ms from 4.58ms on Intel A770. All limitations are similar with CUDA EP: 1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not supported. 2. Usage of graph capture is limited to models where-in all ops in the model can be partitioned to the JS EP or CPU EP and no memory copy between them. 3. Shapes of inputs/outputs cannot change across inference calls. 4. IObinding is required. The usage is like below: Method 1: specify outputs buffers explicitly. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer/outputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; const fetches = { 'output': ort.Tensor.fromGpuBuffer(outputBuffer, { dataType: 'float32', dims: [1, 1000] }) }; let results = await session.run(feeds, fetches); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds, fetches); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release(); ``` Method 2: Don't specify outputs buffers explicitly. Internally, when graph capture is enabled, it will set all outputs location to 'gpu-buffer'. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; let results = await session.run(feeds); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release(); --- js/common/lib/inference-session.ts | 8 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 25 ++- js/web/lib/wasm/jsep/backend-webgpu.ts | 100 ++++++++++- js/web/lib/wasm/jsep/init.ts | 8 +- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 74 ++++++-- .../lib/wasm/jsep/webgpu/program-manager.ts | 15 +- js/web/lib/wasm/jsep/webgpu/types.ts | 2 + js/web/lib/wasm/session-options.ts | 12 ++ js/web/lib/wasm/wasm-core-impl.ts | 166 +++++++++++------- .../providers/js/js_execution_provider.cc | 49 +++++- .../core/providers/js/js_execution_provider.h | 18 +- .../core/providers/js/js_provider_factory.cc | 11 +- .../js/js_provider_factory_creator.h | 4 +- onnxruntime/core/session/inference_session.cc | 66 ++++--- .../core/session/provider_registration.cc | 2 +- onnxruntime/wasm/js_internal_api.js | 15 +- 16 files changed, 439 insertions(+), 136 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 1221b52cd4985..4f85c3b46e253 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -111,7 +111,7 @@ export declare namespace InferenceSession { optimizedModelFilePath?: string; /** - * Wether enable profiling. + * Whether enable profiling. * * This setting is a placeholder for a future use. */ @@ -154,6 +154,12 @@ export declare namespace InferenceSession { */ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + /** + * Whether enable graph capture. + * This setting is available only in ONNXRuntime Web for WebGPU EP. + */ + enableGraphCapture?: boolean; + /** * Store configurations for a session. See * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 24d7062c85fcb..5dd715191c830 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,6 +13,9 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; } export interface OrtWasmModule extends EmscriptenModule { @@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule { jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, + captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void; /** * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). @@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule { * @returns the GPU data ID for the registered GPU buffer. */ jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. - * - * @param sessionId - specify the session ID. - */ - jsepUnregisterBuffers?: (sessionId: number) => void; /** * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. * @@ -183,9 +181,18 @@ export interface OrtWasmModule extends EmscriptenModule { (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; /** - * [exported from js_internal_api.js] Called when InferenceSession.run started. + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. + * @param sessionId - specify the session ID. + * @returns */ - jsepOnRunStart: () => void; + jsepOnReleaseSession: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a48fe99570abf..e1faecfc046e3 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; + +interface CommandInfo { + readonly kernelId: number; + readonly computePipeline: GPUComputePipeline; + readonly bindGroup: GPUBindGroup; + readonly dispatchGroup: [number, number, number]; +} interface KernelInfo { readonly kernelType: string; @@ -103,6 +110,13 @@ export class WebGpuBackend { */ programManager: ProgramManager; + /** + * representing the session ID of which is currently being run. + * `null` means no session is being run. + * only valid when session.run is executed. + */ + currentSessionId: number|null = null; + /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. @@ -155,6 +169,16 @@ export class WebGpuBackend { queryType: TimestampQuery; env: Env; + sessionStatus: SessionState = 'default'; + /** + * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + */ + capturedCommandList: Map = new Map(); + + /** + * a SessionID -> PendingKernelInfo[] mapping for profiling. + */ + private capturedPendingKernels: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -228,6 +252,7 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { + const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; if (this.queryType === 'at-passes') { @@ -238,7 +263,7 @@ export class WebGpuBackend { }; } - this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -494,7 +519,7 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - if (this.queryType !== 'none') { + if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { const pendingKernelInfo: PendingKernelInfo = { kernelId: this.currentKernelId!, programName: artifact.programInfo.name, @@ -502,6 +527,9 @@ export class WebGpuBackend { outputTensorViews, }; this.pendingKernels.push(pendingKernelInfo); + + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); } this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); @@ -672,7 +700,71 @@ export class WebGpuBackend { } } } - onRunStart(): void { + + captureBegin(): void { + LOG_DEBUG('info', 'captureBegin'); + let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + if (!sessionCommandList) { + sessionCommandList = []; + this.capturedCommandList.set(this.currentSessionId!, sessionCommandList); + sessionPendingKernels = []; + this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'capturing'; + } + captureEnd(): void { + LOG_DEBUG('info', 'captureEnd'); + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + replay(): void { + LOG_DEBUG('info', 'replay'); + this.sessionStatus = 'replaying'; + const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + const length = sessionCommandList!.length; + this.pendingKernels = []; + for (let i = 0; i < length; i++) { + const computePassEncoder = this.getComputePassEncoder(); + const command = sessionCommandList![i]; + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { + this.flush(); + } + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + + onReleaseSession(sessionId: number): void { + this.unregisterBuffers(sessionId); + if (this.capturedCommandList.has(sessionId)) { + this.capturedCommandList.delete(sessionId); + } + if (this.capturedPendingKernels.has(sessionId)) { + this.capturedPendingKernels.delete(sessionId); + } + this.gpuDataManager.onReleaseSession(sessionId); + } + + onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; this.setQueryType(); } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index f1794d71579bf..786ae41646554 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); - }); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay()); }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 6f3d9a52d9f5d..c17bd1e1477ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -60,9 +60,15 @@ export interface GpuDataManager { unregisterExternalBuffer(buffer: GPUBuffer): void; /** - * destroy all gpu buffers. Call this when the session.release is called. + * destroy all gpu buffers. */ dispose(): void; + + /** + * release session related data. + * @param sessionId - specify the session ID. + */ + onReleaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager { // The external buffers registered users for IO Binding. private externalBuffers: Map; + // The pendingBuffers for capture graph. + // a SessionID -> GPUBuffer[] mapping. + private capturedPendingBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.buffersForUploadingPending = []; this.buffersPending = []; this.externalBuffers = new Map(); + this.capturedPendingBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager { () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ id}, buffer is the same, skip.`); return id; + } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { + throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. + Please use the previous external buffer!`); } this.externalBuffers.delete(previousBuffer); } else { @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); } this.buffersForUploadingPending = []; - for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); + + if (this.buffersPending.length === 0) { + return; + } + + if (this.backend.sessionStatus === 'default') { + for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise - } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { - // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. - this.freeUniformBuffers.get(buffer.size)!.push(buffer); - } else { - buffer.destroy(); + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); + } else { + buffer.destroy(); + } + } + this.buffersPending = []; + } else { + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. + let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); + if (!capturedBuffers) { + capturedBuffers = []; + this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } + for (const buffer of this.buffersPending) { + capturedBuffers.push(buffer); + } + this.buffersPending = []; } - this.buffersPending = []; } dispose() { @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager { storage.gpuData.buffer.destroy(); }); + this.capturedPendingBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); + this.capturedPendingBuffers = new Map(); + } + + onReleaseSession(sessionId: number) { + // release the captured pending buffers. + const pendingBuffers = this.capturedPendingBuffers.get(sessionId); + if (pendingBuffers) { + pendingBuffers.forEach(buffer => { + buffer.destroy(); + }); + this.capturedPendingBuffers.delete(sessionId); + } } } diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 72eb9713e26a8..9d05f607f817f 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,7 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { entries.push({binding: entries.length, resource: {buffer: input.buffer}}); @@ -51,8 +50,20 @@ export class ProgramManager { } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - computePassEncoder.setBindGroup(0, bindGroup); + if (this.backend.sessionStatus === 'capturing') { + const commandInfo = { + kernelId: this.backend.currentKernelId!, + computePipeline: buildArtifact.computePipeline, + bindGroup, + dispatchGroup + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(commandInfo); + } + + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 789ac70a6913a..a34b6190b7244 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -5,6 +5,8 @@ import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; +export type SessionState = 'default'|'capturing'|'replaying'; + export enum GpuDataType { default = 0, upload = 1, diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 41ab2d52ca209..48eac57494726 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -168,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.enableGraphCapture !== undefined) { + if (typeof sessionOptions.enableGraphCapture !== 'boolean') { + throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); + } + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + } + } + if (sessionOptions.freeDimensionOverrides) { for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { if (typeof name !== 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 046336dc9cac0..37b9ed6a1002f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -139,7 +139,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ]; const activeSessions = new Map(); @@ -235,6 +235,8 @@ export const createSession = async( const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); + const enableGraphCapture = !!options?.enableGraphCapture; + const inputNames = []; const outputNames = []; const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; @@ -256,12 +258,20 @@ export const createSession = async( outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (enableGraphCapture && options?.preferredOutputLocation === undefined) { + outputPreferredLocations.push('gpu-buffer'); + continue; + } const location = typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : options?.preferredOutputLocation?.[nameString] ?? 'cpu'; if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${ + location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); + } outputPreferredLocations.push(location); } } @@ -281,7 +291,9 @@ export const createSession = async( }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); + activeSessions.set( + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -313,13 +325,16 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; if (ioBindingState) { + if (enableGraphCapture) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); + } wasm._OrtReleaseBinding(ioBindingState.handle); } - wasm.jsepUnregisterBuffers?.(sessionId); + wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -328,70 +343,75 @@ export const releaseSession = (sessionId: number): void => { }; export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): - void => { - if (!tensor) { - tensorHandles.push(0); - return; - } + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, + enableGraphCapture = false): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; /** * perform inference run @@ -404,7 +424,12 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; + const ioBindingState = session[3]; + const enableGraphCapture = session[4]; + const inputOutputBound = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -427,13 +452,15 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + prepareInputOutputTensor( + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], + enableGraphCapture); } let inputValuesIndex = inputValuesOffset / 4; @@ -449,7 +476,7 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -486,9 +513,12 @@ export const run = async( } } } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } - wasm.jsepOnRunStart?.(); + wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -595,10 +625,12 @@ export const run = async( } } - if (ioBindingState) { + if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); } - return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index af9658271d210..308e1c7d952d9 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -3,6 +3,7 @@ #include "js_execution_provider.h" +#include #include #include #include @@ -681,9 +682,13 @@ std::unique_ptr RegisterKernels() { using namespace js; -JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) +JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options) : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, preferred_data_layout_{info.data_layout} { + if (session_options) { + enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; + LOGS_DEFAULT(VERBOSE) << "Graph capture enable: " << enable_graph_capture_; + } } std::vector JsExecutionProvider::CreatePreferredAllocators() { @@ -751,4 +756,46 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } +Status JsExecutionProvider::OnRunStart() { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; + EM_ASM({ Module.jsepCaptureBegin(); }); + } + return Status::OK(); +} + +Status JsExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + EM_ASM({ Module.jsepCaptureEnd(); }); + is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool JsExecutionProvider::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status JsExecutionProvider::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + EM_ASM({ Module.jsepReplay(); }); + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 39d43498c0717..91a3256ec2bd5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -5,6 +5,7 @@ #pragma once #include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" @@ -38,7 +39,7 @@ struct JsExecutionProviderInfo { class JsExecutionProvider : public IExecutionProvider { public: - JsExecutionProvider(const JsExecutionProviderInfo& info); + JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options); ~JsExecutionProvider() override; std::vector> GetCapability( @@ -57,7 +58,22 @@ class JsExecutionProvider : public IExecutionProvider { bool ConcurrentRunSupported() const override { return false; } std::vector CreatePreferredAllocators() override; + + Status OnRunStart() override; + Status OnRunEnd(bool sync_stream) override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); DataLayout preferred_data_layout_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory.cc b/onnxruntime/core/providers/js/js_provider_factory.cc index 5b7329a87cf6a..cbdf99f702150 100644 --- a/onnxruntime/core/providers/js/js_provider_factory.cc +++ b/onnxruntime/core/providers/js/js_provider_factory.cc @@ -10,21 +10,22 @@ namespace onnxruntime { struct JsProviderFactory : IExecutionProviderFactory { - JsProviderFactory(const ProviderOptions& provider_options) - : info_{provider_options} { + JsProviderFactory(const ProviderOptions& provider_options, const SessionOptions* session_options) + : info_{provider_options}, session_options_(session_options) { } std::unique_ptr CreateProvider() override { - return std::make_unique(info_); + return std::make_unique(info_, session_options_); } private: JsExecutionProviderInfo info_; + const SessionOptions* session_options_; }; std::shared_ptr JsProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - return std::make_shared(provider_options); + const ProviderOptions& provider_options, const SessionOptions* session_options) { + return std::make_shared(provider_options, session_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory_creator.h b/onnxruntime/core/providers/js/js_provider_factory_creator.h index dbabe255c2d7b..510b0fb4248ca 100644 --- a/onnxruntime/core/providers/js/js_provider_factory_creator.h +++ b/onnxruntime/core/providers/js/js_provider_factory_creator.h @@ -9,9 +9,11 @@ #include "core/providers/providers.h" namespace onnxruntime { +struct SessionOptions; struct JsProviderFactoryCreator { - static std::shared_ptr Create(const ProviderOptions& provider_options); + static std::shared_ptr Create(const ProviderOptions& provider_options, + const SessionOptions* session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c8fc812fe1238..94a750940f6ef 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -146,28 +146,29 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { - bool nodes_on_cpu_and_cuda_eps_only = true; +static bool AreAllComputeNodesAssignedToCudaOrJsEp(const Graph& graph) { + bool nodes_on_cpu_and_cuda_and_js_eps_only = true; for (const auto& node : graph.Nodes()) { const auto& node_provider = node.GetExecutionProviderType(); // Empty node provider means CPU EP if (!node_provider.empty() && - node_provider != kCudaExecutionProvider && + !(node_provider == kCudaExecutionProvider || + node_provider == kJsExecutionProvider) && node_provider != kCpuExecutionProvider) { - nodes_on_cpu_and_cuda_eps_only = false; + nodes_on_cpu_and_cuda_and_js_eps_only = false; break; } } - // If we see nodes assigned to EPs other than CPU or CUDA + // If we see nodes assigned to EPs other than CPU, or CUDA/JS // (or) if there are Memcpy nodes, then all compute nodes have - // not been parititoned to the CUDA EP. + // not been parititoned to the CUDA/JS EP. // We allow CPU EPs to show up in the EP list as long as thre is no Memcpy // involved as shape subgraphs will be forced onto CPU and these will not have // Memcpy nodes involved. - return nodes_on_cpu_and_cuda_eps_only && !HasMemcpyNodes(graph); + return nodes_on_cpu_and_cuda_and_js_eps_only && !HasMemcpyNodes(graph); } static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { @@ -1761,7 +1762,7 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently graph capture is only considered by CUDA EP, TRT EP and JS EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1774,47 +1775,62 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + // + // Check for JS EP: + // If the JS EP is part of the providers list for this session AND + // The JS EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the JS EP, + // Then the JS EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = { + onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kJsExecutionProvider}; - for (auto& it : cuda_graph_support_ep_list) { + for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // Graphs capture can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << "as the model has control flow nodes which can't be supported by " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the graph capture feature as requested by the user " + "as the model has control flow nodes which can't be supported by" + + target_ep->Type())); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { - // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) { + // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + if (!AreAllComputeNodesAssignedToCudaOrJsEp(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the graph capture feature as requested by the user " + " as all compute graph nodes have not been partitioned to the " + + target_ep->Type())); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the graph capture feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the graph capture feature."; } } else { // Following code path is for TRT EP currently. diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index ac059bfd00668..de5d93ef0a434 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -149,7 +149,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { provider_options["preferred_layout"] = preferred_layout; } - options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); + options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); #endif diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7e9c0a6f99c32..cbc60c70b57aa 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -24,7 +24,7 @@ Module['unmountExternalData'] = () => { /** * init JSEP */ -Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => { +Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => { Module.jsepBackend = backend; Module.jsepAlloc = alloc; Module.jsepFree = free; @@ -33,6 +33,9 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepCreateKernel = createKernel; Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; + Module.jsepCaptureBegin = captureBegin; + Module.jsepCaptureEnd = captureEnd; + Module.jsepReplay = replay; // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) // It removes some overhead in cwarp() and ccall() that we don't need. @@ -181,16 +184,16 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { return backend['registerBuffer'](sessionId, index, buffer, size); }; - Module['jsepUnregisterBuffers'] = sessionId => { - backend['unregisterBuffers'](sessionId); - }; Module['jsepGetBuffer'] = (dataId) => { return backend['getBuffer'](dataId); }; Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; - Module['jsepOnRunStart'] = () => { - return backend['onRunStart'](); + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); + }; + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); }; }; From d2db872d2be2f2e16b040b81c845e4e84e038386 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 31 Jan 2024 13:05:08 +0800 Subject: [PATCH 14/51] [js/webgpu] Use DataType as uniform cpu type (#19281) This saves turning data type to string by tensorDataTypeEnumToString. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 18 ++++++----- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 7 +++-- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 8 +++-- .../ops/3rd-party/conv_backprop_webgpu.ts | 8 +++-- .../ops/3rd-party/matmul_packed_webgpu.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 30 +++++++++---------- js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 5 ++-- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 13 ++++---- js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/einsum.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 7 +++-- .../wasm/jsep/webgpu/ops/gather-elements.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 6 ++-- js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 6 ++-- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 14 +++++---- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 5 ++-- .../jsep/webgpu/ops/multi-head-attentiion.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 7 ++--- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 20 +++++++------ js/web/lib/wasm/jsep/webgpu/ops/range.ts | 5 ++-- .../lib/wasm/jsep/webgpu/ops/reduce-shared.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 7 +++-- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 8 ++--- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 6 ++-- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 3 +- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 3 +- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/types.ts | 3 +- 37 files changed, 148 insertions(+), 108 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e1faecfc046e3..58efa795dba48 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -3,7 +3,7 @@ import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; -import {tensorDataTypeEnumToString} from '../wasm-common'; +import {DataType, tensorDataTypeEnumToString} from '../wasm-common'; import {configureLogger, LOG_DEBUG} from './log'; import {createView, TensorView} from './tensor-view'; @@ -453,10 +453,10 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const sizeOfElement = v.type === 'float16' ? 2 : 4; + const sizeOfElement = v.type === DataType.float16 ? 2 : 4; let sizeOfVecOrMat; let baseAlignment; - if (v.type === 'float16') { + if (v.type === DataType.float16) { baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; } else { @@ -470,7 +470,7 @@ export class WebGpuBackend { // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte // length is N * SizeOf(mat2x4). - const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4; currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : data.length * sizeOfElement; }); @@ -483,15 +483,17 @@ export class WebGpuBackend { programUniforms.forEach((v, i) => { const offset = offsets[i]; const data = typeof v.data === 'number' ? [v.data] : v.data; - if (v.type === 'int32') { + if (v.type === DataType.int32) { new Int32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'uint32') { + } else if (v.type === DataType.uint32) { new Uint32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'float16') { + } else if (v.type === DataType.float16) { // TODO: use Float16Array. new Uint16Array(arrayBuffer, offset, data.length).set(data); - } else { + } else if (v.type === DataType.float) { new Float32Array(arrayBuffer, offset, data.length).set(data); + } else { + throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`); } }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index e5ca3204d4433..bc39bd94e3072 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -189,9 +190,9 @@ export const createConv2DMatMulProgramInfo = const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.dilations} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index e50733559dbe9..d18f8586dd071 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -197,9 +198,10 @@ export const createConv2DTransposeMatMulProgramInfo = ]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, - {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides}, + {type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims}, + {type: DataType.int32, data: pads} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 380efc8bc577a..ba6776e9d8c94 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -17,6 +17,7 @@ // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; @@ -264,9 +265,10 @@ export const createConvTranspose2DProgramInfo = const outputChannelsPerGroup = wShape[1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, - {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, - {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + {type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides}, + {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, + {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, + {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) ]; if (hasBias) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 00c1f86d67419..d9a8d59f731de 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -447,8 +448,10 @@ export const createMatmulProgramInfo = const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; const bRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner} + ]; appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index f07a21a343fa8..2cfe6356dd6e7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; @@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(d / components / WG); - const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; - const programUniforms: ProgramUniform[] = - [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; + const programUniforms: ProgramUniform[] = [ + {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, + {type: DataType.uint32, data: elementsPerWG} + ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -336,11 +337,10 @@ const computeAttentionProbs = y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; - const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, - {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, - {type: tensorDataType, data: alpha} + {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, + {type: DataType.uint32, data: parameters.totalSequenceLength}, + {type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha} ]; const inputs = [q, key]; @@ -430,9 +430,9 @@ const computeVxAttentionScore = z: params.batchSize * params.numHeads }; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, - {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, - {type: 'uint32', data: params.vHiddenSize} + {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, + {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, + {type: DataType.uint32, data: params.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { }; const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, - {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, - {type: 'uint32', data: parameters.hiddenSize}, - {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, + {type: DataType.uint32, data: parameters.hiddenSize}, + {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 159b971636765..39b932375891b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: useShapesUniforms ? [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(yShape), ] : [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 8e144a36dc1b0..51f0c76ed8824 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -179,7 +179,7 @@ const createBinaryOpProgramInfo = outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, ...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(b.dims), ...createTensorShapeVariables(outputShape), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 1bedf31ee4e38..3de57d5ac7f7c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -259,8 +259,9 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => - dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ? + [] : + [{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}]; /** * A helper function to get maximum vector size for specified data length diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index daa326b1a34e2..b06c9fb496d15 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -95,14 +96,14 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P let previousSum = 0; const inputDependencies: ProgramInputTensorInfoDependency[] = []; const inputRanks = []; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; inputRanks.push(inputs[i].dims.length); inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); inputDependencies.push('rank'); - programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); } for (let i = 0; i < inputs.length; ++i) { programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index c0aaaa7ce134b..3c2c3cc4e046c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; @@ -28,9 +29,10 @@ export const createGroupedConvProgramInfo = const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, - {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, - {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations}, + {type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.uint32, data: outputChannelsPerGroup} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( @@ -127,8 +129,9 @@ export const createGroupedConvVectorizeProgramInfo = const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]}, - {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]} + {type: DataType.uint32, data: outputSize}, + {type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index 2ff909c30e62e..fb17202cd042f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -54,7 +54,7 @@ const createCumsumProgramInfo = outputs: [{dims: inputShape, dataType: inputType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axis}, + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) ] diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 9e1f58bbfa127..19a009c2eb79b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -272,8 +273,10 @@ const createEinsumProgramInfo = // filter is added to make sure that dimValue is never 0. const programUniformsInit: ProgramUniform[] = uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) - .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); - programUniformsInit.push({type: 'uint32', data: outputSize}); + .map( + (symbol) => + ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: DataType.uint32, data: outputSize}); const programUniforms: ProgramUniform[] = inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index dd18bd23a5912..f8fdb63160380 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -85,7 +85,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => }; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape) ]; return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index e1dc9a5e0ab7d..60067c014613b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {MAX_CLIP, MIN_CLIP} from '../../util'; import {ProgramUniform} from '../types'; @@ -36,9 +37,11 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v export const appendActivationUniformsData = (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { if (attributes.activation === 'Clip') { - programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + programUniform.push( + {type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!}); } else if (attributes.activation === 'HardSigmoid') { - programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!}); + programUniform.push( + {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a945954adcaa4..a2d4e3d28f7c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -46,8 +47,10 @@ const createGatherElementsProgramInfo = const output = outputVariable('output', inputOutputDataType, outputShape.length); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis} + ]; programUniforms.push(...createTensorShapeVariables(inputShape)); programUniforms.push(...createTensorShapeVariables(indicesShape)); programUniforms.push(...createTensorShapeVariables(outputShape)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index e2a62c6655c72..f2c71a9cd4188 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -34,9 +34,9 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index a0d4021516bf7..76302e1af2e53 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -45,8 +46,9 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt } const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K}, - {type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha}, + {type: DataType.float, data: attributes.beta} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; if (inputs.length === 3) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index a835c90bd5451..2096b898b5d40 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -25,7 +25,7 @@ const createInstanceNormProgramInfo = const inputShape = [xShape[0], xShape[1], normPackedSize]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}]; + [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -132,8 +132,9 @@ const computeMean = const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; const meanProgramUniforms: ProgramUniform[] = [ - {type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)}, - {type: 'uint32', data: Math.floor(h * c / components)} + {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(h * c / components)} ]; const getMeanShaderSource = (shaderHelper: ShaderHelper) => { @@ -182,8 +183,9 @@ const computeMean = {inputs: [input], outputs: [-1]})[0]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h}, - {type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)} + {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(WG * c / components)} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -246,7 +248,7 @@ const createInstanceNormNHWCProgramInfo = const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}]; + [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3c9f6ce71bb67..3f73d9cb7c5bc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -49,8 +49,9 @@ const createLayerNormProgramInfo = const components = getMaxComponents(normSize); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, - {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} + {type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize}, + {type: DataType.uint32, data: Math.floor(normSize / components)}, + {type: DataType.float, data: attributes.epsilon} ]; if (bias) { inputDependencies.push('type'); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 188b88b2510d8..b263451b99134 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; @@ -29,8 +30,8 @@ export const createNaiveMatmulProgramInfo = const outputShapeInShader = [batchSize, M, N]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K} ]; appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index 6d22e3780efd9..5c5c849d99811 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -238,8 +239,10 @@ const addBiasTranspose = hiddenSize: number, biasOffset: number) => { const outputShape = [batchSize, sequenceLength, hiddenSize]; const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset}, + {type: DataType.uint32, data: hiddenSize} + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index c65b741e1105a..9f5e60773f080 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; @@ -153,10 +153,9 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const inputDims = inputs[0].dims; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}]; + [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.pads}]; if (attributes.mode === 0) { - const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type']; - programUniforms.push({type: tensorDataType, data: attributes.value}); + programUniforms.push({type: inputs[0].dataType, data: attributes.value}); } programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9e9b361c1af1c..70b8acc3146a0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -56,7 +57,8 @@ const getUniformAndPadInfo = ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: outputSize}, - programUniforms: [{type: 'uint32', data: reduceSize}] + programUniforms: [{type: DataType.uint32, data: reduceSize}] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index e8851ac546942..123eb38a1fb93 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -101,7 +101,7 @@ export const createReduceProgramInfo = outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape) ] }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index f68526acc0e63..edfd856aeb850 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -641,9 +642,9 @@ const createResizeProgramInfo = outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, - {type: 'float32', data: scales}, - {type: 'float32', data: roi}, + {type: DataType.uint32, data: outputSize}, + {type: DataType.float, data: scales}, + {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape), ] diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 509a722f4b52a..7be9ceec6bc65 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -88,10 +88,10 @@ const createSkipLayerNormProgramInfo = const components = getMaxComponents(hiddenSize); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, - {type: 'uint32', data: components}, - {type: 'uint32', data: hiddenSize}, - {type: 'float32', data: attributes.epsilon}, + {type: DataType.uint32, data: outputSize}, + {type: DataType.uint32, data: components}, + {type: DataType.uint32, data: hiddenSize}, + {type: DataType.float, data: attributes.epsilon}, ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniformsArray: UniformsArrayType = [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5212c6475dce0..6baa634f69f82 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -155,9 +155,9 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice ]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, - {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, + {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 324dc3af1a710..6f8bfa08d7b62 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,6 +5,7 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -136,7 +137,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut getRunData: () => ({ outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}, - programUniforms: [{type: 'uint32', data: packedCols}] + programUniforms: [{type: DataType.uint32, data: packedCols}] }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index b8582614fa214..0b703de2ffa1c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -72,7 +73,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; @@ -82,7 +83,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split outputs[i] = outputVariable(`output${i}`, dataType, outputShape); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); + programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis}); programUniforms.push(...createTensorShapeVariables(inputShape)); outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 90a36a7bec2a9..b080767d2faac 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -80,7 +80,7 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) ], }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index ab9a9ac8dd1f0..920da04398832 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -65,7 +66,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), ], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 76929efb32537..1accfac18b876 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -53,7 +53,7 @@ const createElementwiseProgramInfo = dispatchGroup: {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, ], }) }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 2ef9637bcda5e..51e8f56c229bd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -98,8 +98,9 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), - ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC), + ...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB), + ...createTensorShapeVariables(outputShape) ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index a34b6190b7244..ba5b84fcfe067 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../wasm-common'; import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; @@ -26,7 +27,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float16'|'float32'|'uint32'; + type: DataType; data: number|readonly number[]; } From e305794db005d79daaf74ec6437d401430ffda7d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 30 Jan 2024 21:06:21 -0800 Subject: [PATCH 15/51] [js/webgpu] resolve codescan alert (#19343) ### Description resolve codescan alert: https://github.com/microsoft/onnxruntime/security/code-scanning/17687 --- js/web/lib/wasm/jsep/backend-webgpu.ts | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 58efa795dba48..4b544595d76bb 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -705,13 +705,11 @@ export class WebGpuBackend { captureBegin(): void { LOG_DEBUG('info', 'captureBegin'); - let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); - let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); - if (!sessionCommandList) { - sessionCommandList = []; - this.capturedCommandList.set(this.currentSessionId!, sessionCommandList); - sessionPendingKernels = []; - this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); + if (!this.capturedCommandList.get(this.currentSessionId!)) { + this.capturedCommandList.set(this.currentSessionId!, []); + } + if (!this.capturedPendingKernels.get(this.currentSessionId!)) { + this.capturedPendingKernels.set(this.currentSessionId!, []); } // flush the left commands before we change the status. this.flush(); From fcefc67268ad5169484d50a257a75ad48154e38f Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 2 Feb 2024 09:59:00 +0800 Subject: [PATCH 16/51] [js/webgpu] Refactor createTensorShapeVariables (#18883) --- .../jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 3 +-- .../webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 3 +-- .../webgpu/ops/3rd-party/conv_backprop_webgpu.ts | 2 +- .../webgpu/ops/3rd-party/matmul_packed_webgpu.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 13 ++++++++++--- js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 8 ++------ js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 6 ++---- js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 3 +-- js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 4 ++-- js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 6 ++---- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 7 ++----- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 5 ++--- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 6 ++---- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 7 ++----- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 7 ++----- 22 files changed, 40 insertions(+), 64 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index bc39bd94e3072..fc2146068de70 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -195,8 +195,7 @@ export const createConv2DMatMulProgramInfo = {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d18f8586dd071..b5b6a2a15cd8c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -204,8 +204,7 @@ export const createConv2DTransposeMatMulProgramInfo = {type: DataType.int32, data: pads} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ba6776e9d8c94..846ad49c5222b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -269,7 +269,7 @@ export const createConvTranspose2DProgramInfo = {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims) ]; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index d9a8d59f731de..8abc27a24861d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -453,9 +453,7 @@ export const createMatmulProgramInfo = {type: DataType.int32, data: dimInner} ]; appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), - ...createTensorShapeVariables(bShapeTemp)); + programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 51f0c76ed8824..a094fffe239c4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -180,9 +180,7 @@ const createBinaryOpProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, programUniforms: [ {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b.dims), - ...createTensorShapeVariables(outputShape), + ...createTensorShapeVariables(a.dims, b.dims, outputShape) ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 3de57d5ac7f7c..516094d0ef87b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -259,9 +259,16 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ? - [] : - [{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (...dims: ReadonlyArray): ProgramUniform[] => { + const programUniforms: ProgramUniform[] = []; + dims.forEach(dim => { + if (dim.length !== 0) { + programUniforms.push( + {type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)}); + } + }); + return programUniforms; +}; /** * A helper function to get maximum vector size for specified data length diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 3c2c3cc4e046c..8495f9040a1b6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -35,9 +35,7 @@ export const createGroupedConvProgramInfo = {type: DataType.uint32, data: outputChannelsPerGroup} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), - ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); @@ -134,9 +132,7 @@ export const createGroupedConvVectorizeProgramInfo = {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), - ...createTensorShapeVariables(outputShapeInShader)); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader)); const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index fb17202cd042f..6080301d9946b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -55,7 +55,7 @@ const createCumsumProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, - ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) + ...createTensorShapeVariables(inputShape, inputShape) ] }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index f8fdb63160380..80ee906423e19 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -84,10 +84,8 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ${assignment}`; }; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape) - ]; + const programUniforms: ProgramUniform[] = + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]; return { name: 'Expand', shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a2d4e3d28f7c5..4ab6c175a67e2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -51,9 +51,7 @@ const createGatherElementsProgramInfo = {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, {type: DataType.uint32, data: axis} ]; - programUniforms.push(...createTensorShapeVariables(inputShape)); - programUniforms.push(...createTensorShapeVariables(indicesShape)); - programUniforms.push(...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index f2c71a9cd4188..5c31e6dd86c00 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -35,8 +35,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, - {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 2096b898b5d40..2f652dbd310ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -26,7 +26,7 @@ const createInstanceNormProgramInfo = const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; - programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); + programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index b263451b99134..0c533974e2b26 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -34,9 +34,7 @@ export const createNaiveMatmulProgramInfo = {type: DataType.uint32, data: K} ]; appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 9f5e60773f080..236fc29fdf1ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -158,7 +158,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr programUniforms.push({type: inputs[0].dataType, data: attributes.value}); } - programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 70b8acc3146a0..4e933573b9137 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -298,7 +298,7 @@ const createAveragePoolProgramInfo = } const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(outputShape, adjustedAttributes); - programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; return { name, @@ -370,7 +370,7 @@ const createMaxPoolProgramInfo = const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(outputShape, adjustedAttributes); - programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); return { name, shaderCache: diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 123eb38a1fb93..e8205ba6fd928 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -100,10 +100,8 @@ export const createReduceProgramInfo = getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape) - ] + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index edfd856aeb850..2c6b537de1f00 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -642,11 +642,8 @@ const createResizeProgramInfo = outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: DataType.uint32, data: outputSize}, - {type: DataType.float, data: scales}, - {type: DataType.float, data: roi}, - ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape), + {type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales}, + {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape) ] }) }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 6baa634f69f82..a5e71f30e5966 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -157,7 +157,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) + ...createTensorShapeVariables(inputs[0].dims, outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 0b703de2ffa1c..14d6f37927590 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -83,9 +83,8 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split outputs[i] = outputVariable(`output${i}`, dataType, outputShape); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis}); - programUniforms.push(...createTensorShapeVariables(inputShape)); - outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); + programUniforms.push( + {type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes)); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${ shaderHelper.registerUniform('input_size', 'u32') diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index b080767d2faac..f9728575fe072 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -79,10 +79,8 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 920da04398832..7ae801222b875 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -65,11 +65,8 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu return { outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, - ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape), - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }; }, getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 51e8f56c229bd..cfee07a9239d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -97,11 +97,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, - programUniforms: [ - {type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC), - ...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB), - ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)], }), }; }; From bff2f5b51617f2ca0054a24ab043727739e0bc15 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 2 Feb 2024 18:04:06 +0800 Subject: [PATCH 17/51] [js/webgpu] Fix the undefined push error (#19366) ### Description This PR fixes below errors when enable webgpu profiling: ``` TypeError: Cannot read properties of undefined (reading 'push') ``` --- js/web/lib/wasm/jsep/backend-webgpu.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 4b544595d76bb..98990a6fe477b 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -530,8 +530,10 @@ export class WebGpuBackend { }; this.pendingKernels.push(pendingKernelInfo); - const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); - sessionPendingKernels!.push(pendingKernelInfo); + if (this.sessionStatus === 'capturing') { + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); + } } this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); From 8cf59d262da55bbffdc93475ccfc27932088ec51 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 3 Feb 2024 01:06:38 +0800 Subject: [PATCH 18/51] [js/webgpu] Add LeakyRelu activation for fusedConv (#19369) ### Description This PR 1) adds LeakyRelu activation for fusedConv; 2) makes `vec4` value work with `float32` uniforms attributes. For example: `clamp(value, vec4(uniforms.clip_min), vec4(uniforms.clip_max)` will throw compilation errors since `uniforms.clip_min` and `uniforms.clip_min` are `f32` not `f16`. So we need to change it to `clamp(value, vec4(f16(uniforms.clip_min)), vec4(f16(uniforms.clip_max))` And above problem was introduced when we make activation attributes as uniforms instead of constant. BTW, after adding LeakyRelu, `realesrgan-t256` model can pass. --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 2 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 3 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 8 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 47 +++--- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 5 +- js/web/test/data/ops/fused-conv.jsonc | 144 ++++++++++++++++++ 6 files changed, 184 insertions(+), 25 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index fc2146068de70..24006d393592a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -130,7 +130,7 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const applyActivation = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType, dataType); const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 8abc27a24861d..29c7941e6bd30 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -479,7 +479,8 @@ export const createMatmulProgramInfo = const uniforms: UniformsArrayType = [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; appendActivationUniforms(activationAttributes, uniforms); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const declareFunctions = matMulReadWriteFnSource( components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], isChannelsLast); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 8495f9040a1b6..7d424305c715f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; @@ -45,7 +45,8 @@ export const createGroupedConvProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShape.length); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length); const w = inputVariable('w', inputs[1].dataType, wShape.length); const inputVars = [x, w]; @@ -136,7 +137,8 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 60067c014613b..6e66abacf3471 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -15,24 +15,28 @@ export interface InternalActivationAttributes { readonly beta?: number; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { - switch (attributes.activation) { - case 'Relu': - return `value = max(value, ${valueType}(0.0));`; - case 'Sigmoid': - return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; - case 'Clip': - return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; - case 'HardSigmoid': - return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${ - valueType}(uniforms.beta)));`; - case '': - return ''; - // TODO: adding other activations that can be fused. - default: - throw new Error(`Unsupported activation ${attributes.activation}`); - } -}; +export const getActivationSnippet = + (attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ + baseType}(uniforms.clip_max)));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ + baseType}(uniforms.beta)));`; + case 'LeakyRelu': + return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case '': + return ''; + // TODO: adding other activations that can be fused. + default: + throw new Error(`Unsupported activation ${attributes.activation}`); + } + }; export const appendActivationUniformsData = (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { @@ -42,6 +46,8 @@ export const appendActivationUniformsData = } else if (attributes.activation === 'HardSigmoid') { programUniform.push( {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); + } else if (attributes.activation === 'LeakyRelu') { + programUniform.push({type: DataType.float, data: attributes.alpha!}); } }; @@ -50,6 +56,8 @@ export const appendActivationUniforms = (attributes: InternalActivationAttribute uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); } else if (attributes.activation === 'HardSigmoid') { uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); + } else if (attributes.activation === 'LeakyRelu') { + uniforms.push({name: 'alpha', type: 'f32'}); } }; @@ -62,6 +70,9 @@ export const parseInternalActivationAttributes = } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; return {activation, clipMax, clipMin}; + } else if (activation === 'LeakyRelu') { + const [alpha] = attributes?.activation_params as [number] || [0.01]; + return {activation, alpha}; } return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 0c533974e2b26..1a92d861002fb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -7,7 +7,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -45,7 +45,8 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index c734d6db9b92a..6a10e3b96a26a 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -286,5 +286,149 @@ ] } ] + }, + { + "name": "fused group-conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, -6, 51, 47, -170, -10, 251, 229, 847, 889, 973, 1015], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-162, 63, -158, 33, 281, 85, 105, 337, 455, 177, 515, 609], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] } ] From 257cf5e9ba6a0a2275525bb5644ca53242d4c2d8 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 6 Feb 2024 09:07:31 -0800 Subject: [PATCH 19/51] [js/webgpu] support customop FastGelu (#19392) ### Description Support WebGPU custom operator FastGelu. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../wasm/jsep/webgpu/ops/bias-split-gelu.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts | 69 ++++++ js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 33 ++- js/web/test/data/ops/fast-gelu.jsonc | 211 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + onnxruntime/contrib_ops/js/fast_gelu.cc | 23 ++ onnxruntime/contrib_ops/js/fast_gelu.h | 17 ++ .../contrib_ops/js/js_contrib_kernels.cc | 2 + 10 files changed, 353 insertions(+), 8 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts create mode 100644 js/web/test/data/ops/fast-gelu.jsonc create mode 100644 onnxruntime/contrib_ops/js/fast_gelu.cc create mode 100644 onnxruntime/contrib_ops/js/fast_gelu.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2557971eb4ded..b21af8e715db3 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -41,6 +41,7 @@ Do not modify directly.* | Erf | ai.onnx(9-12,13+) | | | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | +| FastGelu | com.microsoft(1+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index d737a28654220..ac08c5fb1f7ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -13,6 +13,7 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose' import {cumsum, parseCumSumAttributes} from './ops/cumsum'; import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; +import {fastGelu} from './ops/fast-gelu'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; @@ -72,6 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Erf', [unaryOps.erf]], ['Exp', [unaryOps.exp]], ['Expand', [expand]], + ['FastGelu', [fastGelu]], ['Floor', [unaryOps.floor]], ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index a81a7a8f1df5c..089fecd758e30 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl(`vec4<${dataType}>`, dataType)} + ${erfImpl(dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts new file mode 100644 index 0000000000000..f50a6a3f011fe --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common'; +import * as unary from './unary-op'; + +// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias. + +const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => { + const dataType = inputTensors[0].dataType; + const outputSize = ShapeUtil.size(inputTensors[0].dims); + const biasLength = ShapeUtil.size(inputTensors[1].dims); + // can only use vec4 when bias length is multiple of 4 + const useVec4 = biasLength % 4 === 0; + const getShaderSource = (shaderHelper: ShaderHelper): string => { + const x = inputVariable('x', dataType, [1], 4); + const bias = inputVariable('bias', dataType, [1], 4); + const y = outputVariable('y', dataType, [1], 4); + + const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}]; + + const singleElementBias = (i: 0|1|2|3) => ` + let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size; + let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`; + const biasGetExpression = useVec4 ? + ` + let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` : + `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)} + let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`; + + return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)} + + ${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))} + + ${shaderHelper.mainStart(WORKGROUP_SIZE)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')} + + let x = ${x.getByOffset('global_idx')}; + ${biasGetExpression} + let x_in = x + bias; + ${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))} + }`; + }; + + return { + name: 'FastGeluWithBias', + shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']}, + getShaderSource, + getRunData: (inputs) => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + programUniforms: + [{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}], + dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)} + }) + }; +}; + +export const fastGelu = (context: ComputeContext): void => { + if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) { + unary.fastGelu(context); + } else { + context.compute(createFastGeluProgramInfo(context.inputs)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 1accfac18b876..5f105c745739e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -178,7 +178,7 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string, varType = 'f32') => ` +export const erfImpl = (varType = 'f32') => ` const r0: ${varType} = 0.3275911; const r1: ${varType} = 0.254829592; const r2: ${varType} = -0.284496736; @@ -186,7 +186,7 @@ const r3: ${varType} = 1.421413741; const r4: ${varType} = -1.453152027; const r5: ${varType} = 1.061405429; -fn erf_vf32(v: ${dataType}) -> ${dataType} { +fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> { let absv = abs(v); let x = 1.0 / (1.0 + r0 * absv); return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); @@ -194,8 +194,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { export const erf = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType))); }; export const exp = (context: ComputeContext): void => { @@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => { export const gelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl(`vec4<${dataType}>`, dataType))); + context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { @@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan')); }; +export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`; + export const tanh = (context: ComputeContext): void => { // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression)); +}; + +export const fastGeluImpl = (varType = 'f32') => ` +const fast_gelu_a: ${varType} = 0.5; +const fast_gelu_b: ${varType} = 0.7978845608028654; +const fast_gelu_c: ${varType} = 0.035677408136300125; + +fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> { + return ${tanhExpression('v')}; +} +`; + +export const fastGeluExpression = (x: string) => + `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`; + +export const fastGelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); + context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined, + context.inputs[0].dataType)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/test/data/ops/fast-gelu.jsonc b/js/web/test/data/ops/fast-gelu.jsonc new file mode 100644 index 0000000000000..2550173e95402 --- /dev/null +++ b/js/web/test/data/ops/fast-gelu.jsonc @@ -0,0 +1,211 @@ +[ + { + "name": "FastGelu test without bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.841192], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581, + 1.0617, 1.17393, 1.28671, 1.39957 + ], + "dims": [3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "FastGelu test with bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + }, + { + "data": [0.5], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.39957], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [3]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [2]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2, 3], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869, + 4.39999, 3.49938 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [7]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7], + "dims": [7], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989, + 4.09996, 3.59959 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[4x4], [8]", + "inputs": [ + { + "data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0], + "dims": [4, 4], + "type": "float32" + }, + { + "data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957, + 1.39957, 4.39999, 1.0617, -0.149419, 3.09737 + ], + "dims": [4, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 56db28b0a379c..55b21283025c2 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1352,6 +1352,7 @@ "equal.jsonc", "exp.jsonc", "expand.jsonc", + "fast-gelu.jsonc", "floor.jsonc", "gather-elements.jsonc", "gemm.jsonc", diff --git a/onnxruntime/contrib_ops/js/fast_gelu.cc b/onnxruntime/contrib_ops/js/fast_gelu.cc new file mode 100644 index 0000000000000..62c538318160d --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fast_gelu.h b/onnxruntime/contrib_ops/js/fast_gelu.h new file mode 100644 index 0000000000000..68c7892741c66 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(FastGelu, FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 498a9f5679eb5..bd58dded026a6 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,6 +8,7 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); @@ -24,6 +25,7 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From 2eb5f3b52fae868c0a0bf5cde8698dbc90fb1e11 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:28:27 -0800 Subject: [PATCH 20/51] [js/webgpu] allow uint8 tensors for webgpu (#19545) ### Description allow uint8 tensors for webgpu --- js/common/lib/tensor-impl.ts | 2 +- js/common/lib/tensor.ts | 2 +- js/web/lib/wasm/wasm-common.ts | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c728556..de18126a9d0ae 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e057..d5da33640dc7d 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c4..93910af1f1bf0 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -169,7 +169,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value From 0f97b5bfb0e247c58b9fbcceb2ab3a2c98a87088 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Sat, 17 Feb 2024 09:19:17 -0800 Subject: [PATCH 21/51] [JS/WebGPU] Add MatMulNBits (#19446) ### Description Add MatMulNBits to support MatMul using 4-bit quantized weights ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + js/web/lib/wasm/jsep/util.ts | 28 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 184 ++ js/web/test/data/ops/matmulnbits.jsonc | 1527 +++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + .../contrib_ops/js/js_contrib_kernels.cc | 16 +- .../js/quantization/matmul_nbits.cc | 25 + .../js/quantization/matmul_nbits.h | 48 + 9 files changed, 1825 insertions(+), 7 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts create mode 100644 js/web/test/data/ops/matmulnbits.jsonc create mode 100644 onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/js/quantization/matmul_nbits.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index b21af8e715db3..4a8c92bb97bfd 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -62,6 +62,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | +| MatMulNBits | com.microsoft(1+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 6922d7ff5df6e..c0517ce363644 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -92,6 +92,34 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + /** + * convert dims corresponding to type change to pack. ex. uint8 data to uint32 + */ + static convertShape(dims: readonly number[], size = 4): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } + const newDims = new Array(rank); + let i = rank - 1; + while (i >= 0) { + if (dims[i] % size === 0) { + newDims[i] = dims[i] / size; + break; + } + if (size % dims[i] !== 0) { + throw new Error('cannot convert shape'); + } + newDims[i] = 1; + size /= dims[i]; + i--; + } + for (i--; i >= 0; i--) { + newDims[i] = dims[i]; + } + return newDims; + } + /** * calculate the size (number of elements) from the given axis (inclusive) */ diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ac08c5fb1f7ab..ba874c8dd0f80 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -20,6 +20,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad} from './ops/pad'; import * as pool from './ops/pool'; @@ -92,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], + ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts new file mode 100644 index 0000000000000..ead7635cf3ac4 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; + +// TODO support quantization bits not equal to 4 +export interface MatMulNBitsAttributes extends AttributeWithCacheKey { + k: number; + n: number; + accuracyLevel: number; + bits: number; + blockSize: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('MatMulNBits requires 3 or 4 inputs'); + } + const a = inputs[0]; + const aRank = a.dims.length; + if (a.dims[aRank - 1] !== attributes.k) { + throw new Error('The last dim of input shape does not match the k value'); + } + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const b = inputs[1]; + if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); + } + const scales = inputs[2]; + const scalesShape = scales.dims; + if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + throw new Error('scales input size error.'); + } + if (inputs.length === 4) { + const zeroPoints = inputs[3]; + const zeroPointsShape = zeroPoints.dims; + const expectedZeroPointsSize = + attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { + throw new Error('zeroPoints input size error.'); + } + } +}; + +export const createMatMulNBitsProgramInfo = + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { + const a = inputs[0]; + const b = inputs[1]; + const scales = inputs[2]; + const aRank = a.dims.length; + const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n); + const outputSize = ShapeUtil.size(outputShape); + + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, + {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, + {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} + ]; + programUniforms.push(...createTensorShapeVariables(a.dims)); + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims))); + programUniforms.push(...createTensorShapeVariables(scales.dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length); + const b = inputVariable('b', DataType.uint32, inputs[1].dims.length); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'}, + {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} + ]; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const wordPerBlob = blobSize / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{ + var result = array<${dataType}, 8>(); + var offset: u32 = 0; + let count: u32 = 4; + for (var i: u32 = 0; i < 8u; i++) { + result[i] = ${dataType}(extractBits(value, offset, count)); + offset += count; + } + return result; + } + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var value: ${dataType} = 0.0; + let output_indices = ${output.offsetToIndices('global_idx')}; + var a_indices: ${a.type.indices} = output_indices; + var n = ${output.indicesGet('output_indices', aRank - 1)}; + // Two zero points are packed into one byte because uniforms.bits <= 4. + // zero_point_offset is either 0 or 4. It is bit offset within one byte. + // TODO support zero_point_offset for bits > 4 + ${ + zeroPoints ? ` + var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : + ''} + var scale_idex = n * ${nBlocksPerCol}; + var b_indices: ${b.type.indices}; + ${b.indicesSet('b_indices', '0', 'n')}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_idex')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point: ${dataType} = ${ + zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0}; + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${wordPerBlob}; word++) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_value = ${b.getByIndices('b_indices')}; + let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var i: u32 = 0; i < 8; i++) { + ${a.indicesSet('a_indices', aRank - 1, 'offset')}; + let a_value = ${a.getByIndices('a_indices')}; + let b_quantized_value = b_quantized_values[i]; + let b_dequantized_value = (b_quantized_value - zero_point) * scale; + value += a_value * b_dequantized_value; + offset++; + } + word_offset += 8; + } + scale_idex++; + ${ + zeroPoints ? ` + if (zero_point_offset == 28) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + } else { + zero_point_offset += 4; + }` : + ''} + block_offset += uniforms.block_size; + } + ${output.setByOffset('global_idx', 'value')}; + } + `; + }; + return { + name: 'MatMulNBits', + shaderCache: + {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64)}, + programUniforms + }), + getShaderSource + }; + }; + +export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); +}; + +export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc new file mode 100644 index 0000000000000..c57c431afb3ce --- /dev/null +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -0,0 +1,1527 @@ +[ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0, + -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065, + 0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592, + 7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107, + -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827, + -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081, + -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, + -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405, + -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403, + -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577, + -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040, + 29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360, + 10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560, + -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104, + -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456, + -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483, + -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0, + -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587, + -123424, 47745 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384, + 124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560, + 103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360, + 81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320, + 57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976, + 35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864, + 15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800, + 479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256, + 422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200, + 360008, 226848, 451256, 292432, 550440 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140, + -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0, + -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260, + -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360, + 42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132, + -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096, + 80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332, + -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500, + 20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996, + -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860, + -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608, + -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024, + 48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556, + -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352, + 19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936, + 208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800, + -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304, + 159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224, + -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640, + 105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480, + -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220, + -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964, + 131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028, + -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900, + 55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892, + -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160, + 142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544, + -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580, + -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696, + 401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852, + -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632, + 223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384, + -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484, + -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080, + 435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668, + -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592, + 328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448, + -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236, + -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124, + -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056, + 54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700, + 160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664, + -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908, + -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988, + -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420, + -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416, + 625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940, + -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488, + 339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432, + -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920, + 122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156, + 293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600, + -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692, + 37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100, + -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892, + -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040, + 708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420, + -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416, + 519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420, + -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316, + -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920, + 768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476, + -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488, + 562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740, + -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740, + -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800, + 829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532, + -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560, + 604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060, + -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164, + -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680, + 889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588, + -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632, + 646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380, + -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588, + -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560, + 950212 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136, + 21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004, + 0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300, + 42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780, + 118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524, + 53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388, + 171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800, + 77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728, + 255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964, + 40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740, + 200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516, + 18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652, + 173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960, + 514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180, + 135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448, + 532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468, + 92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144, + 479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688, + 68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236, + 267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600, + 795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800, + 263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700, + 558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212, + 92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040, + 583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376, + 51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700, + 310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984, + 1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240, + 275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364, + 637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060, + 78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712, + 645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000, + 1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440, + 401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700, + 849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252, + 135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640, + 860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420, + 23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992, + 558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460, + 1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948, + 214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624, + 1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548, + 61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976, + 750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112, + 1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360, + 429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012, + 992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624, + 165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572, + 640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532, + 1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220, + 348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340, + 1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280, + 181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628, + 702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348, + 1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220, + 380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668, + 1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936, + 197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684, + 764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164, + 1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220, + 412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996, + 1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592, + 214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740, + 827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980, + 1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220, + 445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324, + 1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248, + 230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796, + 889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796, + 2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220, + 477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652, + 1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904, + 246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852, + 951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612, + 2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220, + 509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980, + 1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124, + 108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012, + 87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, + 4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620, + -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172, + -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372, + -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916, + 548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, + 238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844, + -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420, + -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876, + -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, + 876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004, + 207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188, + -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492, + -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388, + 1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220, + 450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260, + 284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372, + 349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, + 337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, + 284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364, + 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500, + 43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, + -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, + 1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076, + 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764, + 443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556, + 144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, + -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300, + 1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, + 1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108, + 633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188, + 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436, + 3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964, + 1856780, 2218404, 2611836, 3037076, 3240172 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476, + 100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, + 56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844, + 6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228, + -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124, + -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060, + -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852, + 302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468, + 175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580, + -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724, + -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996, + -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, + 466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124, + 260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212, + -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524, + -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172, + -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956, + 484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564, + 138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492, + -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692, + -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836, + 1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780, + 173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084, + -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396, + -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532, + 927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764, + 474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492, + -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468, + -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156, + 2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116, + 224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860, + -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004, + -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, + 860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124, + -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980, + -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940, + -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036, + 1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884, + -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916, + 2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564, + 293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788, + -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508, + -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676, + 1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884, + -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228, + -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508, + -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172, + 1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820, + -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964, + -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676, + 3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012, + 362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716, + -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012, + -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564, + 1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644, + -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476, + -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076, + -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308, + 1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084, + -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044, + -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436, + 4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460, + 431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644, + -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676, + -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204, + 1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516, + -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116, + -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764, + 2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012, + 1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628, + -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940, + -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156, + 3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900, + 182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324, + -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708, + 5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980, + 905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404, + -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116, + -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428, + 2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404, + 491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636, + -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836, + -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652, + 3300516, 4360956, 5485972, 5643372 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404, + 163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, + 170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620, + 255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572, + 79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716, + 158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244, + -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396, + 716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892, + 1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100, + 350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820, + 541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164, + 36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, + -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836, + 1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468, + 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852, + 791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668, + 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956, + 1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068, + 2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, + 544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204, + 594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300, + -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748, + 1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972, + 2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836, + 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364, + 69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564, + 3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, + 1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932, + 1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388, + 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156, + 2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948, + 3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452, + 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444, + 957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700, + -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, + 2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612, + 3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036, + 355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084, + 113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644, + 5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308, + 1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172, + 1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508, + -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956, + 2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196, + 3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756, + 424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716, + 139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292, + 6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860, + 1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996, + 1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636, + -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284, + 3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780, + 4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476, + 494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348, + 165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940, + 8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412, + 1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820, + 2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764, + -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612, + 3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364, + 5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196, + 564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980, + 191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588, + 9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964, + 2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644, + 2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460, + -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356, + 3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020, + 4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044, + 386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436, + -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220, + 9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292, + 1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164, + 2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268, + 12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420, + 3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844, + 4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564, + 217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676, + 6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916, + 8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156, + 1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100, + 1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668, + 13169812, 13583340 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040, + -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416, + 59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072, + 52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888, + 66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712, + 10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832, + -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104, + -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672, + -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648, + 315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488, + 104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056, + -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360, + -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824, + -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256, + 0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360, + 266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600, + 147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560, + -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960, + -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024, + -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640, + 367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488, + 314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872, + 28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840, + -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640, + -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384, + 453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088, + 385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896, + 34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720, + -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256, + -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128, + 538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688, + 456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472, + -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600, + -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872, + -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872, + 624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288, + 527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776, + -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744, + -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0, + -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040, + 665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600, + 355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496, + -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760, + -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760, + 1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536, + 411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048, + -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648, + -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592, + -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464, + 1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384, + 168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112, + -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008, + -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160, + -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888, + 775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936, + -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400, + -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960, + 1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696, + 520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888, + -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496, + -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320, + -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288, + 881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400, + -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384, + -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280, + 2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992, + 586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256, + -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560, + -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208, + 1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800, + 584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760, + -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128, + -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112, + 1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112, + 80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640, + -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0, + -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824, + 933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184, + -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960, + -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360, + -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888, + 1129464, 1623328, 2149480, 2707920, 2790744 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400, + 38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912, + 186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064, + 213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000, + 366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456, + 166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400, + 348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248, + 58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200, + 224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712, + 1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800, + 477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760, + 874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320, + 179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928, + 493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472, + 1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488, + 789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736, + 1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680, + 243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984, + 669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096, + 2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576, + 1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712, + 1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040, + 308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840, + 716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400, + 2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760, + 950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512, + 1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888, + 113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568, + 1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000, + 2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112, + 656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384, + 1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656, + 4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080, + 1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512, + 2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760, + 345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800, + 905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880, + 4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888, + 1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600, + 2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0, + 72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944, + 2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616, + 3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480, + 565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456, + 1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704, + 4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960, + 1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488, + 2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600, + 6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560, + 2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120, + 3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104, + 487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400, + 1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752, + 5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104, + 1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536, + 2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464, + 7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360, + 2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376, + 3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496, + 231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936, + 3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968, + 4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600, + 1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120, + 1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960, + 6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128, + 1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752, + 3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000, + 8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368, + 3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496, + 4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256, + 447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920, + 4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000, + 6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736, + 1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072, + 2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912, + 8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840, + 2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200, + 3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0, + 134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632, + 4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448, + 5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568, + 748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000, + 1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464, + 8100216, 8852256, 9636584, 10453200, 10794200 + ] + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 55b21283025c2..1c61518ddcdd2 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1362,6 +1362,7 @@ "less.jsonc", "log.jsonc", "matmul.jsonc", + "matmulnbits.jsonc", "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index bd58dded026a6..25e7567a2e9fc 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,13 +8,14 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -25,14 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + SkipLayerNormalization)>}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..888db0fd161f2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/quantization/matmul_nbits.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", JsepSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..cca2c4757765b --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class MatMulNBits final : public JsKernel { + public: + MatMulNBits(const OpKernelInfo& info) : JsKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, + nbits_{narrow(info.GetAttr("bits"))}, + block_size_{narrow(info.GetAttr("block_size"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), + "Block size must be a power of 2 and greater than or equal to 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ + "k" : $1, + "n" : $2, + "accuracyLevel" : $3, + "bits" : $4, + "blockSize" : $5 + }), + static_cast(K_), + static_cast(N_), + static_cast(accuracy_level_), + static_cast(nbits_), + static_cast(block_size_)); + } + + private: + const size_t K_; + const size_t N_; + const int64_t accuracy_level_; + const size_t nbits_; + const size_t block_size_; +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime From ed9d178b9aef411fa101a1c89194c78a58d5dd67 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Wed, 21 Feb 2024 01:24:34 +0800 Subject: [PATCH 22/51] [js/webgpu] Create Split indices helpers by rank, not by shape (#19554) ### Description This is required to make shape uniforms really work. ### Motivation and Context The bug was unveiled in a model with multiple Split nodes. The later nodes would try to reuse a previous pipeline cache, while the old shapes were hardcoded as constants in cache. --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 14d6f37927590..a09ac78b17006 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); + const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; @@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } programUniforms.push( From 9641002e6bfa4199a5fd465dfa8aa4de35e51100 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:23:01 -0800 Subject: [PATCH 23/51] [js] small fix to workaround formatter (#19400) ### Description Rename shader variable names to snake_case naming and also to avoid formatter behaving inconsistently in win/linux. --- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3f73d9cb7c5bc..d5f97213e49ce 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -85,28 +85,28 @@ const createLayerNormProgramInfo = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} let offset = global_idx * uniforms.norm_size_vectorized; - var meanVector = ${fillVector('f32', components)}; - var meanSquareVector = ${fillVector('f32', components)}; + var mean_vector = ${fillVector('f32', components)}; + var mean_square_vector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; - meanVector += value; - meanSquareVector += value * value; + mean_vector += value; + mean_square_vector += value * value; } - let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; - let invStdDev = - inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); + let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; + let inv_std_dev = inverseSqrt(${ + sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale + output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; From a750c39b29fa77a8f3fdf013f130c471a0099300 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:33:37 -0800 Subject: [PATCH 24/51] [js/common] upgrade tsc in common from 4.9.5 to 5.2.2 (#19317) ### Description upgrade tsc in common from 4.9.5 to 5.2.2 --- js/common/package-lock.json | 106 +++++++++++++++++------------------ js/common/package.json | 4 +- js/common/test/tsconfig.json | 2 +- 3 files changed, 56 insertions(+), 56 deletions(-) diff --git a/js/common/package-lock.json b/js/common/package-lock.json index db60eeb87709a..ea972e4097b67 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -9,13 +9,13 @@ "version": "1.17.2", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "node_modules/balanced-match": { @@ -34,9 +34,9 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "node_modules/lunr": { @@ -46,9 +46,9 @@ "dev": true }, "node_modules/marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true, "bin": { "marked": "bin/marked.js" @@ -58,24 +58,24 @@ } }, "node_modules/minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, "engines": { - "node": ">=10" + "node": ">=16 || 14 >=14.17" }, "funding": { "url": "https://github.com/sponsors/isaacs" } }, "node_modules/shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -85,30 +85,30 @@ } }, "node_modules/typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "dependencies": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" }, "bin": { "typedoc": "bin/typedoc" }, "engines": { - "node": ">= 14.14" + "node": ">= 16" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x" } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true, "bin": { @@ -116,7 +116,7 @@ "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/vscode-oniguruma": { @@ -134,9 +134,9 @@ }, "dependencies": { "ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "balanced-match": { @@ -155,9 +155,9 @@ } }, "jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "lunr": { @@ -167,24 +167,24 @@ "dev": true }, "marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true }, "minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "requires": { "brace-expansion": "^2.0.1" } }, "shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "requires": { "ansi-sequence-parser": "^1.1.0", @@ -194,21 +194,21 @@ } }, "typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "requires": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true }, diff --git a/js/common/package.json b/js/common/package.json index a9e14969762fd..3bd830d6b3ba0 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -9,7 +9,7 @@ }, "author": "fs-eire", "scripts": { - "build:cjs": "tsc --module commonjs --outDir ./dist/cjs", + "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs", "build:esm": "tsc", "build:bundles": "webpack", "build": "node ./build.js", @@ -18,7 +18,7 @@ "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" }, "main": "dist/cjs/index.js", "exports": { diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json index 2e4927ac3b325..e9068ad837a81 100644 --- a/js/common/test/tsconfig.json +++ b/js/common/test/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.tools.json", "exclude": ["type-tests/**/*.ts"], "compilerOptions": { - "module": "ES2022", + "module": "Node16", "sourceMap": true } } From a092546e6ee1575e35b13f80b665566b73f450cf Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 21 Feb 2024 00:31:06 -0800 Subject: [PATCH 25/51] [js] changes to allow Float16Array if any polyfill is available (#19305) ### Description This change adds only necessary code to enable ort-web works with any Float16Array polyfill. Unlike #19302, in this PR, ort-web does not include any specific polyfill; instead, it's user's choice for how to use a polyfill. ORT-web uses Float16Array if it's available; otherwise, fallback to use Uint16Array. ```js // case 1: user does not use polyfill: import * as ort from 'onnxruntime-web'; const myF16Data = new Uint16Array(...); // need to use Uint16Array const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ``` ```js // case 2: user use polyfill: import * as ort from 'onnxruntime-web'; import { Float16Array, isFloat16Array, isTypedArray, getFloat16, setFloat16, f16round, } from "@petamoriken/float16"; globalThis.Float16Array = Float16Array; // ort-web will pick the global Float16Array const myF16Data = new Float16Array(...); // Use the polyfilled Float16Array type const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ``` --- js/common/lib/tensor-impl-type-mapping.ts | 34 +++++++++++++++-------- js/common/lib/tensor-impl.ts | 10 ++++--- js/web/lib/wasm/wasm-common.ts | 9 +++++- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index c4a43ea27fea1..b29cb8cbd6d35 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map { - if (!isBigIntChecked) { - isBigIntChecked = true; - const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function'; - const isBigUint64ArrayAvailable = - typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + +// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for +// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array +// polyfill if available. +let isTypedArrayChecked = false; +export const checkTypedArray = () => { + if (!isTypedArrayChecked) { + isTypedArrayChecked = true; + const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; + const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array); @@ -53,5 +58,12 @@ export const checkBigInt = () => { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array); NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64'); } + if (isFloat16ArrayAvailable) { + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array); + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16'); + } else { + // if Float16Array is not available, use 'Uint16Array' to store the data. + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array); + } } }; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index de18126a9d0ae..56682ef98e117 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; +import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; import {Tensor as TensorInterface} from './tensor.js'; @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface { arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| TextureConstructorParameters|GpuBufferConstructorParameters, arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { - // perform one-time check for BigInt support - checkBigInt(); + // perform one-time check for BigInt/Float16Array support + checkTypedArray(); let type: TensorType; let dims: readonly number[]; @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface { throw new TypeError(`Unsupported tensor type: ${arg0}.`); } if (Array.isArray(arg1)) { - if (arg0 === 'float16') { + if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) { + // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array. + // // Throw error here because when user try to use number array as data, // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 93910af1f1bf0..54eaf5e0c43cc 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -3,6 +3,12 @@ import {Tensor} from 'onnxruntime-common'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + // This file includes common definitions. They do NOT have dependency on the WebAssembly instance. /** @@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { switch (type) { case 'float16': - return Uint16Array; + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; case 'float32': return Float32Array; case 'uint8': From 4d0a6851c6b413fd0568fec7c053463d4c65d5ec Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 22 Feb 2024 00:08:47 +0800 Subject: [PATCH 26/51] [js/web] Fix fused-conv is not included in npm test (#19581) BUG: https://github.com/microsoft/onnxruntime/issues/18855 ### Description ### Motivation and Context --- js/web/test/suite-test-list.jsonc | 1 + 1 file changed, 1 insertion(+) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 1c61518ddcdd2..b43b1ac37e37d 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1354,6 +1354,7 @@ "expand.jsonc", "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", From fb9d285c0a2a8decfc60b56570fd16d392751121 Mon Sep 17 00:00:00 2001 From: Matttttt <18152455+martholomew@users.noreply.github.com> Date: Wed, 21 Feb 2024 21:38:18 +0000 Subject: [PATCH 27/51] Misspelling in README.md (#19433) Fixed a misspelling. --- js/web/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/README.md b/js/web/README.md index c75a40ad6da28..906c78a1b7ec4 100644 --- a/js/web/README.md +++ b/js/web/README.md @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience. -ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. +ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports. @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Documents -### Developement +### Development Refer to the following links for development information: From c8e7e8b7550ae5d97b53e6e63c7b40de6dc35ff7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:58:53 -0800 Subject: [PATCH 28/51] Bump ip from 1.1.8 to 1.1.9 in /js/react_native (#19582) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index a6a15fc69b06c..adc2d5396a3e7 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -3701,9 +3701,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-absolute@^1.0.0: version "1.0.0" From 12101deed991a5748f97a57200445e0815e588ae Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 23 Feb 2024 05:09:28 +0800 Subject: [PATCH 29/51] [js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (#19596) This is used in sam-h-decoder-f16. ### Description ### Motivation and Context --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index b5b6a2a15cd8c..11c8778b72335 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; +import {biasSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { - const type = typeSnippet(innerElementSize, 'f32'); + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet = let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; - return vec4(v0, v1, v2, v3); + return ${type}(v0, v1, v2, v3); `; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); @@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo = const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); inputVariables.push(bias); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'pads', type: 'i32', length: pads.length} ]; appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); + } return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}`; }; From bb51cea463c797bd14398ce8ac7964864ea8cd9b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:58:17 -0800 Subject: [PATCH 30/51] Bump ip from 1.1.8 to 1.1.9 in /js/react_native/e2e (#19583) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index 9e20a286c4e27..6f05faf046098 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -3351,9 +3351,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-accessor-descriptor@^0.1.6: version "0.1.6" From 676832697231e010449bd7fcaa102f74039e2cca Mon Sep 17 00:00:00 2001 From: Segev Finer Date: Fri, 23 Feb 2024 04:53:50 +0200 Subject: [PATCH 31/51] [node] Switch to setImmediate to avoid starving the Node.js event loop (#19610) ### Description Switch to setImmediate to avoid starving the Node.js event loop There should really be a true async version though, running computationally intensive things on the event loop will stop everything else from happening while it is running, e.g. a web server from answering requests. This can be done by wrapping `RunAsync` behind a [`napi::Promise`](https://github.com/nodejs/node-addon-api/blob/main/doc/promises.md) to run on the onnxruntime thread pool or [`AsyncWorker`]( https://github.com/nodejs/node-addon-api/blob/main/doc/async_worker.md) for the Node.js/libuv thread pool. ### Motivation and Context Without this, if you run inference in a tight loop, without anything else in between that is async/deferred, `process.nextTick` will lead to starving the event loop and not letting anything else run, `setImmediate` at least lets the event loop spin between calls to `run`. See https://dev.to/ynmanware/setimmediate-settimeout-and-process-nexttick-3mfd Contributed on behalf of [Swimm](https://swimm.io/) --- js/node/lib/backend.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index e8eb0e9babf5a..927953b4f1dd6 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(this.#inferenceSession.run(feeds, fetches, options)); } catch (e) { @@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend { async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {})); } catch (e) { From 846f4452bdabc06bc9ce65ec717086a20f3e8639 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Fri, 23 Feb 2024 00:21:15 -0800 Subject: [PATCH 32/51] [JS/WebGPU] Fix Split and Where to handle corner cases. (#19613) ### Description 1. Fix Where operator to handle Boolean input less than 4 bytes. 2. Fix JSEP test harness to use tensor names consistently. ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 3 ++- js/web/test/data/ops/where.jsonc | 34 ++++++++++++++++++++++++ js/web/test/test-runner.ts | 4 +-- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index cfee07a9239d7..a6375847fc42f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -27,7 +27,7 @@ const createWhereOpProgramShader = const expressionA = `a_data[index_a${x}][component_a${x}]`; const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; @@ -38,6 +38,7 @@ const createWhereOpProgramShader = let index_c${x} = offset_c${x} / 4u; let component_a${x} = offset_a${x} % 4u; let component_b${x} = offset_b${x} % 4u; + let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc index 047fd6fd7511b..990120dd3708e 100644 --- a/js/web/test/data/ops/where.jsonc +++ b/js/web/test/data/ops/where.jsonc @@ -168,5 +168,39 @@ ] } ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1", + "inputs": [ + { + "data": [true, false], + "dims": [1, 1, 2, 1], + "type": "bool" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 4], + "type": "float32" + }, + { + "data": [5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 442cb1bcf1f34..4dba78f3852e4 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -632,8 +632,8 @@ export async function runModelTestSet( try { const feeds: Record = {}; const outputsMetaInfo: Record = {}; - testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor); + testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor); const [start, end, outputs] = await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { From 0eff9837182998ad46dee886c4447aa830dbd106 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:52:47 -0800 Subject: [PATCH 33/51] [js/webgpu] allows a ProgramInfo's RunData to use zero sized output (#19614) ### Description This PR allows zero-sized output. To make the implementation simple, it does not support partial zero-sized tensor. Which means, either all outputs are zero-sized, or an error will be reported. added 2 tests: - op test of `Add` with input T[2,0] T[2,1], and - test_split_zero_size_splits --- js/web/lib/wasm/jsep/backend-webgpu.ts | 32 ++++++++++++++++++++++---- js/web/lib/wasm/jsep/init.ts | 3 ++- js/web/lib/wasm/jsep/util.ts | 11 ++++++++- js/web/test/data/ops/add.jsonc | 22 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 2 +- js/web/test/test-runner.ts | 10 ++++++-- 6 files changed, 71 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 98990a6fe477b..3e3a191ec3ead 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -385,11 +385,16 @@ export class WebGpuBackend { // create info for inputs const inputDatas: GpuData[] = []; for (let i = 0; i < inputTensorViews.length; ++i) { - const gpuData = this.gpuDataManager.get(inputTensorViews[i].data); + const data = inputTensorViews[i].data; + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (data === 0) { + continue; + } + const gpuData = this.gpuDataManager.get(data); if (!gpuData) { - throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`); + throw new Error(`no GPU data for input: ${data}`); } - inputDatas[i] = gpuData; + inputDatas.push(gpuData); } const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); @@ -419,6 +424,11 @@ export class WebGpuBackend { const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + outputTensorViews.push(tensorView); + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (tensorView.data === 0) { + continue; + } const gpuData = this.gpuDataManager.get(tensorView.data); if (!gpuData) { throw new Error(`no GPU data for output: ${tensorView.data}`); @@ -434,10 +444,24 @@ export class WebGpuBackend { } persistentData.push(gpuData); } - outputTensorViews.push(tensorView); outputDatas.push(gpuData); } + // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are + // zero-sized tensors. + if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) { + // if all outputs are zero-sized tensors, there is no need to run the program. + if (outputDatas.length === 0) { + TRACE_FUNC_END(program.name); + return outputTensorViews; + } + // if some outputs are zero-sized tensors, report an error. + // + // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors. + // If we see such use case, we need to make a change here to support it. + throw new Error( + `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`); + } // load uniforms // TODO: add cache for uniform (is it necessary?) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 786ae41646554..b64abf9cc5424 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -104,7 +104,8 @@ class ComputeContextImpl implements ComputeContext { throw new Error(`Unsupported data type: ${dataType}`); } const bufferSize = elementSize * ShapeUtil.size(dims); - return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims); + const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0; + return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput); } diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index c0517ce363644..9a1d5463f7843 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -56,7 +56,16 @@ export class BroadcastUtil { if (aLen !== bLen && aLen > 1 && bLen > 1) { return undefined; } - cdims[crank - i] = Math.max(aLen, bLen); + const max = Math.max(aLen, bLen); + if (aLen && bLen) { + cdims[crank - i] = Math.max(aLen, bLen); + } else { + // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable. + if (max > 1) { + return undefined; + } + cdims[crank - i] = 0; + } } return cdims; diff --git a/js/web/test/data/ops/add.jsonc b/js/web/test/data/ops/add.jsonc index e5b4ff2b53148..dd15134861ef0 100644 --- a/js/web/test/data/ops/add.jsonc +++ b/js/web/test/data/ops/add.jsonc @@ -157,6 +157,28 @@ "type": "float32" } ] + }, + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] } ] } diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index b43b1ac37e37d..88555a27be82e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1231,7 +1231,7 @@ "test_split_variable_parts_1d", "test_split_variable_parts_2d", "test_split_variable_parts_default_axis", - // // "test_split_zero_size_splits", + "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", "test_squeeze_negative_axes", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 4dba78f3852e4..14089c9e146f0 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -578,7 +578,9 @@ export async function sessionRun(options: { // replace the CPU tensors in feeds into GPU tensors for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { - feeds[name] = createGpuTensorForInput(feeds[name]); + if (feeds[name].size > 0) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } } } } @@ -587,7 +589,11 @@ export async function sessionRun(options: { for (const name in options.outputsMetaInfo) { if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { const {type, dims} = options.outputsMetaInfo[name]; - fetches[name] = createGpuTensorForOutput(type, dims); + if (dims.some(d => d === 0)) { + fetches[name] = new ort.Tensor(type, [], dims); + } else { + fetches[name] = createGpuTensorForOutput(type, dims); + } } } } From 16c55468e8eb1909fb8e2f088001dd748949ed79 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 23 Feb 2024 15:45:30 -0800 Subject: [PATCH 34/51] [js/webgpu] minor fixes to make tinyllama work (#19564) --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 4 +++- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b06c9fb496d15..b142a82e551a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -154,7 +154,9 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { validateInputs(context.inputs); - context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); + // 0 length tensors are valid for concat, remove them + const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 5c31e6dd86c00..d48bb909f7f8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath if (idx${x} < 0) { idx${x} = idx${x} + uniforms.axisDimLimit; } - var dataIndices${x} = ${data.type.indices}(0); + var dataIndices${x} : ${data.type.indices}; `; for (let i = 0, j = 0; i < inputRank; i++) { if (i === axis) { From d6a7bd6759a835ede9286233e348c0cfb97aa97d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 24 Feb 2024 10:09:07 -0800 Subject: [PATCH 35/51] [js/web] fix suite test list for zero sized tensor (#19638) ### Description Fixes build break brought by #19614 Currently WebGL backend does not support zero sized tensor. This change split test data into 2 parts, and only enable zero sized tensor tests for WebGPU. --- js/web/test/data/ops/add.jsonc | 22 - js/web/test/data/ops/add_zero-sized.jsonc | 31 + js/web/test/data/ops/concat_zero-sized.jsonc | 561 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 2 + 4 files changed, 594 insertions(+), 22 deletions(-) create mode 100644 js/web/test/data/ops/add_zero-sized.jsonc create mode 100644 js/web/test/data/ops/concat_zero-sized.jsonc diff --git a/js/web/test/data/ops/add.jsonc b/js/web/test/data/ops/add.jsonc index dd15134861ef0..e5b4ff2b53148 100644 --- a/js/web/test/data/ops/add.jsonc +++ b/js/web/test/data/ops/add.jsonc @@ -157,28 +157,6 @@ "type": "float32" } ] - }, - { - "name": "T[2,0] T[2,1]", - "inputs": [ - { - "data": [], - "dims": [2, 0], - "type": "float32" - }, - { - "data": [1, 2], - "dims": [2, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [2, 0], - "type": "float32" - } - ] } ] } diff --git a/js/web/test/data/ops/add_zero-sized.jsonc b/js/web/test/data/ops/add_zero-sized.jsonc new file mode 100644 index 0000000000000..37e08cd7f20ac --- /dev/null +++ b/js/web/test/data/ops/add_zero-sized.jsonc @@ -0,0 +1,31 @@ +[ + { + "name": "Add with no attributes", + "operator": "Add", + "attributes": [], + "cases": [ + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc new file mode 100644 index 0000000000000..7be8e8c1cc602 --- /dev/null +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -0,0 +1,561 @@ +[ + { + "name": "Concat 2D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": -2, "type": "int" }], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [], + "dims": [1, 4, 0, 64], + "type": "float32" + }, + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 88555a27be82e..e96a0aa045bc8 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1334,6 +1334,7 @@ "acos.jsonc", "add.jsonc", "add_int32.jsonc", + "add_zero-sized.jsonc", //"and.jsonc", "asin.jsonc", "attention.jsonc", @@ -1343,6 +1344,7 @@ "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", + "concat_zero-sized.jsonc", "cast.jsonc", "conv.jsonc", "cos.jsonc", From d61ef6f13e4706f4c107fd5a29c5a0bf5c719999 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:07:15 -0800 Subject: [PATCH 36/51] [js/common] move 'env.wasm.trace' to 'env.trace' (#19617) ### Description Try to move 'env.wasm.trace' to 'env.trace' to make it less confusing, because it also works in webgpu. Marked 'env.wasm.trace' as deprecated. --- js/common/lib/env.ts | 9 +++++++++ js/common/lib/trace.ts | 6 +++--- js/web/lib/wasm/jsep/backend-webgpu.ts | 3 ++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 6299c26159400..73a47d1a4f937 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -36,6 +36,7 @@ export declare namespace Env { /** * set or get a boolean value indicating whether to enable trace. * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` */ trace?: boolean; @@ -167,6 +168,7 @@ export interface Env { * @defaultValue `'warning'` */ logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'; + /** * Indicate whether run in debug mode. * @@ -174,6 +176,13 @@ export interface Env { */ debug?: boolean; + /** + * set or get a boolean value indicating whether to enable trace. + * + * @defaultValue `false` + */ + trace?: boolean; + /** * Get version of the current package. */ diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts index 404f7ef8089af..7e0487b350198 100644 --- a/js/common/lib/trace.ts +++ b/js/common/lib/trace.ts @@ -4,7 +4,7 @@ import {env} from './env-impl.js'; export const TRACE = (deviceType: string, label: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } // eslint-disable-next-line no-console @@ -30,14 +30,14 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => { }; export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('BEGIN', extraMsg); }; export const TRACE_FUNC_END = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('END', extraMsg); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 3e3a191ec3ead..27c5566ab9fed 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -710,7 +710,8 @@ export class WebGpuBackend { } setQueryType(): void { this.queryType = 'none'; - if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) { + if (this.env.webgpu.profiling?.mode === 'default' || + (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) { if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) { this.queryType = 'inside-passes'; } else if (this.device.features.has('timestamp-query')) { From ffe15832205d06f8e0ecf276518a19f3421c931a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:05:08 -0800 Subject: [PATCH 37/51] [js/webgpu] use Headless for webgpu test by default (#19702) ### Description use Chromium Headless for webgpu test by default. Still use normal Chromium with window when debug=true or perfMode=true. Use the [`--headless=new`](https://developer.chrome.com/docs/chromium/new-headless) mode. ### Motivation and Context try to use a more stable way to launch npm tests to avoid a "chrome not found" issue in pipeline, which may potentially caused by windowed application. --- js/web/karma.conf.js | 4 ++-- js/web/script/test-runner-cli.ts | 29 +++++++---------------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 8fce79843f617..9e44d9c0d9652 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -86,11 +86,11 @@ module.exports = function(config) { hostname, listenAddress, customLaunchers: { - // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly. + // Chromium-based browsers EdgeTest: {base: 'Edge', flags: chromiumFlags}, ChromeTest: {base: 'Chrome', flags: chromiumFlags}, - ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags}, ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags}, + // // ==== BrowserStack browsers ==== // diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index d56792c6e3595..1d889152c61a6 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -495,14 +495,13 @@ async function main() { npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); const webgpu = args.backends.indexOf('webgpu') > -1; const webnn = args.backends.indexOf('webnn') > -1; - const browser = getBrowserNameFromEnv( - args.env, - args.bundleMode === 'perf' ? 'perf' : - args.debug ? 'debug' : - 'test', - webgpu); + const browser = getBrowserNameFromEnv(args.env); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; const chromiumFlags = ['--enable-features=SharedArrayBuffer', ...args.chromiumFlags]; + if (args.bundleMode === 'dev' && !args.debug) { + // use headless for 'test' mode (when 'perf' and 'debug' are OFF) + chromiumFlags.push('--headless=new'); + } if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); chromiumFlags.push('--remote-debugging-port=9333'); @@ -615,10 +614,10 @@ async function main() { fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config); } - function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean) { + function getBrowserNameFromEnv(env: TestRunnerCliArgs['env']) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu); + return 'ChromeTest'; case 'edge': return 'EdgeTest'; case 'firefox': @@ -633,20 +632,6 @@ async function main() { throw new Error(`env "${env}" not supported.`); } } - - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean) { - if (webgpu) { - return 'ChromeTest'; - } else { - switch (mode) { - case 'debug': - case 'perf': - return 'ChromeTest'; - default: - return 'ChromeTestHeadless'; - } - } - } } void main(); From ce612c7afff2fc06b6919aa6db678b0eef976933 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:50:06 -0800 Subject: [PATCH 38/51] [js/web] transfer input buffer back to caller thread (#19677) ### Description When using proxy worker, input buffers should be transferred back to the caller thread after `run()` call is done. Fixes #19488 --- js/web/lib/wasm/proxy-worker/main.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 6cbd38c76ccc8..3ce37a2d6b652 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -103,7 +103,7 @@ self.onmessage = (ev: MessageEvent): void => { } else { postMessage( {type, out: outputs} as OrtWasmMessage, - extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[])); } }, err => { From 4ccd6201cf7b0539c54ee508a842f430b331ea0f Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 7 Mar 2024 19:07:49 -0800 Subject: [PATCH 39/51] [JS/WebGPU] Preserve zero size input tensor dims. (#19737) ### Description For Concat operation, the zero-size input tensor shape need to be preserved and, unlike non-zero tensors, the dims are not constrained to match other input tensors' dims. ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 146 +++++++++---------- js/web/test/data/ops/concat_zero-sized.jsonc | 80 ++++++++++ 2 files changed, 149 insertions(+), 77 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b142a82e551a7..010ee589c44fa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -const validateInputs = (inputs: readonly TensorView[]): void => { +const validateInputs = (inputs: readonly TensorView[], axis: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } - - const inputType = inputs[0].dataType; - const inputDimensionality = inputs[0].dims.length; - - for (const input of inputs) { + const referenceIndex = 0; + const referenceInput = inputs[referenceIndex]; + const inputType = referenceInput.dataType; + const inputRank = referenceInput.dims.length; + inputs.forEach((input, i) => { + if (i === referenceIndex) { + return; + } // make sure types of all inputs match if (input.dataType !== inputType) { throw new Error('input tensors should be one type'); } - // make sure the dimensionality of all inputs are the same - if (input.dims.length !== inputDimensionality) { + if (input.dims.length !== inputRank) { throw new Error('input tensors should have the same shape'); } - } + input.dims.forEach((dim, i) => { + if (i !== axis && dim !== referenceInput.dims[i]) { + throw new Error('non concat dimensions must match'); + } + }); + }); }; const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` @@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); - } - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === adjustedAxis) { - outputShape[adjustedAxis] += dataNShape[axisIndex]; +const createConcatProgramInfo = + (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => { + const outputSize = ShapeUtil.size(outputShape); + + const sizeInConcatAxis = new Array(inputs.length); + const inputVars = new Array(inputs.length); + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputRanks = []; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { - throw new Error('non concat dimensions must match'); + for (let i = 0; i < inputs.length; ++i) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); } - } - } - - const outputSize = ShapeUtil.size(outputShape); - - const sizeInConcatAxis = new Array(inputs.length); - const inputVars = new Array(inputs.length); - const dataType = inputs[0].dataType; - - let previousSum = 0; - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputRanks = []; - const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; - for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[adjustedAxis]; - sizeInConcatAxis[i] = previousSum; - inputRanks.push(inputs[i].dims.length); - inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); - inputDependencies.push('rank'); - programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); - } - for (let i = 0; i < inputs.length; ++i) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(outputShape)); - const output = outputVariable('output', dataType, outputShape.length); - const indicesAxis = output.indicesGet('indices', adjustedAxis); - const sizeInConcatAxisStr = - Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const output = outputVariable('output', dataType, outputShape.length); + const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${(() => { - shaderHelper.registerUniform('outputSize', 'u32'); - for (let i = 0; i < inputs.length; i++) { - shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); - } - return shaderHelper.declareVariables(...inputVars, output); - })()} + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} @@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P ${assignOutputData(inputVars, output)} }`; - return { - name: 'Concat', - shaderCache: {hint: `${axis}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms, - }), - getShaderSource, - }; -}; + return { + name: 'Concat', + shaderCache: {hint: `${adjustedAxis}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, + }), + getShaderSource, + }; + }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - validateInputs(context.inputs); + const inputs = context.inputs; + const inputShape = inputs[0].dims; + const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); + validateInputs(inputs, adjustedAxis); + const outputShape = inputShape.slice(); + outputShape[adjustedAxis] = + inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); // 0 length tensors are valid for concat, remove them - const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); - context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); + const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute( + createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc index 7be8e8c1cc602..be9625145d157 100644 --- a/js/web/test/data/ops/concat_zero-sized.jsonc +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -557,5 +557,85 @@ ] } ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 0, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "All input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 6], + "type": "float32" + } + ] + } + ] } ] From 18027bed05718b065d16b1f8c91a1f98099e224a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 12 Mar 2024 19:50:51 -0700 Subject: [PATCH 40/51] [js/webgpu] expose a few properties in WebGPU API (#19857) ### Description This change exposes a few properties in `ort.env.webgpu` to resolve feature requirement mentioned in properties in https://github.com/microsoft/onnxruntime/pull/14579#discussion_r1519612619. - Add `powerPreference` and `forceFallbackAdapter` in `ort.env.webgpu`, to allow users to set the value of the properties before the first inference session is created. - Add readonly property `adapter` in `ort.env.webgpu` to allow users to get the adapter instance. Now users can access `ort.env.webgpu.device` and `ort.env.webgpu.adapter`. @xenova @beaufortfrancois --- js/common/lib/env.ts | 35 ++++++++++++++++++++++++++ js/web/lib/wasm/jsep/backend-webgpu.ts | 1 + js/web/lib/wasm/wasm-core-impl.ts | 10 +++++++- 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 73a47d1a4f937..dd8bde2b596f4 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -143,9 +143,44 @@ export declare namespace Env { */ ondata?: (data: WebGpuProfilingData) => void; }; + /** + * Set or get the power preference. + * + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as options for `navigator.gpu.requestAdapter()`. + * + * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. + * + * @defaultValue `undefined` + */ + powerPreference?: 'low-power'|'high-performance'; + /** + * Set or get the force fallback adapter flag. + * + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as options for `navigator.gpu.requestAdapter()`. + * + * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details. + * + * @defaultValue `undefined` + */ + forceFallbackAdapter?: boolean; + /** + * Get the adapter for WebGPU. + * + * This property is only available after the first WebGPU inference session is created. + * + * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". + * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. + * + * see comments on {@link GpuBufferType} + */ + readonly adapter: unknown; /** * Get the device for WebGPU. * + * This property is only available after the first WebGPU inference session is created. + * * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types". * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type. * diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 27c5566ab9fed..182c1cd351c9d 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -231,6 +231,7 @@ export class WebGpuBackend { }; Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); + Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter}); // init queryType, which is necessary for InferenceSession.create this.setQueryType(); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 37b9ed6a1002f..afab9ba00b0c4 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -89,7 +89,15 @@ export const initEp = async(env: Env, epName: string): Promise => { if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); } - const adapter = await navigator.gpu.requestAdapter(); + const powerPreference = env.webgpu?.powerPreference; + if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') { + throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); + } + const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter; + if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { + throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); + } + const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); if (!adapter) { throw new Error( 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); From b54dd28f4a87a99ca2c36067458e2bd52b83ca83 Mon Sep 17 00:00:00 2001 From: Yang Gu Date: Wed, 13 Mar 2024 13:25:07 +0800 Subject: [PATCH 41/51] [js/webgpu] Enable GroupedConvVectorize path (#19791) Vectorize met 2 failed cases in a CI bot with NVIDIA GPU, but we couldn't repro with all the GPUs at hand, including NVIDIA GPUs. This PR introduces GPUAdapterInfo and enables this opt on non-NVIDIA GPUs to make the bots happy. No obivous perf gain can be seen if we enable vectorize on NVIDIA. However, it shows big perf improvement on Intel. On my Gen12 Intel GPU, mobilenetv2-12 perf was improved from 11.14ms to 7.1ms. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 24 +++++++++++++++++++++++- js/web/lib/wasm/jsep/init.ts | 4 +++- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 7 ++++--- js/web/lib/wasm/jsep/webgpu/types.ts | 12 ++++++++++++ 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 182c1cd351c9d..d92b8ac68dbe7 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -10,7 +10,7 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; +import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; interface CommandInfo { readonly kernelId: number; @@ -94,11 +94,32 @@ const getProgramInfoUniqueKey = return key; }; +class AdapterInfoImpl implements AdapterInfo { + readonly architecture?: string; + readonly vendor?: string; + + constructor(adapterInfo: GPUAdapterInfo) { + if (adapterInfo) { + this.architecture = adapterInfo.architecture; + this.vendor = adapterInfo.vendor; + } + } + + isArchitecture(architecture: GpuArchitecture): boolean { + return this.architecture === architecture; + } + + isVendor(vendor: GpuVendor): boolean { + return this.vendor === vendor; + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. */ export class WebGpuBackend { + adapterInfo: AdapterInfoImpl; device: GPUDevice; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping @@ -212,6 +233,7 @@ export class WebGpuBackend { } this.device = await adapter.requestDevice(deviceDescriptor); + this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo()); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); this.kernels = new Map(); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b64abf9cc5424..4936b94ef7a86 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; import {TensorView} from './tensor-view'; import {ShapeUtil} from './util'; -import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; +import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types'; /* eslint-disable no-bitwise */ @@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView { } class ComputeContextImpl implements ComputeContext { + readonly adapterInfo: AdapterInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext { private customDataOffset = 0; private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { + this.adapterInfo = backend.adapterInfo; const heapU32 = module.HEAPU32; // extract context data diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 5afec0389fac8..b68d4dcae4cb9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -148,11 +148,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ const isChannelsLast = attributes.format === 'NHWC'; if (attributes.group !== 1) { - // Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases: + // NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other + // GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs. // [webgpu]Conv - conv - vectorize group - B // [webgpu]Conv - conv - vectorize group - D - const disableGroupedConvVectorize = true; - if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && + const enableGroupedConvVectorize = !context.adapterInfo.isArchitecture('ampere'); + if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { const outputShape = calculateOutputShape( inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index ba5b84fcfe067..48e0855f01a97 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -15,6 +15,13 @@ export enum GpuDataType { } export type GpuDataId = number; +export type GpuArchitecture = 'ampere'; +export type GpuVendor = 'amd'|'intel'|'nvidia'; +export interface AdapterInfo { + isArchitecture: (architecture: GpuArchitecture) => boolean; + isVendor: (vendor: GpuVendor) => boolean; +} + export interface GpuData { type: GpuDataType; id: GpuDataId; @@ -146,6 +153,11 @@ export interface ComputeContextInputsOutputsMapping { * A ComputeContext instance carries the states that representing the current running of a kernel. */ export interface ComputeContext { + /** + * gpu adapter info + */ + readonly adapterInfo: AdapterInfo; + /** * stores the pointer to OpKernelContext */ From 1df991198a0bed8284a5a58b74fcecea576f10bd Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 13 Mar 2024 10:33:14 -0700 Subject: [PATCH 42/51] [JS/WebGPU] Optimize MatMulNBits (#19852) ### Description Use vec<2> or vec<4>, operands in MatMulNBits ### Motivation and Context Improve performance --- .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 208 ++++++++++++------ js/web/test/data/ops/matmulnbits.jsonc | 57 +++++ 2 files changed, 194 insertions(+), 71 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index ead7635cf3ac4..9bf5e4066139d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; // TODO support quantization bits not equal to 4 export interface MatMulNBitsAttributes extends AttributeWithCacheKey { @@ -51,29 +51,38 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt export const createMatMulNBitsProgramInfo = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { - const a = inputs[0]; - const b = inputs[1]; - const scales = inputs[2]; - const aRank = a.dims.length; - const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n); - const outputSize = ShapeUtil.size(outputShape); - - + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n); + const m = inputShape[aRank - 2]; + const blobSize = attributes.blockSize / 8 * attributes.bits; + const blobSizeInWords = blobSize / 4; + const outputNumber = getMaxComponents(m); + const components = getMaxComponents(attributes.n); + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} ]; - programUniforms.push(...createTensorShapeVariables(a.dims)); - programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims))); - programUniforms.push(...createTensorShapeVariables(scales.dims)); + const aShape = inputShape.slice(); + aShape.splice(-1, 1, attributes.k / aComponents); + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(aShape)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); if (inputs.length === 4) { programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); } - programUniforms.push(...createTensorShapeVariables(outputShape)); + const oShape = outputShape.slice(); + oShape.splice(-1, 1, attributes.n / components); + programUniforms.push(...createTensorShapeVariables(oShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { - const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length); - const b = inputVariable('b', DataType.uint32, inputs[1].dims.length); + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); const inputVariables = [a, b, scales]; const zeroPoints = @@ -81,86 +90,143 @@ export const createMatMulNBitsProgramInfo = if (zeroPoints) { inputVariables.push(zeroPoints); } - const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'}, + {name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} ]; const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); - const blobSize = attributes.blockSize / 8 * attributes.bits; - const wordPerBlob = blobSize / 4; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - return ` - fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{ - var result = array<${dataType}, 8>(); + + const qDqDataType = (() => { + switch (aComponents) { + case 1: + return `array<${dataType}, 8>`; + case 2: + return `mat4x2<${dataType}>`; + case 4: + return `mat2x4<${dataType}>`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })(); + + const dequantizeImpl = ` + fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} { + ${(() => { + if (aComponents === 1) { + return `var dequantized = ${qDqDataType}(${ + Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')}); + return dequantized;`; + } else { + return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')}); + return (quantized - zero_points) * scale;`; + } + })()} + }`; + const ortUnpack8x4snormImpl = ` + fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} { + var quantized: ${qDqDataType}; var offset: u32 = 0; let count: u32 = 4; for (var i: u32 = 0; i < 8u; i++) { - result[i] = ${dataType}(extractBits(value, offset, count)); + var result = ${dataType}(extractBits(value, offset, count)); + ${(() => { + switch (aComponents) { + case 1: + return 'quantized[i] = result;'; + case 2: + return 'quantized[i / 2][i % 2] = result;'; + case 4: + return 'quantized[i / 4][i % 4] = result;'; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + })()} offset += count; } - return result; - } + return quantized; + }`; + + const updateZeroPointIndex = zeroPoints ? ` + zero_point_offset += 4; + if (zero_point_offset == 32) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + }` : + ''; + + return ` + ${dequantizeImpl}; + ${ortUnpack8x4snormImpl}; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var value: ${dataType} = 0.0; - let output_indices = ${output.offsetToIndices('global_idx')}; - var a_indices: ${a.type.indices} = output_indices; + var output_values: array<${output.type.value}, ${outputNumber}>; + var output_indices = ${output.offsetToIndices('global_idx')}; var n = ${output.indicesGet('output_indices', aRank - 1)}; + var m = ${output.indicesGet('output_indices', aRank - 2)}; + var a_indices: ${a.type.indices} = output_indices; // Two zero points are packed into one byte because uniforms.bits <= 4. // zero_point_offset is either 0 or 4. It is bit offset within one byte. // TODO support zero_point_offset for bits > 4 ${ zeroPoints ? ` - var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4; - var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; - var zero_point_offset: u32 = 0;` : + var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : ''} - var scale_idex = n * ${nBlocksPerCol}; + var scale_index = n * ${nBlocksPerCol * components}; var b_indices: ${b.type.indices}; - ${b.indicesSet('b_indices', '0', 'n')}; - var block_offset: u32 = 0; - for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { - // The scale and zero points are computed per block. - let scale = ${scales.getByOffset('scale_idex')}; - // The default zero point is 8 for unsigned 4-bit quantization. - let zero_point: ${dataType} = ${ - zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0}; - ${b.indicesSet('b_indices', '1', 'block')}; - var word_offset: u32 = block_offset; - for (var word: u32 = 0; word < ${wordPerBlob}; word++) { - ${b.indicesSet('b_indices', '2', 'word')}; - let b_value = ${b.getByIndices('b_indices')}; - let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value); - // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 - var offset: u32 = word_offset; - for (var i: u32 = 0; i < 8; i++) { - ${a.indicesSet('a_indices', aRank - 1, 'offset')}; - let a_value = ${a.getByIndices('a_indices')}; - let b_quantized_value = b_quantized_values[i]; - let b_dequantized_value = (b_quantized_value - zero_point) * scale; - value += a_value * b_dequantized_value; - offset++; + for (var c: u32 = 0; c < ${components}; c++) { + ${b.indicesSet('b_indices', '0', `n * ${components} + c`)}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_index')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0}); + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_data = ${b.getByIndices('b_indices')}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'}; + let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value); + let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var j: u32 = 0; j < 8/${aComponents}; j++) { + ${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)}; + for (var k: u32 = 0; k < ${outputNumber}u; k++) { + ${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)}; + let a_data = ${a.getByIndices('a_indices')}; + output_values[k]${components > 1 ? '[c]' : ''} += ${ + aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'}; + } + offset += ${aComponents}; + } + word_offset += 8; + } } - word_offset += 8; + scale_index++; + ${updateZeroPointIndex} + block_offset += uniforms.block_size; } - scale_idex++; + // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. ${ - zeroPoints ? ` - if (zero_point_offset == 28) { - zero_point_offset = 0; - zero_point_index++; - zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; - } else { - zero_point_offset += 4; - }` : + zeroPoints ? `if (zero_point_offset % 8 > 0) { + ${updateZeroPointIndex} + }` : ''} - block_offset += uniforms.block_size; - } - ${output.setByOffset('global_idx', 'value')}; - } - `; + } + for (var k: u32 = 0u; k < ${outputNumber}u; k++) { + ${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)}; + ${output.setByIndices('output_indices', 'output_values[k]')} + } + }`; }; return { name: 'MatMulNBits', @@ -168,7 +234,7 @@ export const createMatMulNBitsProgramInfo = {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64)}, + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms }), getShaderSource diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc index c57c431afb3ce..175be78cc0818 100644 --- a/js/web/test/data/ops/matmulnbits.jsonc +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -1,4 +1,61 @@ [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247 + ] + } + ] + } + ] + }, { "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", "operator": "MatMulNBits", From 648cc425d2eddbe3e73a2fc45d8939a09120ebc1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:47:45 -0700 Subject: [PATCH 43/51] [js/web] rewrite backend resolve to allow multiple EPs (#19735) ### Description This PR rewrite the backend resolve logic to support specifying multiple EPs. #### Backend The first version of ONNX Runtime Web actually carried some existing code from [ONNX.js](https://github.com/microsoft/onnxjs), which includes the "backend" concept. The original "backend" in ONNX.js is designed in a way assuming there is only one backend from user's backend hint list will be used. For example, in ONNX.js, if user specify a backend hint as `['webgl', 'wasm']`, ONNX.js will first try to use WebGL backend - if it loads successfully (the browser supports webgl), then "webgl" backend will be used and "wasm" will be ignored; otherwise, "webgl" will be ignored and try to load "wasm" backend. In short: only one backend will be used when initializing a session. #### Execution Provider Execution Provider, or EP, in ONNX Runtime is a different concept. One of the differences is that users are allow to specify multiple EPs, and if one does not support a particular kernel, it can fallback to other EP. This is a very common case when using a GPU EP in ONNX Runtime. #### Current Status: Backend v.s. EP Because of the history reasons mentioned above, the current status is quite confusing. There are **real backend**s, which means it's different implementation in code; and there are **backend hint**s, which are used as string names for backend hint; and there are **EP**s of the ONNX Runtime concepts. currently there are only 2 **backend**s in our code base: The "onnxjs backend", and the "wasm backend". The "onnxjs backend" currently only powers backend hint "webgl", which go into the old onnx.js code path. All other backend hints including "wasm", "cpu"(alias to wasm), "webgpu" and "webnn" are all powered by "wasm backend". And because ORT Web treat "backend" as an internal concept and want to align with ONNX Runtime, so those names of backend hints are becoming EP names. The following table shows today's status: | Execution Provider Name (public) / Backend Hint (internal) | Backend | EP in ORT | -------- | ------- | ------- | | "wasm"/"cpu" | WasmBackend | CPU EP | "webgl" | OnnxjsBackend | \* technically not an EP | "webgpu" | WasmBackend | JSEP | "webnn" | WasmBackend | WebNN EP #### Problem While the API allows to specify multiple EPs, the backend resolving only allows one backend. This causes issues when user specify multiple EP names in session options, the backend resolve behavior and EP registration behavior is inconsistent. Specifically, in this issue: https://github.com/microsoft/onnxruntime/issues/15796#issuecomment-1925363908: EP list `['webgpu', 'wasm']` on a browser without WebGPU support resolves to 'wasm' backend, but the full EP list is passed in session options, so JSEP is still enabled, causing the runtime error. #### Solution Since we still need WebGL backend, we cannot totally remove the backend register/resolve system. In this PR I made the following changes: - initialize every backend from the EP list, instead of only do that for the first successful one. - for the first resolved backend, filter all EP using the exact same backend. Remove all EPs not using this backend from session options - for every explicitly specified EP, if it's removed, show a warning message in console --- js/common/lib/backend-impl.ts | 121 +++++++++--- js/common/lib/inference-session-impl.ts | 10 +- js/common/lib/training-session-impl.ts | 11 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 240 +++++++++++++----------- js/web/lib/wasm/jsep/init.ts | 38 ++-- js/web/lib/wasm/proxy-wrapper.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 76 +++++--- onnxruntime/wasm/js_internal_api.js | 82 ++++---- 8 files changed, 348 insertions(+), 232 deletions(-) diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 3e1e833addb91..e90efd7b97c29 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {Backend} from './backend.js'; +import {InferenceSession} from './inference-session.js'; interface BackendInfo { backend: Backend; @@ -10,6 +11,7 @@ interface BackendInfo { initPromise?: Promise; initialized?: boolean; aborted?: boolean; + error?: string; } const backends: Map = new Map(); @@ -60,43 +62,100 @@ export const registerBackend = (name: string, backend: Backend, priority: number }; /** - * Resolve backend by specified hints. + * Try to resolve and initialize a backend. * - * @param backendHints - a list of execution provider names to lookup. If omitted use registered backends as list. - * @returns a promise that resolves to the backend. + * @param backendName - the name of the backend. + * @returns the backend instance if resolved and initialized successfully, or an error message if failed. + */ +const tryResolveAndInitializeBackend = async(backendName: string): Promise => { + const backendInfo = backends.get(backendName); + if (!backendInfo) { + return 'backend not found.'; + } + + if (backendInfo.initialized) { + return backendInfo.backend; + } else if (backendInfo.aborted) { + return backendInfo.error!; + } else { + const isInitializing = !!backendInfo.initPromise; + try { + if (!isInitializing) { + backendInfo.initPromise = backendInfo.backend.init(backendName); + } + await backendInfo.initPromise; + backendInfo.initialized = true; + return backendInfo.backend; + } catch (e) { + if (!isInitializing) { + backendInfo.error = `${e}`; + backendInfo.aborted = true; + } + return backendInfo.error!; + } finally { + delete backendInfo.initPromise; + } + } +}; + +/** + * Resolve execution providers from the specific session options. + * + * @param options - the session options object. + * @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with + * filtered EP list. * * @ignore */ -export const resolveBackend = async(backendHints: readonly string[]): Promise => { - const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - const errors = []; - for (const backendName of backendNames) { - const backendInfo = backends.get(backendName); - if (backendInfo) { - if (backendInfo.initialized) { - return backendInfo.backend; - } else if (backendInfo.aborted) { - continue; // current backend is unavailable; try next - } +export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions): + Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => { + // extract backend hints from session options + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints; - const isInitializing = !!backendInfo.initPromise; - try { - if (!isInitializing) { - backendInfo.initPromise = backendInfo.backend.init(backendName); + // try to resolve and initialize all requested backends + let backend: Backend|undefined; + const errors = []; + const availableBackendNames = new Set(); + for (const backendName of backendNames) { + const resolveResult = await tryResolveAndInitializeBackend(backendName); + if (typeof resolveResult === 'string') { + errors.push({name: backendName, err: resolveResult}); + } else { + if (!backend) { + backend = resolveResult; + } + if (backend === resolveResult) { + availableBackendNames.add(backendName); + } } - await backendInfo.initPromise; - backendInfo.initialized = true; - return backendInfo.backend; - } catch (e) { - if (!isInitializing) { - errors.push({name: backendName, err: e}); + } + + // if no backend is available, throw error. + if (!backend) { + throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`); + } + + // for each explicitly requested backend, if it's not available, output warning message. + for (const {name, err} of errors) { + if (backendHints.includes(name)) { + // eslint-disable-next-line no-console + console.warn(`removing requested execution provider "${ + name}" from session options because it is not available: ${err}`); } - backendInfo.aborted = true; - } finally { - delete backendInfo.initPromise; } - } - } - throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`); -}; + const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name)); + + return [ + backend, new Proxy(options, { + get: (target, prop) => { + if (prop === 'executionProviders') { + return filteredEps; + } + return Reflect.get(target, prop); + } + }) + ]; + }; diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 55f40c8907a89..ab4c6a3e0c46b 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackend} from './backend-impl.js'; +import {resolveBackendAndExecutionProviders} from './backend-impl.js'; import {InferenceSessionHandler} from './backend.js'; import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; @@ -195,11 +195,9 @@ export class InferenceSession implements InferenceSessionInterface { throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.'); } - // get backend hints - const eps = options.executionProviders || []; - const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); - const backend = await resolveBackend(backendHints); - const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options); + // resolve backend, update session options with validated EPs, and create session handler + const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); + const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs); TRACE_FUNC_END(); return new InferenceSession(handler); } diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index 23bd4421ae672..bae38b0dfda5a 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {resolveBackend} from './backend-impl.js'; +import {resolveBackendAndExecutionProviders} from './backend-impl.js'; import {SessionHandler, TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; @@ -55,13 +55,12 @@ export class TrainingSession implements TrainingSessionInterface { const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; const options: SessionOptions = sessionOptions || {}; - // get backend hints - const eps = options.executionProviders || []; - const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); - const backend = await resolveBackend(backendHints); + // resolve backend, update session options with validated EPs, and create session handler + const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); if (backend.createTrainingSessionHandler) { const handler = await backend.createTrainingSessionHandler( - trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, + optionsWithValidatedEPs); return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel); } else { throw new Error(noBackendErrMsg); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 5dd715191c830..56925b728e9a3 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -16,20 +16,97 @@ export declare namespace JSEP { type CaptureBeginFunction = () => void; type CaptureEndFunction = () => void; type ReplayFunction = () => void; -} -export interface OrtWasmModule extends EmscriptenModule { - // #region emscripten functions - stackSave(): number; - stackRestore(stack: number): void; - stackAlloc(size: number): number; - - UTF8ToString(offset: number, maxBytesToRead?: number): string; - lengthBytesUTF8(str: string): number; - stringToUTF8(str: string, offset: number, maxBytes: number): void; - // #endregion + export interface Module extends WebGpuModule { + /** + * Mount the external data file to an internal map, which will be used during session initialization. + * + * @param externalDataFilePath - specify the relative path of the external data file. + * @param externalDataFileData - specify the content data. + */ + mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void; + /** + * Unmount all external data files from the internal map. + */ + unmountExternalData(): void; + + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per + * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and + * registers a few callbacks that will be called in C++ code. + */ + jsepInit(name: 'webgpu', initParams: [ + backend: BackendType, alloc: AllocFunction, free: FreeFunction, upload: UploadFunction, + download: DownloadFunction, createKernel: CreateKernelFunction, releaseKernel: ReleaseKernelFunction, + run: RunFunction, captureBegin: CaptureBeginFunction, captureEnd: CaptureEndFunction, replay: ReplayFunction + ]): void; + jsepInit(name: 'webnn', initParams?: never): void; + } + + export interface WebGpuModule { + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ + _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ + _JsepGetNodeName(kernel: number): number; + + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. + * @param sessionId - specify the session ID. + * @returns + */ + jsepOnReleaseSession: (sessionId: number) => void; + } +} - // #region ORT APIs +export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; @@ -74,126 +151,61 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtReleaseRunOptions(runOptionsHandle: number): void; _OrtEndProfiling(sessionHandle: number): number; - // #endregion +} + +export interface OrtTrainingAPIs { + _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - // #region ORT Training APIs - _OrtTrainingLoadCheckpoint?(dataOffset: number, dataLength: number): number; + _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - _OrtTrainingReleaseCheckpoint?(checkpointHandle: number): void; + _OrtTrainingCreateSession( + sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, + evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number; - _OrtTrainingCreateSession? - (sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, - evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number; + _OrtTrainingLazyResetGrad(trainingHandle: number): number; - _OrtTrainingLazyResetGrad?(trainingHandle: number): number; + _OrtTrainingRunTrainStep( + trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, + runOptionsHandle: number): number; - _OrtTrainingRunTrainStep? - (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - _OrtTrainingOptimizerStep?(trainingHandle: number, runOptionsHandle: number): number; + _OrtTrainingEvalStep( + trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, + runOptionsHandle: number): number; - _OrtTrainingEvalStep? - (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number, - runOptionsHandle: number): number; + _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - _OrtTrainingGetParametersSize?(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; + _OrtTrainingCopyParametersToBuffer( + trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingCopyParametersToBuffer? - (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingCopyParametersFromBuffer( + trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; - _OrtTrainingCopyParametersFromBuffer? - (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetModelInputOutputCount( + trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): + number; + + _OrtTrainingReleaseSession(trainingHandle: number): void; +} - _OrtTrainingGetModelInputOutputCount? - (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; - _OrtTrainingGetModelInputOutputName? - (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, + Partial { + // #region emscripten functions + stackSave(): number; + stackRestore(stack: number): void; + stackAlloc(size: number): number; - _OrtTrainingReleaseSession?(trainingHandle: number): void; + UTF8ToString(offset: number, maxBytesToRead?: number): string; + lengthBytesUTF8(str: string): number; + stringToUTF8(str: string, offset: number, maxBytes: number): void; // #endregion // #region config numThreads?: number; mainScriptUrlOrBlob?: string|Blob; // #endregion - - // #region external data API - mountExternalData?(externalDataFilePath: string, externalDataFileData: Uint8Array): void; - unmountExternalData?(): void; - // #endregion - - // #region JSEP - /** - * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. - * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. - */ - jsepInit? - (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, - download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, - captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void; - - /** - * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). - * - * @param context - specify the kernel context pointer. - * @param index - specify the index of the output. - * @param data - specify the pointer to encoded data of type and dims. - */ - _JsepOutput(context: number, index: number, data: number): number; - /** - * [exported from wasm] Get name of an operator node. - * - * @param kernel - specify the kernel pointer. - * @returns the pointer to a C-style UTF8 encoded string representing the node name. - */ - _JsepGetNodeName(kernel: number): number; - - /** - * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. - * - * @param sessionId - specify the session ID. - * @param index - specify an integer to represent which input/output it is registering for. For input, it is the - * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index - * corresponding to the session's ouputNames. - * @param buffer - specify the GPU buffer to register. - * @param size - specify the original data size in byte. - * @returns the GPU data ID for the registered GPU buffer. - */ - jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. - * - * @param dataId - specify the GPU data ID - * @returns the GPU buffer. - */ - jsepGetBuffer: (dataId: number) => GPUBuffer; - /** - * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. - * - * @param gpuBuffer - specify the GPU buffer - * @param size - specify the original data size in byte. - * @param type - specify the tensor type. - * @returns the generated downloader function. - */ - jsepCreateDownloader: - (gpuBuffer: GPUBuffer, size: number, - type: Tensor.GpuBufferDataTypes) => () => Promise; - /** - * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before - * _OrtRun[WithBinding]() is called. - * @param sessionId - specify the session ID. - */ - jsepOnRunStart: (sessionId: number) => void; - /** - * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is - * called. - * @param sessionId - specify the session ID. - * @returns - */ - jsepOnReleaseSession: (sessionId: number) => void; - // #endregion } declare const moduleFactory: EmscriptenModuleFactory; diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 4936b94ef7a86..adcaa145cdca8 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -121,7 +121,7 @@ class ComputeContextImpl implements ComputeContext { for (let i = 0; i < dims.length; i++) { this.module.HEAPU32[offset++] = dims[i]; } - return this.module._JsepOutput(this.opKernelContext, index, data); + return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { throw new Error( `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + @@ -136,27 +136,39 @@ class ComputeContextImpl implements ComputeContext { /** * Initialize JSEP with WebGPU backend. * - * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called). - * This function expects: + * This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for + * each of the following EPs if they are specified: + * - "webgpu" + * - "webnn" + * + * For WebGPU, this function expects: * - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). * - WebGPU is available in current environment. (a valid GPUAdapter is passed in) + * + * For WebNN, this function expects: + * - WebNN is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false). + * - WebNN is available in current environment. (navigator.ml is not undefined) + * * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate - * 'webgpu' backend. + * 'webgpu'/'webnn' backend. * + * @param name - the name of the EP, either "webgpu" or "webnn" * @param module - the ORT WebAssembly module * @param env - the ORT environment variable (ort.env) * @param gpuAdapter - the pre-created GPU adapter */ -export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise => { +export const init = + async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise => { const jsepInit = module.jsepInit; if (!jsepInit) { throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.'); } - const backend = new WebGpuBackend(); - await backend.initialize(env, gpuAdapter); + if (name === 'webgpu') { + const backend = new WebGpuBackend(); + await backend.initialize(env, gpuAdapter!); - jsepInit( + jsepInit('webgpu', [ // backend backend, @@ -190,8 +202,8 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte }, // jsepCreateKernel - (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName(kernelId))), + (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel( + kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), @@ -210,5 +222,9 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte // jsepCaptureEnd () => backend.captureEnd(), // jsepReplay - () => backend.replay()); + () => backend.replay() + ]); + } else { + jsepInit('webnn'); + } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 86017a4ec6904..6ff4e86b1235e 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -155,7 +155,7 @@ export const createSession = ensureWorker(); return new Promise((resolve, reject) => { enqueueCallbacks('create', [resolve, reject]); - const message: OrtWasmMessage = {type: 'create', in : {model, options}}; + const message: OrtWasmMessage = {type: 'create', in : {model, options: {...options}}}; const transferable: Transferable[] = []; if (model instanceof Uint8Array) { transferable.push(model.buffer); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index afab9ba00b0c4..7019758be0efd 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -84,35 +84,44 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { - // perform WebGPU availability check - if (typeof navigator === 'undefined' || !navigator.gpu) { - throw new Error('WebGPU is not supported in current environment'); - } - const powerPreference = env.webgpu?.powerPreference; - if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') { - throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); - } - const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter; - if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { - throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); - } - const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); - if (!adapter) { - throw new Error( - 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); - } + if (!BUILD_DEFS.DISABLE_WEBGPU) { + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + const initJsep = require('./jsep/init').init; - if (!env.wasm.simd) { - throw new Error( - 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); - } + if (epName === 'webgpu') { + // perform WebGPU availability check + if (typeof navigator === 'undefined' || !navigator.gpu) { + throw new Error('WebGPU is not supported in current environment'); + } + const powerPreference = env.webgpu?.powerPreference; + if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') { + throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); + } + const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter; + if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { + throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); + } + const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + if (!adapter) { + throw new Error( + 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + } - // init JSEP if available + if (!env.wasm.simd) { + throw new Error( + 'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP'); + } - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - const initJsep = require('./jsep/init').init; - await initJsep(getInstance(), env, adapter); + await initJsep('webgpu', getInstance(), env, adapter); + } + if (epName === 'webnn') { + // perform WebNN availability check + if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) { + throw new Error('WebNN is not supported in current environment'); + } + + await initJsep('webnn', getInstance(), env); + } } }; @@ -380,7 +389,12 @@ export const prepareInputOutputTensor = const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + + const registerBuffer = wasm.jsepRegisterBuffer; + if (!registerBuffer) { + throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); + } + rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); } else { const data = tensor[2]; @@ -595,7 +609,11 @@ export const run = async( // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU // tensor for it. There is no mapping GPU buffer for an empty tensor. if (preferredLocation === 'gpu-buffer' && size > 0) { - const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const getBuffer = wasm.jsepGetBuffer; + if (!getBuffer) { + throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); + } + const gpuBuffer = getBuffer(dataOffset); const elementSize = getTensorElementSize(dataType); if (elementSize === undefined || !isGpuBufferSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); @@ -607,7 +625,7 @@ export const run = async( output.push([ type, dims, { gpuBuffer, - download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type), dispose: () => { wasm._OrtReleaseTensor(tensor); } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index cbc60c70b57aa..90d8b737252e5 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -4,39 +4,27 @@ 'use strict'; /** - * Mount external data files of a model to the virtual file system (MEMFS). + * Mount external data files of a model to an internal map, which will be used during session initialization. * * @param {string} externalDataFilesPath * @param {Uint8Array} externalDataFilesData */ Module['mountExternalData'] = (externalDataFilePath, externalDataFileData) => { const files = Module.MountedFiles || (Module.MountedFiles = new Map()); - files.set(externalDataFilePath, externalDataFileData); + files.set(externalDataFilePath, externalDataFileData); }; /** - * Unmount external data files of a model from the virtual file system (MEMFS). + * Unmount external data files of a model. */ Module['unmountExternalData'] = () => { delete Module.MountedFiles; }; /** - * init JSEP + * initialize JSEP for asyncify support. */ -Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => { - Module.jsepBackend = backend; - Module.jsepAlloc = alloc; - Module.jsepFree = free; - Module.jsepCopy = copy; - Module.jsepCopyAsync = copyAsync; - Module.jsepCreateKernel = createKernel; - Module.jsepReleaseKernel = releaseKernel; - Module.jsepRunKernel = runKernel; - Module.jsepCaptureBegin = captureBegin; - Module.jsepCaptureEnd = captureEnd; - Module.jsepReplay = replay; - +let jsepInitAsync = () => { // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) // It removes some overhead in cwarp() and ccall() that we don't need. // @@ -143,7 +131,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea } // Flush the backend. This will submit all pending commands to the GPU. - backend['flush'](); + Module.jsepBackend?.['flush'](); // Await all pending promises. This includes GPU validation promises for diagnostic purposes. const errorPromises = state.errors; @@ -180,20 +168,46 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea () => Module['_OrtBindInput'], v => Module['_OrtBindInput'] = v); - // expose webgpu backend functions - Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { - return backend['registerBuffer'](sessionId, index, buffer, size); - }; - Module['jsepGetBuffer'] = (dataId) => { - return backend['getBuffer'](dataId); - }; - Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { - return backend['createDownloader'](gpuBuffer, size, type); - }; - Module['jsepOnReleaseSession'] = sessionId => { - backend['onReleaseSession'](sessionId); - }; - Module['jsepOnRunStart'] = sessionId => { - return backend['onRunStart'](sessionId); - }; + // remove this function to make sure it is called only once. + jsepInitAsync = undefined; +}; + + +/** + * initialize JSEP for WebGPU. + */ +Module['jsepInit'] = (name, params) => { + jsepInitAsync?.(); + + if (name === 'webgpu') { + [Module.jsepBackend, + Module.jsepAlloc, + Module.jsepFree, + Module.jsepCopy, + Module.jsepCopyAsync, + Module.jsepCreateKernel, + Module.jsepReleaseKernel, + Module.jsepRunKernel, + Module.jsepCaptureBegin, + Module.jsepCaptureEnd, + Module.jsepReplay] = params; + + // expose webgpu backend functions + const backend = Module.jsepBackend; + Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { + return backend['registerBuffer'](sessionId, index, buffer, size); + }; + Module['jsepGetBuffer'] = (dataId) => { + return backend['getBuffer'](dataId); + }; + Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { + return backend['createDownloader'](gpuBuffer, size, type); + }; + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); + }; + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); + }; + } }; From 0208b1e9aab22b715ed3679b6f90692ca838bea8 Mon Sep 17 00:00:00 2001 From: Belem Zhang Date: Sat, 16 Mar 2024 10:00:30 +0800 Subject: [PATCH 44/51] Fix #19931 broken Get Started link of "ONNX Runtime JavaScript API" page (#19932) ### Description Fix #19931 broken Get Started link HTTP 404 for "Get Started" link in "ONNX Runtime JavaScript API" page Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- js/common/lib/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index d7c98380f3fa4..18cc2aba03f63 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -11,7 +11,7 @@ * - [onnxruntime-react-native](https://www.npmjs.com/package/onnxruntime-react-native) * * See also: - * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript.html) + * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript/) * - [Inference examples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js) * * @packageDocumentation From 0ece4eff1909b309f8dae5a3a59d0a1eb7bb9f5f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 15 Mar 2024 19:01:50 -0700 Subject: [PATCH 45/51] [js/common] fix typedoc warnings (#19933) ### Description Fix a few warnings in typedoc (for generating JS API): ``` [warning] The signature TrainingSession.loadParametersBuffer has an @param with name "buffer", which was not used. [warning] NonTensorType, defined in ./lib/onnx-value.ts, is referenced by OnnxValue but not included in the documentation. [warning] TensorFactory, defined in ./lib/tensor-factory.ts, is referenced by Tensor but not included in the documentation. [warning] ExternalDataFileType, defined in ./lib/onnx-model.ts, is referenced by InferenceSession.SessionOptions.externalData but not included in the documentation. [warning] TensorToDataUrlOptions, defined in ./lib/tensor-conversion.ts, is referenced by Tensor.toDataURL.toDataURL.options but not included in the documentation. [warning] TensorToImageDataOptions, defined in ./lib/tensor-conversion.ts, is referenced by Tensor.toImageData.toImageData.options but not included in the documentation. [warning] Failed to resolve link to "GpuBufferType" in comment for Env.WebGpuFlags.adapter. [warning] Failed to resolve link to "GpuBufferType" in comment for Env.WebGpuFlags.device. ``` Changes highlighted: - Merge `CoreMlExecutionProviderOption` and `CoreMLExecutionProviderOption`. They expose 2 set of different options for React-native and ORT nodejs binding. This should be fixed in future. - Fix a few inconsistency of names between JSDoc and parameters - Fix broken type links - Exclude trace functions --- js/common/lib/backend.ts | 6 +-- js/common/lib/env.ts | 4 +- js/common/lib/index.ts | 3 ++ js/common/lib/inference-session.ts | 43 +++++++++++++++---- js/common/lib/onnx-value.ts | 2 +- js/common/lib/tensor-factory.ts | 2 +- js/common/lib/tensor.ts | 4 +- js/common/lib/trace.ts | 9 ++++ js/common/lib/training-session.ts | 16 +++---- .../templates/linux-web-init-and-check.yml | 4 ++ 10 files changed, 68 insertions(+), 25 deletions(-) diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 9bfcb12206057..8c07bdd5c5c4a 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -58,7 +58,7 @@ export interface TrainingSessionHandler extends SessionHandler { options: InferenceSession.RunOptions): Promise; getParametersSize(trainableOnly: boolean): Promise; - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; getContiguousParameters(trainableOnly: boolean): Promise; } @@ -77,8 +77,8 @@ export interface Backend { Promise; createTrainingSessionHandler? - (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer, - evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer, + (checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, trainModelUriOrBuffer: TrainingSession.UriOrBuffer, + evalModelUriOrBuffer: TrainingSession.UriOrBuffer, optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer, options: InferenceSession.SessionOptions): Promise; } diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index dd8bde2b596f4..b139c719e863f 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -173,7 +173,7 @@ export declare namespace Env { * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. * - * see comments on {@link GpuBufferType} + * see comments on {@link Tensor.GpuBufferType} */ readonly adapter: unknown; /** @@ -184,7 +184,7 @@ export declare namespace Env { * When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types". * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type. * - * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". + * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; /** diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 18cc2aba03f63..3ed56b3c2e812 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -21,6 +21,9 @@ export * from './backend.js'; export * from './env.js'; export * from './inference-session.js'; export * from './tensor.js'; +export * from './tensor-conversion.js'; +export * from './tensor-factory.js'; export * from './trace.js'; +export * from './onnx-model.js'; export * from './onnx-value.js'; export * from './training-session.js'; diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 4f85c3b46e253..4f7fbdcdcf0ca 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -186,22 +186,22 @@ export declare namespace InferenceSession { // #region execution providers // Currently, we have the following backends to support execution providers: - // Backend Node.js binding: supports 'cpu' and 'cuda'. + // Backend Node.js binding: supports 'cpu', 'dml' (win32), 'coreml' (macOS) and 'cuda' (linux). // Backend WebAssembly: supports 'cpu', 'wasm', 'webgpu' and 'webnn'. // Backend ONNX.js: supports 'webgl'. // Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android). interface ExecutionProviderOptionMap { + coreml: CoreMLExecutionProviderOption; cpu: CpuExecutionProviderOption; - coreml: CoreMlExecutionProviderOption; cuda: CudaExecutionProviderOption; dml: DmlExecutionProviderOption; + nnapi: NnapiExecutionProviderOption; tensorrt: TensorRtExecutionProviderOption; wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; - xnnpack: XnnpackExecutionProviderOption; webgpu: WebGpuExecutionProviderOption; webnn: WebNNExecutionProviderOption; - nnapi: NnapiExecutionProviderOption; + xnnpack: XnnpackExecutionProviderOption; } type ExecutionProviderName = keyof ExecutionProviderOptionMap; @@ -219,10 +219,6 @@ export declare namespace InferenceSession { readonly name: 'cuda'; deviceId?: number; } - export interface CoreMlExecutionProviderOption extends ExecutionProviderOption { - readonly name: 'coreml'; - coreMlFlags?: number; - } export interface DmlExecutionProviderOption extends ExecutionProviderOption { readonly name: 'dml'; deviceId?: number; @@ -253,8 +249,39 @@ export declare namespace InferenceSession { } export interface CoreMLExecutionProviderOption extends ExecutionProviderOption { readonly name: 'coreml'; + /** + * The bit flags for CoreML execution provider. + * + * ``` + * COREML_FLAG_USE_CPU_ONLY = 0x001 + * COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002 + * COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004 + * COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008 + * COREML_FLAG_CREATE_MLPROGRAM = 0x010 + * ``` + * + * See include/onnxruntime/core/providers/coreml/coreml_provider_factory.h for more details. + * + * This flag is available only in ONNXRuntime (Node.js binding). + */ + coreMlFlags?: number; + /** + * Specify whether to use CPU only in CoreML EP. + * + * This setting is available only in ONNXRuntime (react-native). + */ useCPUOnly?: boolean; + /** + * Specify whether to enable CoreML EP on subgraph. + * + * This setting is available only in ONNXRuntime (react-native). + */ enableOnSubgraph?: boolean; + /** + * Specify whether to only enable CoreML EP for Apple devices with ANE (Apple Neural Engine). + * + * This setting is available only in ONNXRuntime (react-native). + */ onlyEnableDeviceWithANE?: boolean; } export interface NnapiExecutionProviderOption extends ExecutionProviderOption { diff --git a/js/common/lib/onnx-value.ts b/js/common/lib/onnx-value.ts index a16a30d25d839..72369ce8b4209 100644 --- a/js/common/lib/onnx-value.ts +++ b/js/common/lib/onnx-value.ts @@ -3,7 +3,7 @@ import {Tensor} from './tensor.js'; -type NonTensorType = never; +export type NonTensorType = never; /** * Type OnnxValue Represents both tensors and non-tensors value for model's inputs/outputs. diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 6e19d7fb898a3..431de4c3635c2 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -253,7 +253,7 @@ export interface TensorFactory { /** * create a tensor from an ImageBitmap object * - * @param bitMap - the ImageBitmap object to create tensor from + * @param bitmap - the ImageBitmap object to create tensor from * @param options - An optional object representing options for creating tensor from URL. * * The following default settings will be applied: diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index d5da33640dc7d..20319ebb800c2 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -160,7 +160,7 @@ export interface Tensor extends TypedTensorBase, TypedTensorUtils { if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; @@ -29,6 +32,9 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => { } }; +/** + * @ignore + */ export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; @@ -36,6 +42,9 @@ export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { TRACE_FUNC('BEGIN', extraMsg); }; +/** + * @ignore + */ export const TRACE_FUNC_END = (extraMsg?: string) => { if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts index e54aed90e702c..f9de77e3ac7d0 100644 --- a/js/common/lib/training-session.ts +++ b/js/common/lib/training-session.ts @@ -11,7 +11,7 @@ export declare namespace TrainingSession { /** * Either URI file path (string) or Uint8Array containing model or checkpoint information. */ - type URIorBuffer = string|Uint8Array; + type UriOrBuffer = string|Uint8Array; } /** @@ -98,13 +98,13 @@ export interface TrainingSession { getParametersSize(trainableOnly: boolean): Promise; /** - * Copies parameter values from the given array to the training state. Currently, only supporting models with + * Copies parameter values from the given buffer to the training state. Currently, only supporting models with * parameters of type Float32. * - * @param buffer - Float32 buffer containing parameters converted to a Uint8Array. + * @param buffer - A Uint8Array representation of Float32 parameters. * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true. */ - loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise; /** * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning. @@ -157,19 +157,19 @@ export interface TrainingSessionCreateOptions { /** * URI or buffer for a .ckpt file that contains the checkpoint for the training model. */ - checkpointState: TrainingSession.URIorBuffer; + checkpointState: TrainingSession.UriOrBuffer; /** * URI or buffer for the .onnx training file. */ - trainModel: TrainingSession.URIorBuffer; + trainModel: TrainingSession.UriOrBuffer; /** * Optional. URI or buffer for the .onnx optimizer model file. */ - optimizerModel?: TrainingSession.URIorBuffer; + optimizerModel?: TrainingSession.UriOrBuffer; /** * Optional. URI or buffer for the .onnx eval model file. */ - evalModel?: TrainingSession.URIorBuffer; + evalModel?: TrainingSession.UriOrBuffer; } /** diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml index e788e4b3dddaa..a4d5a73118ea2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml @@ -31,6 +31,10 @@ steps: node -e "a=require('child_process').execSync('git diff --name-only').toString();if(a)throw new Error('Following source files are not formatted: (did you run \"npm run format\"?)\n'+a)" workingDirectory: '$(Build.SourcesDirectory)/js' displayName: 'Check unformatted files' +- script: | + npx typedoc --emit none --treatWarningsAsErrors + workingDirectory: '$(Build.SourcesDirectory)/js/common' + displayName: 'TypeDoc Validation' - script: | npm run build:doc workingDirectory: '$(Build.SourcesDirectory)/js/web' From 8f03331d0620f4fe28129718ba627a1248370049 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Mar 2024 18:53:17 -0700 Subject: [PATCH 46/51] Bump follow-redirects from 1.15.4 to 1.15.6 in /js/web (#19949) Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.15.4 to 1.15.6.
Commits
  • 35a517c Release version 1.15.6 of the npm package.
  • c4f847f Drop Proxy-Authorization across hosts.
  • 8526b4a Use GitHub for disclosure.
  • b1677ce Release version 1.15.5 of the npm package.
  • d8914f7 Preserve fragment in responseUrl.
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=follow-redirects&package-manager=npm_and_yarn&previous-version=1.15.4&new-version=1.15.6)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/web/package-lock.json | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/js/web/package-lock.json b/js/web/package-lock.json index f177fbaf10be0..8a6c20e55004c 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -52,7 +52,7 @@ "version": "1.17.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/@chiragrupani/karma-chromium-edge-launcher": { @@ -1351,9 +1351,9 @@ "dev": true }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -4595,9 +4595,9 @@ "dev": true }, "follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "from": { @@ -5503,7 +5503,7 @@ "onnxruntime-common": { "version": "file:../common", "requires": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "p-cancelable": { From bdb3b0123fcc8c1c5713d6c5f3f6203d1dba2481 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 16 Mar 2024 18:54:53 -0700 Subject: [PATCH 47/51] Bump follow-redirects from 1.15.4 to 1.15.6 in /js/node (#19951) Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.15.4 to 1.15.6.
Commits
  • 35a517c Release version 1.15.6 of the npm package.
  • c4f847f Drop Proxy-Authorization across hosts.
  • 8526b4a Use GitHub for disclosure.
  • b1677ce Release version 1.15.5 of the npm package.
  • d8914f7 Preserve fragment in responseUrl.
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=follow-redirects&package-manager=npm_and_yarn&previous-version=1.15.4&new-version=1.15.6)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/node/package-lock.json | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 094c0b8a7de04..3d78902dcc56a 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -30,7 +30,7 @@ "version": "1.17.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/@protobufjs/aspromise": { @@ -336,9 +336,9 @@ "dev": true }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -1242,9 +1242,9 @@ "dev": true }, "follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "form-data": { @@ -1503,7 +1503,7 @@ "onnxruntime-common": { "version": "file:../common", "requires": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "parse-json": { From b2c48b97009376257ba4e4916b57214afb98423d Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 18 Mar 2024 08:28:43 -0700 Subject: [PATCH 48/51] accumulate in fp32 for Reduce* (#19868) --- js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index a9b28d7c034f3..210b3ee7e2fca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -131,7 +131,7 @@ export const createReduceSharedProgramInfo = const workgroupSize = 32; const sharedMemorySnippet = ` - var aBestValues : array<${output.type.storage}, ${workgroupSize}>; + var aBestValues : array; `; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -145,10 +145,10 @@ export const createReduceSharedProgramInfo = let outputIndex = global_idx / ${workgroupSize}; let offset = outputIndex * uniforms.reduceSize; - var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]}); + var bestValue = f32(${reduceInitValues[reduceType]}); let Length = uniforms.reduceSize; for (var k = local_idx; k < Length; k = k + ${workgroupSize}) { - let candidate = ${output.type.storage}(${input.getByOffset('offset + k')}); + let candidate = f32(${input.getByOffset('offset + k')}); bestValue = ${reduceOps[reduceType]}; } aBestValues[local_idx] = bestValue; @@ -172,8 +172,8 @@ export const createReduceSharedProgramInfo = output.setByOffset( 'outputIndex', `${ - reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` : - `${reduceOutputValues[reduceType]}`}`)}; + reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` : + `${output.type.storage}(${reduceOutputValues[reduceType]})`}`)}; } }`; From 333efa02d652a83a403687852f9a1ba6adc306e6 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 19 Mar 2024 13:59:32 +0800 Subject: [PATCH 49/51] [js/webgpu] Fix NAN caused by un-initialized buffer in instance-norm (#19387) The added case will be NAN because of the un-initialized buffer. --- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 2 +- js/web/test/data/ops/instance-norm.jsonc | 80 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 2f652dbd310ab..2c72def089144 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -207,7 +207,7 @@ const computeMean = let offset = currentImageNumber * uniforms.image_size; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < ${WG}; i++) { + for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) { let value = input[offset + i + currentChannelNumber * ${WG}]; sum += value[0]; squaredSum += value[1]; diff --git a/js/web/test/data/ops/instance-norm.jsonc b/js/web/test/data/ops/instance-norm.jsonc index e89ac2da3795f..f28b016d47ab9 100644 --- a/js/web/test/data/ops/instance-norm.jsonc +++ b/js/web/test/data/ops/instance-norm.jsonc @@ -224,5 +224,85 @@ ] } ] + }, + { + "name": "Simple test with NHWC, components 1, buffer reuse", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { + "domain": "", + "version": 17 + }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 4, 5, 6], + "dims": [2, 3, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NHWC, components 2, buffer reuse", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { + "domain": "", + "version": 17 + }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2], + "dims": [1, 6, 1, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8, 9], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539, + 16.348413467407227, 9, 1.6515865325927734 + ], + "dims": [1, 6, 1, 3], + "type": "float32" + } + ] + } + ] } ] From 53377a3e9589b9cd57d3dff26a68b28b7a8be2c7 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 19 Mar 2024 12:55:00 -0700 Subject: [PATCH 50/51] [js/webgpu] allow setting env.webgpu.adapter (#19940) ### Description Allow user to set `env.webgpu.adapter` before creating the first inference session. Feature request: https://github.com/microsoft/onnxruntime/pull/19857#issuecomment-1999984753 @xenova --- js/common/lib/env.ts | 10 +++++--- js/web/lib/wasm/jsep/backend-webgpu.ts | 6 +++-- js/web/lib/wasm/wasm-core-impl.ts | 35 ++++++++++++++++++-------- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index b139c719e863f..c8df1613b3268 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -166,16 +166,20 @@ export declare namespace Env { */ forceFallbackAdapter?: boolean; /** - * Get the adapter for WebGPU. + * Set or get the adapter for WebGPU. * - * This property is only available after the first WebGPU inference session is created. + * Setting this property only has effect before the first WebGPU inference session is created. The value will be + * used as the GPU adapter for the underlying WebGPU backend to create GPU device. + * + * If this property is not set, it will be available to get after the first WebGPU inference session is created. The + * value will be the GPU adapter that created by the underlying WebGPU backend. * * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types". * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type. * * see comments on {@link Tensor.GpuBufferType} */ - readonly adapter: unknown; + adapter: unknown; /** * Get the device for WebGPU. * diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index d92b8ac68dbe7..b36dc73330d46 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -252,8 +252,10 @@ export class WebGpuBackend { } }; - Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - Object.defineProperty(this.env.webgpu, 'adapter', {value: adapter}); + Object.defineProperty( + this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false}); + Object.defineProperty( + this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false}); // init queryType, which is necessary for InferenceSession.create this.setQueryType(); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 7019758be0efd..9b27051f1b9fe 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -93,18 +93,31 @@ export const initEp = async(env: Env, epName: string): Promise => { if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); } - const powerPreference = env.webgpu?.powerPreference; - if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') { - throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); - } - const forceFallbackAdapter = env.webgpu?.forceFallbackAdapter; - if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { - throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); - } - const adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + + let adapter = env.webgpu.adapter as GPUAdapter | null; if (!adapter) { - throw new Error( - 'Failed to get GPU adapter. You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + // if adapter is not set, request a new adapter. + const powerPreference = env.webgpu.powerPreference; + if (powerPreference !== undefined && powerPreference !== 'low-power' && + powerPreference !== 'high-performance') { + throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); + } + const forceFallbackAdapter = env.webgpu.forceFallbackAdapter; + if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { + throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); + } + adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter}); + if (!adapter) { + throw new Error( + 'Failed to get GPU adapter. ' + + 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.'); + } + } else { + // if adapter is set, validate it. + if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' || + typeof adapter.requestDevice !== 'function') { + throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.'); + } } if (!env.wasm.simd) { From 310c099b9999e644381d3daa980974fbdf43d925 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 19 Mar 2024 16:15:49 -0700 Subject: [PATCH 51/51] [js/webgpu] fix maxpool / fp16 (#19981) --- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 4e933573b9137..5521650e8ded4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -381,8 +381,9 @@ const createMaxPoolProgramInfo = programUniforms }), getShaderSource: shaderHelper => generatePoolingCode( - shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms, - hasPads, pwStartEndNotZero, phStartEndNotZero), + shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, + (input.dataType === DataType.float16) ? -65504 : -1e5, uniforms, hasPads, pwStartEndNotZero, + phStartEndNotZero), }; };