Skip to content

Commit

Permalink
add binary ops
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Sep 11, 2024
1 parent f65ade9 commit 264b9a2
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 14 deletions.
170 changes: 170 additions & 0 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/webgpu/math/binary_elementwise_ops.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {
Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("input_a", ShaderVariable::UseUniform);
const auto& b = shader.AddInput("input_b", ShaderVariable::UseUniform);
const auto& c = shader.AddOutput("output", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
std::string get_a_data = is_lhs_scalar_ ? "let a = output_value_t(" + a.GetByOffset("0") + ".x" + ");\n" :
"let a = " + a.GetByOffset("global_idx") + ";\n";
std::string get_b_data = is_rhs_scalar_ ? "let b = output_value_t(" + b.GetByOffset("0") + ".x" + ");\n" :
"let b = " + b.GetByOffset("global_idx") + ";\n";
if (!is_lhs_scalar_ && !is_rhs_scalar_ && is_broadcast_) {
if (vectorize_) {
std::string common = "let outputIndices = " + c.OffsetToIndices("global_idx * 4u") + ";\n"
"let offset_a = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"let offset_b = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n";
const std::string a_data = shared_dimension_divisible_by_4_ || a_last_dim_divisible_by_4_ ?
"let a = " + a.GetByOffset("offset_a / 4") + ";\n" :
"let a = output_value_t(" + a.GetByOffset("offset_a") + ");\n";
get_a_data = common + a_data;
get_b_data = shared_dimension_divisible_by_4_ || b_last_dim_divisible_by_4_ ?
"let b = " + b.GetByOffset("offset_b / 4") + ";\n" :
"let b = output_value_t(" + b.GetByOffset("offset_b") + ");\n";
} else {
std::string common = "let outputIndices = " + c.OffsetToIndices("global_idx") + ";\n"
"let offset_a = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"
"let offset_b = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n";
get_a_data = common + "let a = " + a.GetByOffset("offset_a") + ";\n";
get_b_data = "let b = " + b.GetByOffset("offset_b") + ";\n";
}
}

shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
get_a_data, get_b_data,
c.SetByOffset("global_idx", expression_));
return Status::OK();
}

Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
auto lhs_tensor = context.Input(0);
auto rhs_tensor = context.Input(1);
const auto& lhs_shape = lhs_tensor->Shape();
const auto& rhs_shape = rhs_tensor->Shape();

TensorShape output_shape;
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape));
auto output_tensor = context.Output(0, output_shape);
int64_t size = output_shape.Size();
if (size == 0) {
return Status::OK();
}

bool is_broadcast = lhs_shape != rhs_shape;
bool is_lhs_scalar = lhs_shape.IsScalar();
bool is_rhs_scalar = rhs_shape.IsScalar();

bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast;
bool a_last_dim_divisible_by_4 = false;
bool b_last_dim_divisible_by_4 = false;
bool shared_dimension_divisible_by_4 = false;
if (!vectorize) {
// check whether vectorize can be enabled
a_last_dim_divisible_by_4 = lhs_shape.NumDimensions() > 0 && lhs_shape[lhs_shape.NumDimensions() - 1] % 4 == 0;
b_last_dim_divisible_by_4 = rhs_shape.NumDimensions() > 0 && rhs_shape[rhs_shape.NumDimensions() - 1] % 4 == 0;
if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) {
vectorize = true;
} else {
size_t shared_dimension = 1;
for (size_t i = 1; i < output_shape.NumDimensions(); i++) {
size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1;
size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1;
if (dimA == dimB) {
shared_dimension *= dimA;
} else {
break;
}
}
if (shared_dimension % 4 == 0) {
shared_dimension_divisible_by_4 = true;
vectorize = true;
}
}
}

SafeInt<uint32_t> vec_size = vectorize ? (size + 3) / 4 : size;
BinaryElementwiseProgram program{kernel_name_, expression_, is_broadcast, is_lhs_scalar, is_rhs_scalar, vectorize,
a_last_dim_divisible_by_4, b_last_dim_divisible_by_4, shared_dimension_divisible_by_4};
program
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({
{static_cast<uint32_t>(vec_size)},
});
if (is_lhs_scalar || is_rhs_scalar || !is_broadcast) {
program
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Rank, {vec_size}, 4},
{rhs_tensor, ProgramTensorMetadataDependency::Rank, {vec_size}, 4}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, {vec_size}, 4}});
} else {
program
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Rank, shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4 ? 4 : 1},
{rhs_tensor, ProgramTensorMetadataDependency::Rank, shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4 ? 4 : 1}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, vectorize ? 4 : 1}});
}
std::ostringstream ss;
ss.imbue(std::locale::classic());
ss << "is_broadcast = " << is_broadcast << "|"
<< "is_lhs_scalar = " << is_lhs_scalar << "|"
<< "is_rhs_scalar = " << is_rhs_scalar << "|"
<< "vectorize = " << vectorize << "|"
<< "a_last_dim_divisible_by_4 = " << a_last_dim_divisible_by_4 << "|"
<< "b_last_dim_divisible_by_4 = " << b_last_dim_divisible_by_4 << "|"
<< "shared_dimension_divisible_by_4 = " << shared_dimension_divisible_by_4;
program.CacheHint(ss.str());

return context.RunProgram(program);
}

#define WEBGPU_BINARY_IMPL(OP_TYPE, ...) \
class OP_TYPE final : public BinaryElementwise { \
public: \
OP_TYPE(const OpKernelInfo& info) : BinaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \
};

#define WEBGPU_BINARY_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION, \
kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE), \
KERNEL_CLASS);

#define WEBGPU_BINARY_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, \
kOnnxDomain, \
VERSION_FROM, VERSION_TO, \
kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE), \
KERNEL_CLASS);

WEBGPU_BINARY_IMPL(Add, "a + b")
WEBGPU_BINARY_VERSIONED_KERNEL(Add, 7, 12, Add, WebGpuSupportedDataTypes())
WEBGPU_BINARY_VERSIONED_KERNEL(Add, 13, 13, Add, WebGpuSupportedDataTypes())
WEBGPU_BINARY_KERNEL(Add, 14, Add, WebGpuSupportedDataTypes())

WEBGPU_BINARY_IMPL(Div, "a / b")
WEBGPU_BINARY_VERSIONED_KERNEL(Div, 7, 12, Div, WebGpuSupportedDataTypes())
WEBGPU_BINARY_VERSIONED_KERNEL(Div, 13, 13, Div, WebGpuSupportedDataTypes())
WEBGPU_BINARY_KERNEL(Div, 14, Div, WebGpuSupportedDataTypes())

WEBGPU_BINARY_IMPL(Mul, "a * b")
WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 7, 12, Mul, WebGpuSupportedDataTypes())
WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 13, 13, Mul, WebGpuSupportedDataTypes())
WEBGPU_BINARY_KERNEL(Mul, 14, Mul, WebGpuSupportedDataTypes())

WEBGPU_BINARY_IMPL(Sub, "a - b")
WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedDataTypes())
WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedDataTypes())
WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedDataTypes())

} // namespace webgpu
} // namespace onnxruntime
58 changes: 58 additions & 0 deletions onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram> {
public:
BinaryElementwiseProgram(const std::string& kernel_name, const std::string& expression, const bool is_broadcast, const bool is_lhs_scalar, const bool is_rhs_scalar, const bool vectorize, const bool a_last_dim_divisible_by_4, const bool b_last_dim_divisible_by_4, const bool shared_dimension_divisible_by_4)
: Program{kernel_name}, expression_{expression}, is_broadcast_{is_broadcast}, is_lhs_scalar_{is_lhs_scalar}, is_rhs_scalar_{is_rhs_scalar}, vectorize_{vectorize}, a_last_dim_divisible_by_4_{a_last_dim_divisible_by_4}, b_last_dim_divisible_by_4_{b_last_dim_divisible_by_4}, shared_dimension_divisible_by_4_{shared_dimension_divisible_by_4} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
std::string expression_;
bool is_broadcast_;
bool is_lhs_scalar_;
bool is_rhs_scalar_;
bool vectorize_;
bool a_last_dim_divisible_by_4_;
bool b_last_dim_divisible_by_4_;
bool shared_dimension_divisible_by_4_;
};

class BinaryElementwise : public WebGpuKernel {
public:
BinaryElementwise(const OpKernelInfo& info,
const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl = "",
ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info},
kernel_name_{kernel_name},
expression_{expression},
additional_impl_{additional_impl},
additional_usage_{usage} {}

protected:
std::string cache_hint;

Status ComputeInternal(ComputeContext& context) const final;

private:
std::string kernel_name_;
std::string expression_;
std::string additional_impl_;
ShaderVariable::Usage additional_usage_;
};

} // namespace webgpu
} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ Status ValidateVariableShape(const TensorShape& origin_shape,
ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(),
"Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components);
} else if (num_components > 1) {
// if shape is not overriden, assert origin_shape[-1] % 4 == 0
ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.NumDimensions() - 1] % num_components == 0,
// if shape is not overriden, assert origin_size % 4 == 0
ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape.Size() % num_components == 0,
"Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components);
}

Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,18 +452,18 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(20, Gelu),

// // binary - math
// KERNEL_CREATE_INFO_VERSIONED(7, 12, Add),
// KERNEL_CREATE_INFO_VERSIONED(13, 13, Add),
// KERNEL_CREATE_INFO(14, Add),
// KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub),
// KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub),
// KERNEL_CREATE_INFO(14, Sub),
// KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul),
// KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul),
// KERNEL_CREATE_INFO(14, Mul),
// KERNEL_CREATE_INFO_VERSIONED(7, 12, Div),
// KERNEL_CREATE_INFO_VERSIONED(13, 13, Div),
// KERNEL_CREATE_INFO(14, Div),
KERNEL_CREATE_INFO_VERSIONED(7, 12, Add),
KERNEL_CREATE_INFO_VERSIONED(13, 13, Add),
KERNEL_CREATE_INFO(14, Add),
KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub),
KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub),
KERNEL_CREATE_INFO(14, Sub),
KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul),
KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul),
KERNEL_CREATE_INFO(14, Mul),
KERNEL_CREATE_INFO_VERSIONED(7, 12, Div),
KERNEL_CREATE_INFO_VERSIONED(13, 13, Div),
KERNEL_CREATE_INFO(14, Div),
// KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow),
// KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow),
// KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow),
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,28 @@ TEST(MathOpTest, Add_Broadcast_3x2_3x1) {
#endif
}

TEST(MathOpTest, Add_Broadcast_2x2x2_1x2x2) {
OpTester test("Add");

test.AddInput<float>("A", {2, 2, 2},
{101.0f, 102.0f,
103.0f, 104.0f,

201.0f, 202.0f,
203.0f, 204.0f});
test.AddInput<float>("B", {1, 2, 2},
{010.0f, 020.0f,
030.0f, 040.0f});
test.AddOutput<float>("C", {2, 2, 2},
{111.0f, 122.0f,
133.0f, 144.0f,

211.0f, 222.0f,
233.0f, 244.0f});

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

TEST(MathOpTest, Add_Broadcast_2x1x4_1x3x1) {
OpTester test("Add");

Expand Down

0 comments on commit 264b9a2

Please sign in to comment.