Skip to content

Commit

Permalink
[js/webgpu] enable f16 for concat (#18528)
Browse files Browse the repository at this point in the history
### Description
With this PR `realesrgan-t64-f16` models becomes 32.8 ms from 1052.55
ms. Now the whole model run on jsep.
  • Loading branch information
qjia7 authored Nov 21, 2023
1 parent 81a763a commit ac8598a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions onnxruntime/core/providers/js/operators/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1, 3,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
Concat);

Expand All @@ -22,7 +23,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
4, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
Concat);

Expand All @@ -32,7 +34,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
Concat);

Expand All @@ -42,7 +45,8 @@ ONNX_OPERATOR_KERNEL_EX(
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
Concat);

Expand Down

0 comments on commit ac8598a

Please sign in to comment.