From 7ba6ed6f67959d85121ea0247fa3f6446e7cc9b9 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 11 Dec 2024 06:58:57 +0800 Subject: [PATCH] [webgpu] Optimize Expand (#23052) ### Description Use components = 4 if possible. This is the webgpu native implementation from #22752 --- .../core/providers/webgpu/tensor/expand.cc | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 9fc886cb69bbf..809616660aa9e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -11,14 +11,20 @@ namespace onnxruntime { namespace webgpu { Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") - << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " - << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); - + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"); + if (input.NumComponents() != output.NumComponents()) { + const auto& output_indices = shader.AddIndices("output_indices"); + shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n " + << " let value = vec4(" << input.GetByOffset("input_offset") << ");\n" + << output.SetByOffset("global_idx", "value"); + } else { + shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + } return Status::OK(); } @@ -28,18 +34,27 @@ Status Expand::ComputeInternal(ComputeContext& context) const { auto output_dims = input_shape_tensor->DataAsSpan(); TensorShape output_shape{}; - ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + TensorShape input_shape = input_tensor->Shape(); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_shape, output_dims, output_shape)); auto* output_tensor = context.Output(0, output_shape); - uint32_t data_size = gsl::narrow(output_shape.Size()); + const int components_i = input_shape.IsScalar() ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); + ExpandProgram program{}; program - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}}) .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ {data_size}, }); + if (components_i != components_o) { + program.AddIndices(output_shape); + } return context.RunProgram(program); } @@ -55,8 +70,8 @@ Status Expand::ComputeInternal(ComputeContext& context) const { KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ KERNEL_CLASS); -WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes()) -WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedNumberTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedNumberTypes()) } // namespace webgpu } // namespace onnxruntime