Skip to content

Commit

Permalink
[webgpu-native] Add MatmulNBits (microsoft#22150)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
qjia7 and fs-eire authored Sep 27, 2024
1 parent 92a08e2 commit 8da1f7a
Show file tree
Hide file tree
Showing 8 changed files with 501 additions and 102 deletions.
294 changes: 294 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Large diffs are not rendered by default.

53 changes: 53 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
public:
MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points) : Program{"MatMulNBits"},
output_number_{output_number},
components_b_{components_b},
has_zero_points_{has_zero_points} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32});

private:
uint32_t output_number_;
int components_b_;
bool has_zero_points_;
};

class MatMulNBits final : public WebGpuKernel {
public:
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
K_ = info.GetAttr<int64_t>("K");
N_ = info.GetAttr<int64_t>("N");
block_size_ = info.GetAttr<int64_t>("block_size");
int64_t bits = info.GetAttr<int64_t>("bits");
ORT_ENFORCE(bits == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
}

Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;

private:
int64_t K_;
int64_t N_;
int64_t block_size_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
Expand Down
90 changes: 57 additions & 33 deletions onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,24 @@ std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep)

#ifndef NDEBUG
constexpr std::string_view ProgramVariableDataTypeName[] = {
"f32", // f32
"f32x2", // vec2f32
"f32x4", // vec4f32
"f16", // f16
"f16x2", // vec2f16
"f16x4", // vec4f16
"i32", // i32
"i32x2", // vec2i32
"i32x4", // vec4i32
"u32", // u32
"u32x2", // vec2u32
"u32x4", // vec4u32
"i64", // int64
"u64", // uint64
"boolx4", // vec4bool
"f32", // Float32
"f32x2", // Float32x2
"f32x4", // Float32x4
"f16", // Float16
"f16x2", // Float16x2
"f16x4", // Float16x4
"i32", // Int32
"i32x2", // Int32x2
"i32x4", // Int32x4
"u32", // Uint32
"u32x2", // Uint32x2
"u32x4", // Uint32x4
"i64", // Int64
"u64", // Uint64
"boolx4", // Boolx4
"u8x4", // Uint8x4
"u8x8", // Uint8x8
"u8x16", // Uint8x16
};
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) {
os << ProgramVariableDataTypeName[std::underlying_type<decltype(type)>::type(type)];
Expand All @@ -115,17 +118,22 @@ int NumberOfComponents(ProgramVariableDataType type) {
case ProgramVariableDataType::Uint64:
case ProgramVariableDataType::Float16:
return 1;
case ProgramVariableDataType::Vec2Float32:
case ProgramVariableDataType::Vec2Int32:
case ProgramVariableDataType::Vec2Uint32:
case ProgramVariableDataType::Vec2Float16:
case ProgramVariableDataType::Float32x2:
case ProgramVariableDataType::Int32x2:
case ProgramVariableDataType::Uint32x2:
case ProgramVariableDataType::Float16x2:
return 2;
case ProgramVariableDataType::Vec4Float32:
case ProgramVariableDataType::Vec4Int32:
case ProgramVariableDataType::Vec4Uint32:
case ProgramVariableDataType::Vec4Float16:
case ProgramVariableDataType::Vec4Bool:
case ProgramVariableDataType::Float32x4:
case ProgramVariableDataType::Int32x4:
case ProgramVariableDataType::Uint32x4:
case ProgramVariableDataType::Float16x4:
case ProgramVariableDataType::Boolx4:
case ProgramVariableDataType::Uint8x4:
return 4;
case ProgramVariableDataType::Uint8x8:
return 8;
case ProgramVariableDataType::Uint8x16:
return 16;
default:
return -1;
}
Expand All @@ -152,28 +160,44 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp
} else if (component == 2) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return ProgramVariableDataType::Vec2Float32;
return ProgramVariableDataType::Float32x2;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return ProgramVariableDataType::Vec2Float16;
return ProgramVariableDataType::Float16x2;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return ProgramVariableDataType::Vec2Int32;
return ProgramVariableDataType::Int32x2;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
return ProgramVariableDataType::Vec2Uint32;
return ProgramVariableDataType::Uint32x2;
default:
return ProgramVariableDataType::InvalidType;
}
} else if (component == 4) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Uint8x4;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return ProgramVariableDataType::Vec4Float32;
return ProgramVariableDataType::Float32x4;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return ProgramVariableDataType::Vec4Float16;
return ProgramVariableDataType::Float16x4;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return ProgramVariableDataType::Vec4Int32;
return ProgramVariableDataType::Int32x4;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
return ProgramVariableDataType::Vec4Uint32;
return ProgramVariableDataType::Uint32x4;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
return ProgramVariableDataType::Vec4Bool;
return ProgramVariableDataType::Boolx4;
default:
return ProgramVariableDataType::InvalidType;
}
} else if (component == 8) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Uint8x8;
default:
return ProgramVariableDataType::InvalidType;
}
} else if (component == 16) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Uint8x16;
default:
return ProgramVariableDataType::InvalidType;
}
Expand Down
21 changes: 12 additions & 9 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,23 @@ constexpr SafeInt<uint32_t> WORKGROUP_SIZE = 64;
enum class ProgramVariableDataType {
InvalidType = -1,
Float32,
Vec2Float32,
Vec4Float32,
Float32x2,
Float32x4,
Float16,
Vec2Float16,
Vec4Float16,
Float16x2,
Float16x4,
Int32,
Vec2Int32,
Vec4Int32,
Int32x2,
Int32x4,
Uint32,
Vec2Uint32,
Vec4Uint32,
Uint32x2,
Uint32x4,
Int64,
Uint64,
Vec4Bool,
Boolx4,
Uint8x4,
Uint8x8,
Uint8x16
};
#ifndef NDEBUG
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType);
Expand Down
24 changes: 15 additions & 9 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,27 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 ||
var_type == ProgramVariableDataType::Vec2Float32 ||
var_type == ProgramVariableDataType::Vec4Float32,
var_type == ProgramVariableDataType::Float32x2 ||
var_type == ProgramVariableDataType::Float32x4,
"Unexpected program variable type ", int(var_type), " for float32 tensor");
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 ||
var_type == ProgramVariableDataType::Vec2Float16 ||
var_type == ProgramVariableDataType::Vec4Float16,
var_type == ProgramVariableDataType::Float16x2 ||
var_type == ProgramVariableDataType::Float16x4,
"Unexpected program variable type ", int(var_type), " for float16 tensor");

break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 ||
var_type == ProgramVariableDataType::Vec2Int32 ||
var_type == ProgramVariableDataType::Vec4Int32,
var_type == ProgramVariableDataType::Int32x2 ||
var_type == ProgramVariableDataType::Int32x4,
"Unexpected program variable type ", int(var_type), " for int32 tensor");
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 ||
var_type == ProgramVariableDataType::Vec2Uint32 ||
var_type == ProgramVariableDataType::Vec4Uint32,
var_type == ProgramVariableDataType::Uint32x2 ||
var_type == ProgramVariableDataType::Uint32x4,
"Unexpected program variable type ", int(var_type), " for uint32 tensor");
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
Expand All @@ -146,9 +146,15 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va
"Unexpected program variable type ", int(var_type), " for uint64 tensor");
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Vec4Bool,
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Boolx4,
"Unexpected program variable type ", int(var_type), " for bool tensor");
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint8x4 ||
var_type == ProgramVariableDataType::Uint8x8 ||
var_type == ProgramVariableDataType::Uint8x16,
"Unexpected program variable type ", int(var_type), " for uint8 tensor");
break;
default:
ORT_RETURN_IF(true, "Unsupported data type: ", element_type);
// todo: add int4/uint4
Expand Down
Loading

0 comments on commit 8da1f7a

Please sign in to comment.