forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
264 additions
and
14 deletions.
There are no files selected for viewing
170 changes: 170 additions & 0 deletions
170
onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/common.h" | ||
#include "core/providers/webgpu/math/binary_elementwise_ops.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
const auto& a = shader.AddInput("input_a", ShaderVariable::UseUniform); | ||
const auto& b = shader.AddInput("input_b", ShaderVariable::UseUniform); | ||
const auto& c = shader.AddOutput("output", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); | ||
std::string get_a_data = is_lhs_scalar_ ? "let a = output_value_t(" + a.GetByOffset("0") + ".x" + ");\n" : | ||
"let a = " + a.GetByOffset("global_idx") + ";\n"; | ||
std::string get_b_data = is_rhs_scalar_ ? "let b = output_value_t(" + b.GetByOffset("0") + ".x" + ");\n" : | ||
"let b = " + b.GetByOffset("global_idx") + ";\n"; | ||
if (!is_lhs_scalar_ && !is_rhs_scalar_ && is_broadcast_) { | ||
if (vectorize_) { | ||
std::string common = "let outputIndices = " + c.OffsetToIndices("global_idx * 4u") + ";\n" | ||
"let offset_a = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n" | ||
"let offset_b = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"; | ||
const std::string a_data = shared_dimension_divisible_by_4_ || a_last_dim_divisible_by_4_ ? | ||
"let a = " + a.GetByOffset("offset_a / 4") + ";\n" : | ||
"let a = output_value_t(" + a.GetByOffset("offset_a") + ");\n"; | ||
get_a_data = common + a_data; | ||
get_b_data = shared_dimension_divisible_by_4_ || b_last_dim_divisible_by_4_ ? | ||
"let b = " + b.GetByOffset("offset_b / 4") + ";\n" : | ||
"let b = output_value_t(" + b.GetByOffset("offset_b") + ");\n"; | ||
} else { | ||
std::string common = "let outputIndices = " + c.OffsetToIndices("global_idx") + ";\n" | ||
"let offset_a = " + a.BroadcastedIndicesToOffset("outputIndices", c) + ";\n" | ||
"let offset_b = " + b.BroadcastedIndicesToOffset("outputIndices", c) + ";\n"; | ||
get_a_data = common + "let a = " + a.GetByOffset("offset_a") + ";\n"; | ||
get_b_data = "let b = " + b.GetByOffset("offset_b") + ";\n"; | ||
} | ||
} | ||
|
||
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), | ||
get_a_data, get_b_data, | ||
c.SetByOffset("global_idx", expression_)); | ||
return Status::OK(); | ||
} | ||
|
||
Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { | ||
auto lhs_tensor = context.Input(0); | ||
auto rhs_tensor = context.Input(1); | ||
const auto& lhs_shape = lhs_tensor->Shape(); | ||
const auto& rhs_shape = rhs_tensor->Shape(); | ||
|
||
TensorShape output_shape; | ||
ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), lhs_shape, rhs_shape, output_shape)); | ||
auto output_tensor = context.Output(0, output_shape); | ||
int64_t size = output_shape.Size(); | ||
if (size == 0) { | ||
return Status::OK(); | ||
} | ||
|
||
bool is_broadcast = lhs_shape != rhs_shape; | ||
bool is_lhs_scalar = lhs_shape.IsScalar(); | ||
bool is_rhs_scalar = rhs_shape.IsScalar(); | ||
|
||
bool vectorize = is_lhs_scalar || is_rhs_scalar || !is_broadcast; | ||
bool a_last_dim_divisible_by_4 = false; | ||
bool b_last_dim_divisible_by_4 = false; | ||
bool shared_dimension_divisible_by_4 = false; | ||
if (!vectorize) { | ||
// check whether vectorize can be enabled | ||
a_last_dim_divisible_by_4 = lhs_shape.NumDimensions() > 0 && lhs_shape[lhs_shape.NumDimensions() - 1] % 4 == 0; | ||
b_last_dim_divisible_by_4 = rhs_shape.NumDimensions() > 0 && rhs_shape[rhs_shape.NumDimensions() - 1] % 4 == 0; | ||
if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) { | ||
vectorize = true; | ||
} else { | ||
size_t shared_dimension = 1; | ||
for (size_t i = 1; i < output_shape.NumDimensions(); i++) { | ||
size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; | ||
size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; | ||
if (dimA == dimB) { | ||
shared_dimension *= dimA; | ||
} else { | ||
break; | ||
} | ||
} | ||
if (shared_dimension % 4 == 0) { | ||
shared_dimension_divisible_by_4 = true; | ||
vectorize = true; | ||
} | ||
} | ||
} | ||
|
||
SafeInt<uint32_t> vec_size = vectorize ? (size + 3) / 4 : size; | ||
BinaryElementwiseProgram program{kernel_name_, expression_, is_broadcast, is_lhs_scalar, is_rhs_scalar, vectorize, | ||
a_last_dim_divisible_by_4, b_last_dim_divisible_by_4, shared_dimension_divisible_by_4}; | ||
program | ||
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
.AddUniformVariables({ | ||
{static_cast<uint32_t>(vec_size)}, | ||
}); | ||
if (is_lhs_scalar || is_rhs_scalar || !is_broadcast) { | ||
program | ||
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Rank, {vec_size}, 4}, | ||
{rhs_tensor, ProgramTensorMetadataDependency::Rank, {vec_size}, 4}}) | ||
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, {vec_size}, 4}}); | ||
} else { | ||
program | ||
.AddInputs({{lhs_tensor, ProgramTensorMetadataDependency::Rank, shared_dimension_divisible_by_4 || a_last_dim_divisible_by_4 ? 4 : 1}, | ||
{rhs_tensor, ProgramTensorMetadataDependency::Rank, shared_dimension_divisible_by_4 || b_last_dim_divisible_by_4 ? 4 : 1}}) | ||
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, vectorize ? 4 : 1}}); | ||
} | ||
std::ostringstream ss; | ||
ss.imbue(std::locale::classic()); | ||
ss << "is_broadcast = " << is_broadcast << "|" | ||
<< "is_lhs_scalar = " << is_lhs_scalar << "|" | ||
<< "is_rhs_scalar = " << is_rhs_scalar << "|" | ||
<< "vectorize = " << vectorize << "|" | ||
<< "a_last_dim_divisible_by_4 = " << a_last_dim_divisible_by_4 << "|" | ||
<< "b_last_dim_divisible_by_4 = " << b_last_dim_divisible_by_4 << "|" | ||
<< "shared_dimension_divisible_by_4 = " << shared_dimension_divisible_by_4; | ||
program.CacheHint(ss.str()); | ||
|
||
return context.RunProgram(program); | ||
} | ||
|
||
#define WEBGPU_BINARY_IMPL(OP_TYPE, ...) \ | ||
class OP_TYPE final : public BinaryElementwise { \ | ||
public: \ | ||
OP_TYPE(const OpKernelInfo& info) : BinaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ | ||
}; | ||
|
||
#define WEBGPU_BINARY_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ | ||
ONNX_OPERATOR_KERNEL_EX( \ | ||
OP_TYPE, \ | ||
kOnnxDomain, \ | ||
VERSION, \ | ||
kWebGpuExecutionProvider, \ | ||
KernelDefBuilder().TypeConstraint("T", TYPE), \ | ||
KERNEL_CLASS); | ||
|
||
#define WEBGPU_BINARY_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ | ||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ | ||
OP_TYPE, \ | ||
kOnnxDomain, \ | ||
VERSION_FROM, VERSION_TO, \ | ||
kWebGpuExecutionProvider, \ | ||
KernelDefBuilder().TypeConstraint("T", TYPE), \ | ||
KERNEL_CLASS); | ||
|
||
WEBGPU_BINARY_IMPL(Add, "a + b") | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Add, 7, 12, Add, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Add, 13, 13, Add, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_KERNEL(Add, 14, Add, WebGpuSupportedDataTypes()) | ||
|
||
WEBGPU_BINARY_IMPL(Div, "a / b") | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Div, 7, 12, Div, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Div, 13, 13, Div, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_KERNEL(Div, 14, Div, WebGpuSupportedDataTypes()) | ||
|
||
WEBGPU_BINARY_IMPL(Mul, "a * b") | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 7, 12, Mul, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Mul, 13, 13, Mul, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_KERNEL(Mul, 14, Mul, WebGpuSupportedDataTypes()) | ||
|
||
WEBGPU_BINARY_IMPL(Sub, "a - b") | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedDataTypes()) | ||
WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedDataTypes()) | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
58 changes: 58 additions & 0 deletions
58
onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/program.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram> { | ||
public: | ||
BinaryElementwiseProgram(const std::string& kernel_name, const std::string& expression, const bool is_broadcast, const bool is_lhs_scalar, const bool is_rhs_scalar, const bool vectorize, const bool a_last_dim_divisible_by_4, const bool b_last_dim_divisible_by_4, const bool shared_dimension_divisible_by_4) | ||
: Program{kernel_name}, expression_{expression}, is_broadcast_{is_broadcast}, is_lhs_scalar_{is_lhs_scalar}, is_rhs_scalar_{is_rhs_scalar}, vectorize_{vectorize}, a_last_dim_divisible_by_4_{a_last_dim_divisible_by_4}, b_last_dim_divisible_by_4_{b_last_dim_divisible_by_4}, shared_dimension_divisible_by_4_{shared_dimension_divisible_by_4} {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
std::string expression_; | ||
bool is_broadcast_; | ||
bool is_lhs_scalar_; | ||
bool is_rhs_scalar_; | ||
bool vectorize_; | ||
bool a_last_dim_divisible_by_4_; | ||
bool b_last_dim_divisible_by_4_; | ||
bool shared_dimension_divisible_by_4_; | ||
}; | ||
|
||
class BinaryElementwise : public WebGpuKernel { | ||
public: | ||
BinaryElementwise(const OpKernelInfo& info, | ||
const std::string& kernel_name, | ||
const std::string& expression, | ||
const std::string& additional_impl = "", | ||
ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info}, | ||
kernel_name_{kernel_name}, | ||
expression_{expression}, | ||
additional_impl_{additional_impl}, | ||
additional_usage_{usage} {} | ||
|
||
protected: | ||
std::string cache_hint; | ||
|
||
Status ComputeInternal(ComputeContext& context) const final; | ||
|
||
private: | ||
std::string kernel_name_; | ||
std::string expression_; | ||
std::string additional_impl_; | ||
ShaderVariable::Usage additional_usage_; | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters