From 0f7a5f6077f0885aa32b0ede324023419badb3c2 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 27 Sep 2024 13:49:09 +0800 Subject: [PATCH 1/5] [webgpu native] Add RotaryEmbedding op (#22194) ### Description ### Motivation and Context --- .../webgpu/bert/rotary_embedding.cc | 134 ++++++++++++++++++ .../webgpu/bert/rotary_embedding.h | 47 ++++++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- .../contrib_ops/rotary_embedding_op_test.cc | 4 + 4 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..eb5cfad87597f --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform); + const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + // TODO: remove output_indices. + const auto& output_indices = shader.AddIndices("output_indices", false); + const auto interleaved_str = interleaved_ ? "true" : "false"; + shader.SetMainFunctionBody( + " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" + " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n", + " if (global_idx >= size) { return; }\n" + " if (bsnh[3] < half_rotary_emb_dim) {\n" + " let position_ids_idx = " + + position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) + ";\n" + + " let position_id = u32(" + + position_ids.GetByOffset("position_ids_idx") + ")" + + " + select(0, bsnh[1], position_ids_idx == 0);\n" + " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " + + interleaved_str + + ");\n" + " let j = i + select(half_rotary_emb_dim, 1, " + + interleaved_str + + ");\n" + " let re = " + + input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + "-" + + input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + ";\n" + + " " + output.SetByOffset("i", "re") + "\n" + + " let im = " + input.GetByOffset("i") + " * " + + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + + "+ " + input.GetByOffset("j") + + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + + ";\n " + output.SetByOffset("j", "im") + + "\n" + " } else { \n" + " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + + " " + output.SetByOffset("k", input.GetByOffset("k")) + + "\n" + " }"); + + return Status::OK(); +} + +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { + scale_ = info.GetAttrOrDefault("scale", 1.0); + rotary_embedding_dim_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads_ = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved_ = (info.GetAttrOrDefault("interleaved", 0) == 1); + is_packed_batching_ = (info.GetAttrOrDefault("is_packed_batching", 0) == 1); +} + +Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto input_shape = input->Shape(); + const auto* position_ids = context.Input(1); + const auto* cos_cache = context.Input(2); + const auto* sin_cache = context.Input(3); + auto* output = context.Output(0, input_shape); + + const auto batch_size = gsl::narrow_cast(input->Shape()[0]); + const auto batch_stride = gsl::narrow_cast(input_shape.SizeFromDimension(1)); + const auto sequence_length = gsl::narrow_cast(input_shape[input_shape.NumDimensions() - 2]); + const auto hidden_size = batch_stride / sequence_length; + const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); + const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; + + // Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape + // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] + // to unfold the global index in shader. + const TensorShape global_shape({batch_size, + sequence_length, + hidden_size / head_size, + head_size - half_rotary_embedding_dim}); + + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = gsl::narrow_cast(global_shape[j]); + global_strides[j] = gsl::narrow_cast(global_shape.SizeFromDimension(j + 1)); + } + + const auto output_size = gsl::narrow_cast(global_shape.Size()); + RotaryEmbeddingProgram program{interleaved_}; + const auto input_output_strides = + input_shape.NumDimensions() == 3 + ? std::vector({batch_stride, hidden_size, head_size, 1}) + : (input_shape.NumDimensions() == 4 + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + : std::vector({})); + + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::Rank}, + {position_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{scale_}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..0d73b89fb62df --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class RotaryEmbeddingProgram final : public Program { + public: + RotaryEmbeddingProgram(bool interleaved) : Program{"RotaryEmbedding"}, interleaved_{interleaved} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32}, + {"global_shape", ProgramUniformVariableDataType::Uint32}, + {"global_stride", ProgramUniformVariableDataType::Uint32}, + {"input_output_stride", ProgramUniformVariableDataType::Uint32}); + + private: + const bool interleaved_; +}; + +class RotaryEmbedding final : public WebGpuKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + float scale_; + int num_heads_; + int rotary_embedding_dim_; + bool interleaved_; + bool is_packed_batching_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index def104b6cb108..01c8a28d45069 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -47,7 +47,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo Date: Fri, 27 Sep 2024 14:57:21 +0800 Subject: [PATCH 2/5] [webgpu native] Add transpose shared (#22098) ### Description ### Motivation and Context --- .../core/providers/webgpu/tensor/transpose.cc | 91 +++++++++++++++---- .../core/providers/webgpu/tensor/transpose.h | 24 ++--- 2 files changed, 87 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 0962d9191d785..e0a0113e13224 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,11 +47,11 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -const std::string AppendPermFunction(gsl::span perm) { +const std::string AppendPermFunction(gsl::span perm) { std::ostringstream ss; ss.imbue(std::locale::classic()); - ss << "fn perm(i: y_indices_t)->x_indices_t {\n" - " var a: x_indices_t;\n"; + ss << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; for (size_t i = 0; i < perm.size(); ++i) { ss << " a[" << perm[i] << "] = i[" << i << "];\n"; } @@ -60,21 +60,52 @@ const std::string AppendPermFunction(gsl::span perm) { return ss.str(); } +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { + for (auto i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; + Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AppendImplementation(AppendPermFunction(this->perm_)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " let indices = ", output.OffsetToIndices("global_idx"), - ";\n" - " let x_indices = perm(indices); \n" - " ", - output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + + if (use_shared_) { + shader.AppendImplementation("var tile : array, tile_size>;\n"); + shader.SetMainFunctionBody( + " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + " tile[local_id.y][local_id.x] = " + + input.GetByIndices("a_indices_t(input_row, input_col)") + + ";\n" + " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " + + output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") + "\n }"); + } else { + shader.AppendImplementation(AppendPermFunction(this->perm_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " let indices = ", output.OffsetToIndices("global_idx"), + ";\n" + " let x_indices = perm(indices);\n", + " ", + output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + } return Status::OK(); } Status Transpose::ComputeInternal(ComputeContext& context) const { - // TODO: there is an optimized version of transpose to port. const auto* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); @@ -86,16 +117,42 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const bool channels_last = new_perm == InlinedVector({2, 3, 1}); + const bool channels_first = new_perm == InlinedVector({3, 1, 2}); + const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; + auto new_input_shape = input_shape; + TensorShape new_output_shape(output_dims); + if (use_shared) { + new_input_shape = channels_last + ? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) + : channels_first + ? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) + : new_shape; + new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); + } + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); - TransposeProgram program{*p_perm}; + TransposeProgram program{*p_perm, use_shared}; + if (use_shared) { + program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); + } + program .CacheHint(absl::StrJoin(*p_perm, "-")) - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({output_tensor}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) + .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) .AddUniformVariables({ {static_cast(output_size)}, }); + + use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) + : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 3ca5674d5dfab..7cf5c1fe0865d 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -11,26 +11,28 @@ namespace onnxruntime { namespace webgpu { +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + Status ComputeInternal(ComputeContext& context) const override; + constexpr static uint32_t TILE_SIZE = 16; +}; + class TransposeProgram final : public Program { public: - TransposeProgram(const gsl::span& permutations) - : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) { + TransposeProgram(const gsl::span& permutations, bool use_shared) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) { } Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_CONSTANTS({"tile_size", Transpose::TILE_SIZE}); private: - InlinedVector perm_; -}; - -class Transpose final : public WebGpuKernel, public TransposeBase { - public: - Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { - } - - Status ComputeInternal(ComputeContext& context) const override; + InlinedVector perm_; + const bool use_shared_; }; } // namespace webgpu From b1b5e1fd1a718c1dcae3a96233a32383c26a1db5 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 27 Sep 2024 15:19:28 +0800 Subject: [PATCH 3/5] [webgpu-native] Add gather (#22183) ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../core/providers/webgpu/tensor/gather.cc | 82 +++++++++++++++++++ .../core/providers/webgpu/tensor/gather.h | 34 ++++++++ .../webgpu/webgpu_execution_provider.cc | 6 +- 3 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/gather.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/gather.h diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc new file mode 100644 index 0000000000000..31e0a9e883239 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/gather.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + std::ostringstream calc_data_indices; + calc_data_indices.imbue(std::locale::classic()); + calc_data_indices << " var indices_indices = input_indices_indices_t(0);\n"; + for (int i = 0; i < indices.Rank(); i++) { + calc_data_indices << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; + } + calc_data_indices << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(uniforms.data_shape[" << axis_ << "]);\n" + << " }\n" + << " var data_indices : data_indices_t;\n"; + for (int i = 0, j = 0; i < data.Rank(); i++) { + if (i == SafeInt(axis_)) { + calc_data_indices << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + j += indices.Rank(); + } else { + calc_data_indices << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; + j++; + } + } + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), + " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", + calc_data_indices.str(), " ", + output.SetByOffset("global_idx", data.GetByIndices("data_indices"))); + + return Status::OK(); +} + +Status Gather::ComputeInternal(ComputeContext& context) const { + Prepare p; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + uint32_t data_size = SafeInt(p.output_tensor->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + uint32_t axis = static_cast(p.axis); + GatherProgram program{axis}; + program + .AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .CacheHint(std::to_string(axis)) + .AddUniformVariables({{data_size}}); + return context.RunProgram(program); +} + +#define WEBGPU_GATHER_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +#define WEBGPU_GATHER_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), \ + KERNEL_CLASS); + +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes()) +WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes()) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.h b/onnxruntime/core/providers/webgpu/tensor/gather.h new file mode 100644 index 0000000000000..bebe13519ce43 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/tensor/gatherbase.h" + +namespace onnxruntime { +namespace webgpu { + +class GatherProgram final : public Program { + public: + GatherProgram(const uint32_t axis) : Program{"Gather"}, axis_{axis} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t axis_; +}; + +class Gather final : public WebGpuKernel, public GatherBase { + public: + Gather(const OpKernelInfo& info) : WebGpuKernel(info), GatherBase(info) {} + + protected: + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index f5d66d6a24134..df2a2caa0a1fa 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -641,9 +641,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, From 92a08e2d13f3498fd8eeb8ab572f096ceb2c86e7 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Fri, 27 Sep 2024 02:55:55 -0700 Subject: [PATCH 4/5] [Native-WebGPU] Add Concat (#22225) ### Description Add Concat operator support ### Motivation and Context Required for WebGPU EP --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../core/providers/webgpu/compute_context.h | 4 + .../core/providers/webgpu/tensor/concat.cc | 155 ++++++++++++++++++ .../core/providers/webgpu/tensor/concat.h | 36 ++++ .../webgpu/webgpu_execution_provider.cc | 8 +- 4 files changed, 199 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/concat.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/concat.h diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index b7ea8a58e232b..455eb4452f858 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -113,6 +113,10 @@ class ComputeContext { return webgpu_context_.Run(*this, program); } + inline OpKernelContext& GetKernelContext() { + return kernel_context_; + } + // // Push error scope. // diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc new file mode 100644 index 0000000000000..671a6a1ed072c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/tensor/concat.h" + +#include "core/common/inlined_containers.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +#define WEBGPU_CONCAT_KERNEL(version) \ + ONNX_OPERATOR_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + Concat); + +WEBGPU_CONCAT_VERSIONED_KERNEL(1, 3) +WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) +WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) +WEBGPU_CONCAT_KERNEL(13) + +const std::string AppendCalCulateInputIndexFunction(size_t input_count) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + ss << "fn calculate_input_index(index: u32) -> u32 {" << std::endl + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {" << std::endl + << " if (index < uniforms.size_in_concat_axis[i]) {" << std::endl + << " return i;" << std::endl + << " }" << std::endl + << " }" << std::endl + << " return " << input_count << ";" << std::endl + << "}" << std::endl; + return ss.str(); +} + +const void AppendAssignOutput(std::ostringstream& ss, const ShaderVariableHelper& input, const ShaderVariableHelper& output) { + ss << output.SetByOffset("global_idx", input.GetByIndices("indices")) << ";" << std::endl; +} + +const std::string AppendAssignOutputDataFunction(gsl::span inputs, const ShaderVariableHelper& output) { + std::ostringstream ss; + size_t input_count = inputs.size(); + ss.imbue(std::locale::classic()); + ss << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {" << std::endl; + if (input_count == 0) { + AppendAssignOutput(ss, *inputs[0], output); + } else { + for (size_t i = 0; i < input_count; ++i) { + if (i == 0) { + ss << " if (input_index == 0u) {" << std::endl; + } else if (i == input_count - 1) { + ss << " } else {" << std::endl; + } else { + ss << " } else if (input_index == " << i << "u) {" << std::endl; + } + ss << " "; + AppendAssignOutput(ss, *inputs[i], output); + } + ss << " }" << std::endl; + } + ss << "}" << std::endl; + return ss.str(); +} +Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { + size_t input_count = Inputs().size(); + std::vector inputs; + inputs.reserve(input_count); + for (size_t i = 0; i < input_count; ++i) { + inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); + } + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + shader.AppendImplementation(AppendCalCulateInputIndexFunction(input_count)); + shader.AppendImplementation(AppendAssignOutputDataFunction(gsl::make_span(inputs), output)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " var indices = ", output.OffsetToIndices("global_idx"), ";\n", + " let indices_axis = ", output.IndicesGet("indices", axis_), ";\n", + " let input_index = calculate_input_index(indices_axis);\n", + " if (input_index != 0u) {\n", + " ", output.IndicesSet("indices", axis_, "indices_axis - uniforms.size_in_concat_axis[input_index - 1]"), ";\n", + " }\n", + " assign_output_data(global_idx, input_index, indices);\n"); + return Status::OK(); +} + +Status Concat::ComputeInternal(ComputeContext& context) const { + int input_count = context.InputCount(); + InlinedTensorsVector input_tensors; + input_tensors.reserve(input_count); + for (int i = 0; i < input_count; ++i) { + input_tensors.push_back(context.Input(i)); + } + + Prepare prepare; + ORT_RETURN_IF_ERROR(PrepareForCompute(&context.GetKernelContext(), input_tensors, prepare)); + if (prepare.output_num_elements == 0) { + return Status::OK(); + } + + uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); + + ConcatProgram program{prepare.axis}; + + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(input_count); + uint32_t sum = 0; + for (int i = 0; i < input_count; ++i) { + const auto& input = prepare.inputs[i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + + auto axis_size = input.tensor->Shape()[prepare.axis]; + sum += static_cast(axis_size); + sizes_in_concat_axis.push_back(sum); + } + + size_t non_empty_input_count = sizes_in_concat_axis.size(); + + if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { + // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", + input_count, ", output=1) exceeds the limit (", + context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + } + + program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), + output_size}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h new file mode 100644 index 0000000000000..0f6e6dd327e33 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/concatbase.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class ConcatProgram final : public Program { + public: + ConcatProgram(size_t axis) : Program{"Concat"}, axis_{axis} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + size_t axis_; +}; + +class Concat final : public WebGpuKernel, public ConcatBase { + public: + Concat(const OpKernelInfo& info) : WebGpuKernel(info), ConcatBase(info) { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index df2a2caa0a1fa..c1f13d652413d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -628,10 +628,10 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, From 8da1f7a157b554261da9482c3d7ffae2ae0617f6 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 27 Sep 2024 18:11:27 +0800 Subject: [PATCH 5/5] [webgpu-native] Add MatmulNBits (#22150) ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../webgpu/quantization/matmul_nbits.cc | 294 ++++++++++++++++++ .../webgpu/quantization/matmul_nbits.h | 53 ++++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- onnxruntime/core/providers/webgpu/program.cc | 90 ++++-- onnxruntime/core/providers/webgpu/program.h | 21 +- .../core/providers/webgpu/shader_helper.cc | 24 +- .../core/providers/webgpu/shader_variable.cc | 103 +++--- .../test/contrib_ops/matmul_4bits_test.cc | 16 +- 8 files changed, 501 insertions(+), 102 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..b1f1a3a9ad8d0 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { +// Put it to a common place? +uint32_t GetMaxComponents(uint32_t size) { + // we cannot use vec3 type since it has alignment of 16 bytes + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + + return 1; +} + +std::string QuantizedDataType(int components) { + switch (components) { + case 1: + return "array"; + case 2: + return "mat4x2"; + case 4: + return "mat2x4"; + default: + return "array"; + } +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + MatMulNBits); + +Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); + const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + + const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); + const int output_element_number = y.NumComponents() * SafeInt(output_number_); + std::ostringstream prepare_scale_and_zero_point; + prepare_scale_and_zero_point.imbue(std::locale::classic()); + prepare_scale_and_zero_point << " var col_index = col * " << y.NumComponents() << ";\n"; + if (has_zero_points_) { + const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); + prepare_scale_and_zero_point << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + << " var zero_point_byte_count: u32;\n" + << " var zero_point_word_index: u32;\n" + << " var zero_point_byte_offset: u32;\n" + << " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + << " var zero_point_bits_offset: u32;\n" + << " var zero_point_word: u32;\n"; + for (int c = 0; c < output_element_number; c++) { + prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; + prepare_scale_and_zero_point << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" + << " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + << " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + << " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n"; + prepare_scale_and_zero_point << " col_index += 1;\n"; + } + } else { + prepare_scale_and_zero_point << " let zero_point = output_element_t(8.0);\n"; + for (int c = 0; c < output_element_number; c++) { + prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; + prepare_scale_and_zero_point << " col_index += 1;\n"; + } + } + + std::ostringstream prepare_b_data; + prepare_b_data.imbue(std::locale::classic()); + prepare_b_data << " col_index = col * " << y.NumComponents() << ";\n"; + for (int c = 0; c < output_element_number; c++) { + prepare_b_data << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" + << " col_index += 1;\n"; + } + prepare_b_data << " var b_value : u32;\n" + << " let b_mask : u32 = 0x0F0F0F0Fu;\n" + << " var b_value_lower : vec4;\n" + << " var b_value_upper : vec4;\n" + << " var b_quantized_values : " << quantized_data_type << ";\n" + << " var b_dequantized_values : " << quantized_data_type << ";\n"; + + std::ostringstream process_one_word; + process_one_word.imbue(std::locale::classic()); + process_one_word << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" + << " var a_data: " << quantized_data_type << ";\n" + << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" + << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" + << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" + << " input_offset++;\n" + << " } else {\n" + << " a_data[j] = input_a_value_t(0);\n" + << " }\n" + << " }\n"; + for (int c = 0; c < output_element_number; c++) { + process_one_word << " b_value = " << "b" << c << "_data"; + if (components_b_ > 1) { + process_one_word << "[i]"; + } + process_one_word << ";\n" + << " b_value_lower = unpack4xU8(b_value & b_mask);\n" + << " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" + << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + << " b_dequantized_values = "; + if (a.NumComponents() == 1) { + if (has_zero_points_) { + process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; + } else { + process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " + << "(b_quantized_values[1] - zero_point) * scale" << c << "," + << "(b_quantized_values[2] - zero_point) * scale" << c << "," + << "(b_quantized_values[3] - zero_point) * scale" << c << "," + << "(b_quantized_values[4] - zero_point) * scale" << c << "," + << "(b_quantized_values[5] - zero_point) * scale" << c << "," + << "(b_quantized_values[6] - zero_point) * scale" << c << "," + << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; + } + } else { + process_one_word << "(b_quantized_values - " << quantized_data_type << "("; + for (int i = 0; i < 8; i++) { + if (has_zero_points_) { + process_one_word << "zero_point" << c; + } else { + process_one_word << "zero_point"; + } + if (i < 7) { + process_one_word << ", "; + } + } + process_one_word << ")) * scale" << c << ";\n"; + } + + process_one_word << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; + if (y.NumComponents() > 1) { + process_one_word << "[" << c % y.NumComponents() << "]"; + } + process_one_word << " += "; + if (a.NumComponents() == 1) { + process_one_word << "a_data[0] * b_dequantized_values[0] + " + << "a_data[1] * b_dequantized_values[1] + " + << "a_data[2] * b_dequantized_values[2] + " + << "a_data[3] * b_dequantized_values[3] + " + << "a_data[4] * b_dequantized_values[4] + " + << "a_data[5] * b_dequantized_values[5] + " + << "a_data[6] * b_dequantized_values[6] + " + << "a_data[7] * b_dequantized_values[7];\n"; + } else if (a.NumComponents() == 2) { + process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " + << "dot(a_data[1], b_dequantized_values[1]) + " + << "dot(a_data[2], b_dequantized_values[2]) + " + << "dot(a_data[3], b_dequantized_values[3]);\n"; + } else if (a.NumComponents() == 4) { + process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " + << "dot(a_data[1], b_dequantized_values[1]);\n"; + } + } + + const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; + std::string offset = "workgroup_idx * " + std::to_string(output_number_); + shader.AppendImplementation("var workgroup_shared : array;\n"); + shader.SetMainFunctionBody(" let output_indices = ", y.OffsetToIndices(offset), + ";\n" + " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + " let blob_size = uniforms.input_b_shape[2]" + ";\n" + " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" + " var word_offset = block * uniforms.block_size / ", + a.NumComponents(), ";\n", + prepare_scale_and_zero_point.str(), + " for (var word: u32 = 0; word < blob_size; word += 1) {\n", + prepare_b_data.str(), + " for (var i: u32 = 0; i < ", components_b_, "; i++) {\n", + process_one_word.str(), + " word_offset += ", 8 / a.NumComponents(), + ";\n" + " }\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + " if (local_id.x < ", + output_number_, + ") {\n" + " var output_value = output_value_t(0);\n" + " var workgroup_shared_offset = local_id.x;\n" + " let blocks_num = min(", + shared_memory_size, + ", n_blocks_per_col);\n" + " for (var b = 0u; b < blocks_num; b++) {\n" + " output_value += workgroup_shared[workgroup_shared_offset];\n" + " workgroup_shared_offset += ", + output_number_, + ";\n" + " }\n", + " ", + y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value"), + "\n" + " }\n"); + + return Status::OK(); +} + +Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* b = context.Input(1); + const Tensor* scales = context.Input(2); + const Tensor* zero_points = context.Input(3); + const Tensor* g_idx = context.Input(4); + const Tensor* bias = context.Input(5); + + ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet."); + ORT_ENFORCE(bias == nullptr, "bias as input is not supported yet."); + + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + auto* y = context.Output(0, helper.OutputShape()); + const uint32_t data_size = SafeInt(y->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const uint32_t batch_count = SafeInt(helper.OutputOffsets().size()); + const uint32_t M = SafeInt(helper.M()); + const uint32_t N = SafeInt(helper.N()); + const uint32_t K = SafeInt(helper.K()); + const uint32_t block_size = SafeInt(block_size_); + const uint32_t nbits = 4; + + const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size; + const uint32_t blob_size = (block_size / 8) * nbits; + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_a = GetMaxComponents(K); + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + const uint32_t components = GetMaxComponents(N); + // TODO: Support output_number > 1. Some cases are failed when output_number > 1. + // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; + const uint32_t output_number = 1; + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; + + const bool has_zero_points = zero_points != nullptr; + MatMulNBitsProgram program{output_number, SafeInt(components_b), has_zero_points}; + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, SafeInt(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, SafeInt(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, SafeInt(components)}) + .SetDispatchGroupSize(data_size / components / output_number) + .AddUniformVariable({block_size}) + .CacheHint(std::to_string(output_number)); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..7fec1423faf04 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MatMulNBitsProgram final : public Program { + public: + MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points) : Program{"MatMulNBits"}, + output_number_{output_number}, + components_b_{components_b}, + has_zero_points_{has_zero_points} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t output_number_; + int components_b_; + bool has_zero_points_; +}; + +class MatMulNBits final : public WebGpuKernel { + public: + MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + N_ = info.GetAttr("N"); + block_size_ = info.GetAttr("block_size"); + int64_t bits = info.GetAttr("bits"); + ORT_ENFORCE(bits == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 01c8a28d45069..b5d7a90b9bbfd 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -44,7 +44,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 75c3c9ee96081..25c0a4278be89 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -84,21 +84,24 @@ std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) #ifndef NDEBUG constexpr std::string_view ProgramVariableDataTypeName[] = { - "f32", // f32 - "f32x2", // vec2f32 - "f32x4", // vec4f32 - "f16", // f16 - "f16x2", // vec2f16 - "f16x4", // vec4f16 - "i32", // i32 - "i32x2", // vec2i32 - "i32x4", // vec4i32 - "u32", // u32 - "u32x2", // vec2u32 - "u32x4", // vec4u32 - "i64", // int64 - "u64", // uint64 - "boolx4", // vec4bool + "f32", // Float32 + "f32x2", // Float32x2 + "f32x4", // Float32x4 + "f16", // Float16 + "f16x2", // Float16x2 + "f16x4", // Float16x4 + "i32", // Int32 + "i32x2", // Int32x2 + "i32x4", // Int32x4 + "u32", // Uint32 + "u32x2", // Uint32x2 + "u32x4", // Uint32x4 + "i64", // Int64 + "u64", // Uint64 + "boolx4", // Boolx4 + "u8x4", // Uint8x4 + "u8x8", // Uint8x8 + "u8x16", // Uint8x16 }; std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; @@ -115,17 +118,22 @@ int NumberOfComponents(ProgramVariableDataType type) { case ProgramVariableDataType::Uint64: case ProgramVariableDataType::Float16: return 1; - case ProgramVariableDataType::Vec2Float32: - case ProgramVariableDataType::Vec2Int32: - case ProgramVariableDataType::Vec2Uint32: - case ProgramVariableDataType::Vec2Float16: + case ProgramVariableDataType::Float32x2: + case ProgramVariableDataType::Int32x2: + case ProgramVariableDataType::Uint32x2: + case ProgramVariableDataType::Float16x2: return 2; - case ProgramVariableDataType::Vec4Float32: - case ProgramVariableDataType::Vec4Int32: - case ProgramVariableDataType::Vec4Uint32: - case ProgramVariableDataType::Vec4Float16: - case ProgramVariableDataType::Vec4Bool: + case ProgramVariableDataType::Float32x4: + case ProgramVariableDataType::Int32x4: + case ProgramVariableDataType::Uint32x4: + case ProgramVariableDataType::Float16x4: + case ProgramVariableDataType::Boolx4: + case ProgramVariableDataType::Uint8x4: return 4; + case ProgramVariableDataType::Uint8x8: + return 8; + case ProgramVariableDataType::Uint8x16: + return 16; default: return -1; } @@ -152,28 +160,44 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp } else if (component == 2) { switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return ProgramVariableDataType::Vec2Float32; + return ProgramVariableDataType::Float32x2; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - return ProgramVariableDataType::Vec2Float16; + return ProgramVariableDataType::Float16x2; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return ProgramVariableDataType::Vec2Int32; + return ProgramVariableDataType::Int32x2; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - return ProgramVariableDataType::Vec2Uint32; + return ProgramVariableDataType::Uint32x2; default: return ProgramVariableDataType::InvalidType; } } else if (component == 4) { switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return ProgramVariableDataType::Vec4Float32; + return ProgramVariableDataType::Float32x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: - return ProgramVariableDataType::Vec4Float16; + return ProgramVariableDataType::Float16x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: - return ProgramVariableDataType::Vec4Int32; + return ProgramVariableDataType::Int32x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: - return ProgramVariableDataType::Vec4Uint32; + return ProgramVariableDataType::Uint32x4; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - return ProgramVariableDataType::Vec4Bool; + return ProgramVariableDataType::Boolx4; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 8) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x8; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 16) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ProgramVariableDataType::Uint8x16; default: return ProgramVariableDataType::InvalidType; } diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index f05ca9c2bf224..bd9a26b0fcfbd 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -172,20 +172,23 @@ constexpr SafeInt WORKGROUP_SIZE = 64; enum class ProgramVariableDataType { InvalidType = -1, Float32, - Vec2Float32, - Vec4Float32, + Float32x2, + Float32x4, Float16, - Vec2Float16, - Vec4Float16, + Float16x2, + Float16x4, Int32, - Vec2Int32, - Vec4Int32, + Int32x2, + Int32x4, Uint32, - Vec2Uint32, - Vec4Uint32, + Uint32x2, + Uint32x4, Int64, Uint64, - Vec4Bool, + Boolx4, + Uint8x4, + Uint8x8, + Uint8x16 }; #ifndef NDEBUG std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index c229e821cbf8c..a88687fce18b5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -114,27 +114,27 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || - var_type == ProgramVariableDataType::Vec2Float32 || - var_type == ProgramVariableDataType::Vec4Float32, + var_type == ProgramVariableDataType::Float32x2 || + var_type == ProgramVariableDataType::Float32x4, "Unexpected program variable type ", int(var_type), " for float32 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 || - var_type == ProgramVariableDataType::Vec2Float16 || - var_type == ProgramVariableDataType::Vec4Float16, + var_type == ProgramVariableDataType::Float16x2 || + var_type == ProgramVariableDataType::Float16x4, "Unexpected program variable type ", int(var_type), " for float16 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || - var_type == ProgramVariableDataType::Vec2Int32 || - var_type == ProgramVariableDataType::Vec4Int32, + var_type == ProgramVariableDataType::Int32x2 || + var_type == ProgramVariableDataType::Int32x4, "Unexpected program variable type ", int(var_type), " for int32 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 || - var_type == ProgramVariableDataType::Vec2Uint32 || - var_type == ProgramVariableDataType::Vec4Uint32, + var_type == ProgramVariableDataType::Uint32x2 || + var_type == ProgramVariableDataType::Uint32x4, "Unexpected program variable type ", int(var_type), " for uint32 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: @@ -146,9 +146,15 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va "Unexpected program variable type ", int(var_type), " for uint64 tensor"); break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Vec4Bool, + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Boolx4, "Unexpected program variable type ", int(var_type), " for bool tensor"); break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint8x4 || + var_type == ProgramVariableDataType::Uint8x8 || + var_type == ProgramVariableDataType::Uint8x16, + "Unexpected program variable type ", int(var_type), " for uint8 tensor"); + break; default: ORT_RETURN_IF(true, "Unsupported data type: ", element_type); // todo: add int4/uint4 diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f2a5b049b4777..cbc39c86e504d 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -15,57 +15,66 @@ namespace webgpu { namespace { constexpr static const std::string_view STORAGE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "vec2", // int64 - "vec2", // uint64 - "u32", // vec4bool + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "vec2", // Int64 + "vec2", // Uint64 + "u32", // Boolx4 + "u32", // Uint8x4 + "vec2", // Uint8x8 + "vec4", // Uint8x16 }; constexpr static const std::string_view VALUE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "i32", // int64 (trancated to i32) - "u32", // uint64 (trancated to u32) - "vec4", // vec4bool + "f32", // Float32 + "vec2", // Float32x2 + "vec4", // Float32x4 + "f16", // Float16 + "vec2", // Float16x2 + "vec4", // Float16x4 + "i32", // Int32 + "vec2", // Int32x2 + "vec4", // Int32x4 + "u32", // Uint32 + "vec2", // Uint32x2 + "vec4", // Uint32x4 + "i32", // Int64 (trancated to i32) + "u32", // Uint64 (trancated to u32) + "vec4", // Boolx4 + "u32", // Uint8x4 (u32 as 4 elements of uint8) + "vec2", // Uint8x8 (vec2 as 2x4 elements of uint8) + "vec4", // Uint8x16 (vec4 as 4x4 elements of uint8) }; constexpr static const std::string_view ELEMENT_TYPE[] = { - "f32", // f32 - "f32", // vec2f32 - "f32", // vec4f32 - "f16", // f16 - "f16", // vec2f16 - "f16", // vec4f16 - "i32", // i32 - "i32", // vec2i32 - "i32", // vec4i32 - "u32", // u32 - "u32", // vec2u32 - "u32", // vec4u32 - "i32", // int64 - "u32", // uint64 - "bool", // vec4bool + "f32", // Float32 + "f32", // Float32x2 + "f32", // Float32x4 + "f16", // Float16 + "f16", // Float16x2 + "f16", // Float16x4 + "i32", // Int32 + "i32", // Int32x2 + "i32", // Int32x4 + "u32", // Uint32 + "u32", // Uint32x2 + "u32", // Uint32x4 + "i32", // Int64 + "u32", // Uint64 + "bool", // Boolx4 + "u32", // Uint8x4 + "u32", // Uint8x8 + "u32", // Uint8x16 }; inline std::string GetIndicesType(int rank) { @@ -263,7 +272,7 @@ std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const case onnxruntime::webgpu::ProgramVariableDataType::Uint64: ss << ElementType() << "(" << name_ << "[" << offset << "].x)"; break; - case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: ss << "vec4(bool(" << name_ << "[" << offset << "] & 0xFFu), bool(" << name_ << "[" << offset << "] & 0xFF00u), bool(" @@ -291,7 +300,7 @@ std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std:: case onnxruntime::webgpu::ProgramVariableDataType::Uint64: ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), 0u);"; break; - case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + case onnxruntime::webgpu::ProgramVariableDataType::Boolx4: ss << name_ << "[" << offset << "]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(" << value << "));"; break; default: diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index fa7c6bce7c23e..669beb055309b 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -278,7 +278,11 @@ void TestMatMulNBitsTyped() { base_opts.output_abs_error = 0.1f; } else { if constexpr (std::is_same::value) { +#ifdef USE_WEBGPU + base_opts.output_abs_error = 0.03f; +#else base_opts.output_abs_error = 0.01f; +#endif } } @@ -293,7 +297,7 @@ void TestMatMulNBitsTyped() { RunTest(opts); } -#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) && !defined(USE_WEBGPU) { TestOptions opts = base_opts; opts.has_g_idx = true; @@ -324,7 +328,7 @@ void TestMatMulNBitsTyped() { opts.has_zero_point = true, opts.zp_is_4bit = false; RunTest(opts); } -#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) && !defined(USE_WEBGPU) { TestOptions opts = base_opts; @@ -358,7 +362,7 @@ TEST(MatMulNBits, Float16) { #endif #endif -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) namespace { // Legacy test function. @@ -393,6 +397,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura #ifdef USE_DML execution_providers.push_back(DefaultDmlExecutionProvider()); #endif +#ifdef USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); +#endif RunTest(opts, std::move(execution_providers)); } else { @@ -437,6 +444,9 @@ TEST(MatMulNBits, Float16Large) { // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. float abs_error = 0.3f; +#elif USE_WEBGPU + // See Intel A770 to pass these tests with an absolute error of 0.08. + float abs_error = 0.08f; #else float abs_error = 0.05f; #endif