From 9e69606360d7e77f9ee869beec2b8c9e43517fae Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 29 Jan 2024 10:13:46 -0800 Subject: [PATCH] fix f16 for attention, enable slice and flatten for more types (#19262) --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 2 +- onnxruntime/core/providers/js/operators/flatten.cc | 8 ++++---- onnxruntime/core/providers/js/operators/slice.cc | 12 ++++-------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index ef8038dff487e..f07a21a343fa8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -297,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView if (sum == 0) { for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { - x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')}; + x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')}; } } else { for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { diff --git a/onnxruntime/core/providers/js/operators/flatten.cc b/onnxruntime/core/providers/js/operators/flatten.cc index 7e4b4c350951b..1aacae819e304 100644 --- a/onnxruntime/core/providers/js/operators/flatten.cc +++ b/onnxruntime/core/providers/js/operators/flatten.cc @@ -13,7 +13,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -23,7 +23,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -33,7 +33,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); ONNX_OPERATOR_KERNEL_EX( @@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()) .Alias(0, 0) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", JsepSupportedFloatTypes()), Flatten); } // namespace js diff --git a/onnxruntime/core/providers/js/operators/slice.cc b/onnxruntime/core/providers/js/operators/slice.cc index bbafe40ea92ac..869b5450501e1 100644 --- a/onnxruntime/core/providers/js/operators/slice.cc +++ b/onnxruntime/core/providers/js/operators/slice.cc @@ -12,8 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 9, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice_1); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -26,8 +25,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3) .InputMemoryType(OrtMemTypeCPU, 4) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -40,8 +38,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3) .InputMemoryType(OrtMemTypeCPU, 4) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice); ONNX_OPERATOR_KERNEL_EX( @@ -54,8 +51,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPU, 2) .InputMemoryType(OrtMemTypeCPU, 3) .InputMemoryType(OrtMemTypeCPU, 4) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), - DataTypeImpl::GetTensorType()}), + .TypeConstraint("T", JsepSupportedDataTypes()), Slice); } // namespace js