Skip to content

Commit

Permalink
Add implementation of WebGPU EP (microsoft#22591)
Browse files Browse the repository at this point in the history
### Description

This PR adds the actual implementation of the WebGPU EP based on
microsoft#22318.

This change includes the following:

<details>
<summary><b>core framework of WebGPU EP</b></summary>

  - WebGPU EP factory classes for:
    - handling WebGPU options
    - creating WebGPU EP instance
    - creating WebGPU context
  - WebGPU Execution Provider classes
    - GPU Buffer allocator
    - data transfer
  - Buffer management classes
    - Buffer Manager
    - BufferCacheManager
      - DisabledCacheManager
      - SimpleCacheManager
      - LazyReleaseCacheManager
      - BucketCacheManager
  - Program classes
    - Program (base)
    - Program Cache Key
    - Program Manager
  - Shader helper classes
    - Shader Helper
    - ShaderIndicesHelper
    - ShaderVariableHelper
  - Utils
    - GPU Query based profiler
    - compute context
    - string utils
  - Miscs
    - Python binding webgpu support (basic)
 
</details>

<details>
<summary><b>Kernel implementation</b></summary>


  - onnx.ai (default opset):
- Elementwise (math): Abs, Neg, Floor, Ceil, Reciprocal, Sqrt, Exp, Erf,
Log, Sin, Cos, Tan, Asin, Acos, Atan, Sinh, Cosh, Asinh, Acosh, Atanh,
Tanh, Not, Cast
- Elementwise (activation): Sigmoid, HardSigmoid, Clip, Elu, Relu,
LeakyRelu, ThresholdedRelu, Gelu
- Binary (math): Add, Sub, Mul, Div, Pow, Equal, Greater,
GreaterOrEqual, Less, LessOrEqual
    - (Tensors): Shape, Reshape, Squeeze, Unsqueeze
    - Where
    - Transpose
    - Concat
    - Expand
    - Gather
    - Tile
    - Range
    - LayerNormalization
  - com.microsoft
    - FastGelu
    - MatMulNBits
    - MultiHeadAttention
    - RotaryEmbedding
    - SkipLayerNormalization
    - LayerNormalization
    - SimplifiedLayerNormalization
    - SkipSimplifiedLayerNormalization

</details>

<details>
<summary><b>Build, test and CI pipeline integration</b></summary>

  - build works for Windows, macOS and iOS
  - support onnxruntime_test_all and python node test
  - added a new unit test for `--use_external_dawn` build flag.
  - updated MacOS pipeline to build with WebGPU support
  - added a new pipeline for WebGPU Windows

</details>

This change does not include:

- Node.js binding support for WebGPU (will be a separate PR)
  • Loading branch information
fs-eire authored Oct 30, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 5cc7fb4 commit 7a8fa12
Showing 102 changed files with 10,241 additions and 92 deletions.
4 changes: 4 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -148,6 +148,7 @@ option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algor
option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF)
option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF)
option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF)
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)

# Options related to reducing the binary size produced by the build
# XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON
@@ -958,6 +959,9 @@ if (onnxruntime_USE_WEBGPU)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu)
if (onnxruntime_USE_EXTERNAL_DAWN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1)
endif()
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
7 changes: 6 additions & 1 deletion cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
@@ -656,11 +656,16 @@ if (onnxruntime_USE_WEBGPU)

# Vulkan may optionally be included in a Windows build. Exclude until we have an explicit use case that requires it.
set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE)
# We are currently always using the D3D12 backend.
set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE)
endif()

onnxruntime_fetchcontent_makeavailable(dawn)

list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native dawn::dawn_proc)
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native)
endif()
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
endif()

set(onnxruntime_LINK_DIRS)
5 changes: 4 additions & 1 deletion cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
@@ -22,6 +22,9 @@
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu
onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native dawn::dawn_proc)
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native)
endif()
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc)

set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")
12 changes: 12 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
@@ -523,6 +523,9 @@ set (onnxruntime_global_thread_pools_test_SRC
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc)

set (onnxruntime_webgpu_external_dawn_test_SRC
${TEST_SRC_DIR}/webgpu/external_dawn/main.cc)

# tests from lowest level library up.
# the order of libraries should be maintained, with higher libraries being added first in the list

@@ -1884,4 +1887,13 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD
endif()
endif()

if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN)
AddTest(TARGET onnxruntime_webgpu_external_dawn_test
SOURCES ${onnxruntime_webgpu_external_dawn_test_SRC}
LIBS dawn::dawn_native ${onnxruntime_test_providers_libs}
DEPENDS ${all_dependencies}
)
onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers)
endif()

include(onnxruntime_fuzz_test.cmake)
45 changes: 30 additions & 15 deletions cmake/patches/dawn/dawn.patch
Original file line number Diff line number Diff line change
@@ -15,40 +15,55 @@ index 9c0bd6fa4e..bf8a57aeac 100644
###############################################################################
# Do the 'complete_lib' build.
diff --git a/src/dawn/native/Surface_metal.mm b/src/dawn/native/Surface_metal.mm
index ce55acbd43..baa4835362 100644
index ce55acbd43..2cfd363479 100644
--- a/src/dawn/native/Surface_metal.mm
+++ b/src/dawn/native/Surface_metal.mm
@@ -36,7 +36,13 @@
@@ -33,10 +33,18 @@

#import <QuartzCore/CAMetalLayer.h>

+#include "dawn/common/Platform.h"
+
namespace dawn::native {

bool InheritsFromCAMetalLayer(void* obj) {
- id<NSObject> object = static_cast<id>(obj);
+ id<NSObject> object =
+#if TARGET_OS_IOS
+#if DAWN_PLATFORM_IS(IOS)
+ (__bridge id)obj;
+#else
+#else // DAWN_PLATFORM_IS(IOS)
+ static_cast<id>(obj);
+#endif
+#endif // DAWN_PLATFORM_IS(IOS)
+
return [object isKindOfClass:[CAMetalLayer class]];
}

diff --git a/src/dawn/native/metal/SharedFenceMTL.mm b/src/dawn/native/metal/SharedFenceMTL.mm
index bde8bfea07..f2f6459e91 100644
index bde8bfea07..8906185d6f 100644
--- a/src/dawn/native/metal/SharedFenceMTL.mm
+++ b/src/dawn/native/metal/SharedFenceMTL.mm
@@ -40,7 +40,13 @@ ResultOrError<Ref<SharedFence>> SharedFence::Create(
@@ -25,6 +25,8 @@
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+#include "dawn/common/Platform.h"
+
#include "dawn/native/metal/SharedFenceMTL.h"

#include "dawn/native/ChainUtils.h"
@@ -39,8 +41,13 @@ ResultOrError<Ref<SharedFence>> SharedFence::Create(
const SharedFenceMTLSharedEventDescriptor* descriptor) {
DAWN_INVALID_IF(descriptor->sharedEvent == nullptr, "MTLSharedEvent is missing.");
if (@available(macOS 10.14, iOS 12.0, *)) {
return AcquireRef(new SharedFence(
- return AcquireRef(new SharedFence(
- device, label, static_cast<id<MTLSharedEvent>>(descriptor->sharedEvent)));
+ device, label,
+#if TARGET_OS_IOS
+ (__bridge id<MTLSharedEvent>)(descriptor->sharedEvent)
+#else
+ static_cast<id<MTLSharedEvent>>(descriptor->sharedEvent)
+#endif
+ ));
+ return AcquireRef(new SharedFence(device, label,
+#if DAWN_PLATFORM_IS(IOS)
+ (__bridge id<MTLSharedEvent>)(descriptor->sharedEvent)
+#else // DAWN_PLATFORM_IS(IOS)
+ static_cast<id<MTLSharedEvent>>(descriptor->sharedEvent)
+#endif // DAWN_PLATFORM_IS(IOS)
+ ));
} else {
return DAWN_INTERNAL_ERROR("MTLSharedEvent not supported.");
}
84 changes: 84 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/math/unary_elementwise_ops.h"
#include "contrib_ops/webgpu/bert/fast_gelu.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
FastGelu,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
FastGelu);

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform);

shader.AdditionalImplementation() << TanhImpl;
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")
<< " var a = " << x.GetByOffset("global_idx") << ";\n";
if (Inputs().size() > 1) {
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
if (bias_components_ == 1) {
shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n"
" a += x_value_t("
<< bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", "
<< bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n";
} else {
shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n";
}
}
shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr);

return Status::OK();
}

Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* input = context.Input(0);
const auto* bias = context.Input(1);
auto* output = context.Output(0, input->Shape());

uint32_t data_size = gsl::narrow<uint32_t>(output->Shape().Size());
if (data_size == 0) {
return Status::OK();
}

const auto vec_size = (data_size + 3) / 4;
uint32_t bias_size = 0;
int bias_components = 1;

if (bias != nullptr) {
bias_size = gsl::narrow<uint32_t>(bias->Shape().Size());
if (bias_size % 4 == 0) {
bias_components = 4;
bias_size = bias_size / 4;
}
}

FastGeluProgram program{bias_components};
program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4})
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariable({vec_size});

if (bias != nullptr) {
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components});
}
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
38 changes: 38 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;
using onnxruntime::webgpu::ComputeContext;

class FastGeluProgram final : public Program<FastGeluProgram> {
public:
FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
int bias_components_;
};

class FastGelu final : public WebGpuKernel {
public:
FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}

Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
36 changes: 36 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/nn/layer_norm.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;
using onnxruntime::webgpu::ComputeContext;

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
LayerNormalization,
kOnnxDomain,
1,
16,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
onnxruntime::webgpu::LayerNorm<false>);

ONNX_OPERATOR_KERNEL_EX(
SimplifiedLayerNormalization,
kOnnxDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
onnxruntime::webgpu::LayerNorm<true>);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit 7a8fa12

Please sign in to comment.