diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 9fc886cb69bbf..809616660aa9e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -11,14 +11,20 @@ namespace onnxruntime { namespace webgpu { Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") - << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " - << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); - + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"); + if (input.NumComponents() != output.NumComponents()) { + const auto& output_indices = shader.AddIndices("output_indices"); + shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n " + << " let value = vec4(" << input.GetByOffset("input_offset") << ");\n" + << output.SetByOffset("global_idx", "value"); + } else { + shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + } return Status::OK(); } @@ -28,18 +34,27 @@ Status Expand::ComputeInternal(ComputeContext& context) const { auto output_dims = input_shape_tensor->DataAsSpan(); TensorShape output_shape{}; - ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + TensorShape input_shape = input_tensor->Shape(); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_shape, output_dims, output_shape)); auto* output_tensor = context.Output(0, output_shape); - uint32_t data_size = gsl::narrow(output_shape.Size()); + const int components_i = input_shape.IsScalar() ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4 + : 1; + uint32_t data_size = gsl::narrow(output_shape.Size() / components_o); + ExpandProgram program{}; program - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}}) .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ {data_size}, }); + if (components_i != components_o) { + program.AddIndices(output_shape); + } return context.RunProgram(program); } @@ -55,8 +70,8 @@ Status Expand::ComputeInternal(ComputeContext& context) const { KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ KERNEL_CLASS); -WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes()) -WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedNumberTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedNumberTypes()) } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/expand_test.cc b/onnxruntime/test/providers/cpu/tensor/expand_test.cc index 4b0f4e84ca37d..38e3bc3af6d6b 100644 --- a/onnxruntime/test/providers/cpu/tensor/expand_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/expand_test.cc @@ -167,6 +167,20 @@ TEST(ExpandOpTest, Expand_2x2x1x2x1_float) { test.Run(); } +TEST(ExpandOpTest, Expand_3x1x8_float) { + OpTester test("Expand", 8); + test.AddInput("data_0", {3, 2, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + test.AddInput("data_1", {3}, {3, 1, 8}); + test.AddOutput("result", {3, 2, 8}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, + 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, + 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f}); + test.Run(); +} + #ifndef USE_TENSORRT TEST(ExpandOpTest, Expand_scalar_float) { OpTester test("Expand", 8);