From ac8598a837083dca599ca260152b71d14946de98 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 22 Nov 2023 06:26:00 +0800 Subject: [PATCH] [js/webgpu] enable f16 for concat (#18528) ### Description With this PR `realesrgan-t64-f16` models becomes 32.8 ms from 1052.55 ms. Now the whole model run on jsep. --- onnxruntime/core/providers/js/operators/concat.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/concat.cc b/onnxruntime/core/providers/js/operators/concat.cc index 3a6a7e1cafd7a..17c6b0466c3a5 100644 --- a/onnxruntime/core/providers/js/operators/concat.cc +++ b/onnxruntime/core/providers/js/operators/concat.cc @@ -12,7 +12,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 3, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -22,7 +23,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 4, 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -32,7 +34,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 11, 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -42,7 +45,8 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat);