From ce329e701195ceb28ffa7bcc62b0be0b29855abb Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Mon, 11 Sep 2023 15:02:46 +0400 Subject: [PATCH] FP16 Gemm, Softmax & Transpose --- js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 10 +-- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 14 ++--- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 6 -- .../core/providers/js/js_data_types.cc | 18 ++++++ onnxruntime/core/providers/js/js_data_types.h | 3 +- .../providers/js/js_execution_provider.cc | 28 ++++----- onnxruntime/core/providers/js/js_kernel.h | 1 + .../core/providers/js/operators/gemm.cc | 63 +++++++++---------- .../core/providers/js/operators/gemm.h | 1 - .../core/providers/js/operators/softmax.cc | 14 ++--- .../core/providers/js/operators/softmax.h | 1 - .../core/providers/js/operators/transpose.cc | 8 +-- 12 files changed, 79 insertions(+), 88 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 46816a2410586..efcf9ce0dd230 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs) { @@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { throw new Error('Invalid input shape of C'); } - if ((inputs[0].dataType !== DataType.float) || (inputs[1].dataType !== DataType.float) || - (inputs.length === 3 && inputs[2].dataType !== DataType.float)) { - throw new Error('Invalid input type.'); - } - if ((inputs[0].dataType !== inputs[1].dataType) || (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType)) { throw new Error('Input types are mismatched'); @@ -81,7 +75,7 @@ const createGemmProgramInfo = line = 'value += a[m * K + k] * b[k * N + n];'; } - const dataType = 'f32'; // TODO: support other data type + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); // TODO: support other data type const calculateAlpha = attributes.alpha === 1 ? '' : 'value *= alpha;'; const calculateC = inputs.length === 3 ? `value += beta * c[${offsetC(M, N, inputs[2].dims)}];` : ''; const inputStorageBuffersDeclarations = [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index bdbf05e2f185e..e2443b24410a5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,21 +5,17 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Softmax op requires 1 input.'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Softmax input needs to be float.'); - } }; export interface SoftmaxAttributes extends AttributeWithCacheKey { @@ -33,7 +29,7 @@ export const softmaxProgramMetadata = { const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => { - const dataType = 'f32'; + const dataType = tensorTypeToWsglStorageType(input.dataType); const shape = input.dims; const outputSize = ShapeUtil.size(shape); const WG = 64; @@ -48,6 +44,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut const cols = shape[axis]; const rows = outputSize / cols; + // 6.2.4 in wgsl spec + const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;'; const getShaderSource = (_shaderHelper: ShaderHelper) => ` var rowMaxShared : ${dataType}; var rowSumShared : ${dataType}; @@ -76,7 +74,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut let row_stride : i32 = ${cols}; // find the rows max - var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec + ${threadMaxDecl} for (var col = lindex; col < cols; col += wg) { let value = getValue(row, col, row_stride); threadMax = max(threadMax, value); @@ -100,7 +98,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut workgroupBarrier(); // find the rows sum - var threadSum = 0.0; + var threadSum: ${dataType} = 0.0; for (var col = lindex; col < cols; col += wg) { let subExp = exp(getValue(row, col, row_stride) - rowMaxShared); threadSum += subExp; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index ebedc61712e8a..9243b0e4af6b6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -22,11 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Transpose requires 1 input.'); } - - if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.int32 && - inputs[0].dataType !== DataType.uint32) { - throw new Error('Transpose only support float, int32, and uint32 data types'); - } }; const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => diff --git a/onnxruntime/core/providers/js/js_data_types.cc b/onnxruntime/core/providers/js/js_data_types.cc index 69d5bd4f9de8f..e39773fd07137 100644 --- a/onnxruntime/core/providers/js/js_data_types.cc +++ b/onnxruntime/core/providers/js/js_data_types.cc @@ -9,12 +9,30 @@ namespace js { using SupportedTypes = TypeList< float, +#ifdef ENABLE_WEBASSEMBLY_FLOAT16 + MLFloat16, +#endif int32_t, uint32_t>; +using SupportedFloats = +#ifdef ENABLE_WEBASSEMBLY_FLOAT16 + TypeList< + float, + MLFloat16>; +#else + TypeList; +#endif + const std::vector& JsepSupportedDataTypes() { static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); return supportedDataTypes; } + +const std::vector& JsepSupportedFloatTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + } // namespace js } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/js/js_data_types.h b/onnxruntime/core/providers/js/js_data_types.h index d6b6ac00401b3..968e79124645f 100644 --- a/onnxruntime/core/providers/js/js_data_types.h +++ b/onnxruntime/core/providers/js/js_data_types.h @@ -6,5 +6,6 @@ namespace onnxruntime { namespace js { std::vector& JsepSupportedDataTypes(); -} +std::vector& JsepSupportedFloatTypes(); +} // namespace js } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c5b3b1933e04c..dbf918ac2c499 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -244,10 +244,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); @@ -269,9 +269,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Softmax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Softmax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 3, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat); @@ -496,10 +496,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -522,9 +522,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 3accd80875d1b..177c0a9e691ed 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -11,6 +11,7 @@ #include "core/framework/op_kernel.h" #include "core/providers/js/js_execution_provider.h" +#include "core/providers/js/js_data_types.h" struct pthreadpool; diff --git a/onnxruntime/core/providers/js/operators/gemm.cc b/onnxruntime/core/providers/js/operators/gemm.cc index 04700d0f54705..de27288f2ee0e 100644 --- a/onnxruntime/core/providers/js/operators/gemm.cc +++ b/onnxruntime/core/providers/js/operators/gemm.cc @@ -8,41 +8,34 @@ namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 13, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 9, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Gemm, \ - kOnnxDomain, \ - 7, 8, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Gemm); - -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + Gemm, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Gemm); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Gemm, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Gemm); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Gemm, + kOnnxDomain, + 9, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Gemm); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Gemm, + kOnnxDomain, + 7, 8, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Gemm); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gemm.h b/onnxruntime/core/providers/js/operators/gemm.h index 27c41788ccfbd..74091526f8411 100644 --- a/onnxruntime/core/providers/js/operators/gemm.h +++ b/onnxruntime/core/providers/js/operators/gemm.h @@ -8,7 +8,6 @@ namespace onnxruntime { namespace js { -template class Gemm : public JsKernel { public: Gemm(const OpKernelInfo& info) : JsKernel(info) { diff --git a/onnxruntime/core/providers/js/operators/softmax.cc b/onnxruntime/core/providers/js/operators/softmax.cc index cbaecf9e4c975..292bd5006fb30 100644 --- a/onnxruntime/core/providers/js/operators/softmax.cc +++ b/onnxruntime/core/providers/js/operators/softmax.cc @@ -7,27 +7,25 @@ namespace onnxruntime { namespace js { #define REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(SoftmaxOp, sinceVersion, endVersion) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ SoftmaxOp, \ kOnnxDomain, \ sinceVersion, endVersion, \ - float, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SoftmaxOp); + .TypeConstraint("T", JsepSupportedFloatTypes()), \ + SoftmaxOp); #define REGISTER_SOFTMAX_ELEMENTWISE_KERNEL(SoftmaxOp, sinceVersion) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_KERNEL_EX( \ SoftmaxOp, \ kOnnxDomain, \ sinceVersion, \ - float, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ .InputMemoryType(OrtMemTypeCPU, 1), \ - SoftmaxOp); + SoftmaxOp); REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 1, 10); REGISTER_SOFTMAX_ELEMENTWISE_VERSIONED_KERNEL(Softmax, 11, 12); diff --git a/onnxruntime/core/providers/js/operators/softmax.h b/onnxruntime/core/providers/js/operators/softmax.h index 068a59e6b24e3..87259e8b6f206 100644 --- a/onnxruntime/core/providers/js/operators/softmax.h +++ b/onnxruntime/core/providers/js/operators/softmax.h @@ -8,7 +8,6 @@ namespace onnxruntime { namespace js { -template class Softmax : public JsKernel { public: Softmax(const OpKernelInfo& info) : JsKernel(info) { diff --git a/onnxruntime/core/providers/js/operators/transpose.cc b/onnxruntime/core/providers/js/operators/transpose.cc index ef1e49046ae8c..332bd35f2434c 100644 --- a/onnxruntime/core/providers/js/operators/transpose.cc +++ b/onnxruntime/core/providers/js/operators/transpose.cc @@ -12,9 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Transpose); ONNX_OPERATOR_KERNEL_EX( @@ -23,9 +21,7 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Transpose); } // namespace js