Skip to content

Commit

Permalink
[webgpu] Add Alias def for Flatten (microsoft#23038)
Browse files Browse the repository at this point in the history
### Description

Add `Alias` definition for Flatten in WebGPU EP.

also add int32/uint32 in type constraint T.
  • Loading branch information
fs-eire authored and ankitm3k committed Dec 11, 2024
1 parent 6cd9a06 commit 8dd8bc5
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions onnxruntime/core/providers/webgpu/tensor/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,55 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
1, 8,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
9, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
13, 20,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_KERNEL_EX(
Flatten,
kOnnxDomain,
21,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime

0 comments on commit 8dd8bc5

Please sign in to comment.