Skip to content

Commit

Permalink
[webgpu] Add Alias def for Flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Dec 6, 2024
1 parent d27fecd commit b805dad
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions onnxruntime/core/providers/webgpu/tensor/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,55 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
1, 8,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
9, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Flatten,
kOnnxDomain,
13, 20,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

ONNX_OPERATOR_KERNEL_EX(
Flatten,
kOnnxDomain,
21,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()).InputMemoryType(OrtMemTypeCPU, 1),
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.InputMemoryType(OrtMemTypeCPU, 1),
Flatten);

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime

0 comments on commit b805dad

Please sign in to comment.