Skip to content

Commit

Permalink
[webgpu] Optimize Expand (#23052)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Use components = 4 if possible.

This is the webgpu native implementation from #22752
  • Loading branch information
qjia7 authored Dec 10, 2024
1 parent bf4d3e1 commit defcc4f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
41 changes: 28 additions & 13 deletions onnxruntime/core/providers/webgpu/tensor/expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@ namespace onnxruntime {
namespace webgpu {

Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform);
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
<< " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
<< " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n "
<< output.SetByOffset("global_idx", input.GetByOffset("input_offset"));

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size");
if (input.NumComponents() != output.NumComponents()) {
const auto& output_indices = shader.AddIndices("output_indices");
shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"
<< " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n "
<< " let value = vec4<input_value_t>(" << input.GetByOffset("input_offset") << ");\n"
<< output.SetByOffset("global_idx", "value");
} else {
shader.MainFunctionBody() << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
<< " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n "
<< output.SetByOffset("global_idx", input.GetByOffset("input_offset"));
}
return Status::OK();
}

Expand All @@ -28,18 +34,27 @@ Status Expand::ComputeInternal(ComputeContext& context) const {

auto output_dims = input_shape_tensor->DataAsSpan<int64_t>();
TensorShape output_shape{};
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape));
TensorShape input_shape = input_tensor->Shape();
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_shape, output_dims, output_shape));

auto* output_tensor = context.Output(0, output_shape);
uint32_t data_size = gsl::narrow<uint32_t>(output_shape.Size());
const int components_i = input_shape.IsScalar() ? 1 : input_shape[input_shape.NumDimensions() - 1] % 4 == 0 ? 4
: 1;
const int components_o = output_shape.IsScalar() ? 1 : output_shape[output_shape.NumDimensions() - 1] % 4 == 0 ? 4
: 1;
uint32_t data_size = gsl::narrow<uint32_t>(output_shape.Size() / components_o);

ExpandProgram program{};
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}})
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_i}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, components_o}})
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({
{data_size},
});
if (components_i != components_o) {
program.AddIndices(output_shape);
}
return context.RunProgram(program);
}

Expand All @@ -55,8 +70,8 @@ Status Expand::ComputeInternal(ComputeContext& context) const {
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
KERNEL_CLASS);

WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes())
WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes())
WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedNumberTypes())
WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedNumberTypes())

} // namespace webgpu
} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/expand_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ TEST(ExpandOpTest, Expand_2x2x1x2x1_float) {
test.Run();
}

TEST(ExpandOpTest, Expand_3x1x8_float) {
OpTester test("Expand", 8);
test.AddInput<float>("data_0", {3, 2, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
test.AddInput<int64_t>("data_1", {3}, {3, 1, 8});
test.AddOutput<float>("result", {3, 2, 8},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f,
3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f, 3.0f,
4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f, 4.0f,
5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f,
6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f, 6.0f});
test.Run();
}

#ifndef USE_TENSORRT
TEST(ExpandOpTest, Expand_scalar_float) {
OpTester test("Expand", 8);
Expand Down

0 comments on commit defcc4f

Please sign in to comment.