Skip to content

Commit

Permalink
refactor uint8 components
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Sep 20, 2024
1 parent e8a07cc commit 46ccca4
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 21 deletions.
25 changes: 14 additions & 11 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace contrib {
namespace webgpu {

namespace {
uint32_t getMaxComponents(uint32_t size) {
// Put it to a common place?
uint32_t GetMaxComponents(uint32_t size) {
// we cannot use vec3 type since it has alignment of 16 bytes
if (size % 4 == 0) {
return 4;
Expand Down Expand Up @@ -116,7 +117,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "}\n";
for (int c = 0; c < output_element_number; c++) {
process_one_word << "b_value = " << "b" << c << "_data";
if (b.NumComponents() > 1) {
if (components_b_ > 1) {
process_one_word << "[i]";
}
process_one_word << ";\n"
Expand Down Expand Up @@ -186,7 +187,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {

const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE;
std::string offset = "workgroup_idx * " + std::to_string(output_number_);
shader.AppendImplementation("var<workgroup> workgroup_shared : array<output_value_t," + std::to_string(shared_memory_size) + ">;\n const workgroup_size = " + std::to_string(static_cast<uint32_t>(WORKGROUP_SIZE)) + "u;\n");
shader.AppendImplementation("var<workgroup> workgroup_shared : array<output_value_t," + std::to_string(shared_memory_size) + ">;\n");
shader.SetMainFunctionBody("let output_indices = ", y.OffsetToIndices(offset),
";\n"
"let col = output_indices[2];\n"
Expand All @@ -195,13 +196,13 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
"let n_blocks_per_col = uniforms.input_b_shape[1];\n"
"let blob_size = uniforms.input_b_shape[2]"
";\n"
"for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size) {\n"
"for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n"
" var word_offset = block * uniforms.block_size / ",
a.NumComponents(), ";\n",
prepare_scale_and_zero_point.str(),
" for (var word: u32 = 0; word < blob_size; word += 1) {\n",
prepare_b_data.str(),
" for (var i: u32 = 0; i < ", b.NumComponents(), "; i++) {\n",
" for (var i: u32 = 0; i < ", components_b_, "; i++) {\n",
process_one_word.str(),
" word_offset += ", 8 / a.NumComponents(),
";\n"
Expand Down Expand Up @@ -260,9 +261,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
const uint32_t n_blocks_per_col = (K + block_size - 1) / block_size;
const uint32_t blob_size = (block_size / 8) * nbits;
const uint32_t blob_size_in_words = blob_size / 4;
const uint32_t components_a = getMaxComponents(K);
const uint32_t components_b = getMaxComponents(blob_size_in_words);
const uint32_t components = getMaxComponents(N);
const uint32_t components_a = GetMaxComponents(K);
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
const uint32_t components = GetMaxComponents(N);
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
const uint32_t output_number = 1;
Expand All @@ -272,15 +273,17 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
TensorShape reshaped_y_shape{batch_count, M, N / components};

const bool has_zero_points = zero_points != nullptr;
MatMulNBitsProgram program{output_number, has_zero_points};
MatMulNBitsProgram program{output_number, SafeInt<int>(components_b), has_zero_points};
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, SafeInt<int>(components_a)}, {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, SafeInt<int>(components_b)}, {scales, ProgramTensorMetadataDependency::None}})
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, SafeInt<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, SafeInt<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, SafeInt<int>(components)})
.SetDispatchGroupSize(data_size / components / output_number)
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None});
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
return context.RunProgram(program);
}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ using namespace onnxruntime::webgpu;

class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
public:
MatMulNBitsProgram(uint32_t output_number, bool has_zero_points) : Program{"MatMulNBits"}, output_number_{output_number}, has_zero_points_{has_zero_points} {
MatMulNBitsProgram(const uint32_t output_number, const int components_b, const bool has_zero_points) : Program{"MatMulNBits"},
output_number_{output_number},
components_b_{components_b},
has_zero_points_{has_zero_points} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32});

private:
uint32_t output_number_;
int components_b_;
bool has_zero_points_;
};

Expand Down
28 changes: 23 additions & 5 deletions onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ constexpr std::string_view ProgramVariableDataTypeName[] = {
"i64", // int64
"u64", // uint64
"boolx4", // vec4bool
"u32", // uint8
"u32x2", // vec2uint8
"u32x4", // vec4uint8
};
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) {
os << ProgramVariableDataTypeName[std::underlying_type<decltype(type)>::type(type)];
Expand All @@ -125,7 +128,12 @@ int NumberOfComponents(ProgramVariableDataType type) {
case ProgramVariableDataType::Vec4Uint32:
case ProgramVariableDataType::Vec4Float16:
case ProgramVariableDataType::Vec4Bool:
case ProgramVariableDataType::Uint8:
return 4;
case ProgramVariableDataType::Vec2Uint8:
return 8;
case ProgramVariableDataType::Vec4Uint8:
return 16;
default:
return -1;
}
Expand All @@ -134,8 +142,6 @@ int NumberOfComponents(ProgramVariableDataType type) {
ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) {
if (component == 1) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Uint32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return ProgramVariableDataType::Float32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
Expand All @@ -153,8 +159,6 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp
}
} else if (component == 2) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Vec2Uint32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return ProgramVariableDataType::Vec2Float32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
Expand All @@ -169,7 +173,7 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp
} else if (component == 4) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Vec4Uint32;
return ProgramVariableDataType::Uint8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return ProgramVariableDataType::Vec4Float32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
Expand All @@ -183,6 +187,20 @@ ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int comp
default:
return ProgramVariableDataType::InvalidType;
}
} else if (component == 8) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Vec2Uint8;
default:
return ProgramVariableDataType::InvalidType;
}
} else if (component == 16) {
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return ProgramVariableDataType::Vec4Uint8;
default:
return ProgramVariableDataType::InvalidType;
}
} else {
return ProgramVariableDataType::InvalidType;
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ enum class ProgramVariableDataType {
Int64,
Uint64,
Vec4Bool,
Uint8,
Vec2Uint8,
Vec4Uint8
};
#ifndef NDEBUG
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType);
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va
"Unexpected program variable type ", int(var_type), " for bool tensor");
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 ||
var_type == ProgramVariableDataType::Vec2Uint32 ||
var_type == ProgramVariableDataType::Vec4Uint32,
ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint8 ||
var_type == ProgramVariableDataType::Vec2Uint8 ||
var_type == ProgramVariableDataType::Vec4Uint8,
"Unexpected program variable type ", int(var_type), " for UINT8 tensor");
break;
default:
Expand All @@ -169,7 +169,7 @@ Status ValidateVariableShape(const TensorShape& origin_shape,
int num_components) {
if (use_override_shape) {
// if override shape specified, assert override_size == ceil( origin_size / 4 )
ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size() || ((origin_shape.Size() + num_components - 1) / num_components / 4 == override_shape.Size()),
ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(),
"Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components);
}

Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/providers/webgpu/shader_variable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ constexpr static const std::string_view STORAGE_TYPE[] = {
"vec2<u32>", // int64
"vec2<u32>", // uint64
"u32", // vec4bool
"u32", // uint8
"vec2<u32>", // vec2uint8
"vec4<u32>", // vec4uint8
};

constexpr static const std::string_view VALUE_TYPE[] = {
Expand All @@ -48,6 +51,9 @@ constexpr static const std::string_view VALUE_TYPE[] = {
"i32", // int64 (trancated to i32)
"u32", // uint64 (trancated to u32)
"vec4<bool>", // vec4bool
"u32", // uint8
"vec2<u32>", // vec2uint8
"vec4<u32>", // vec4uint8
};

constexpr static const std::string_view ELEMENT_TYPE[] = {
Expand All @@ -66,6 +72,9 @@ constexpr static const std::string_view ELEMENT_TYPE[] = {
"i32", // int64
"u32", // uint64
"bool", // vec4bool
"u32", // uint8
"u32", // vec2uint8
"u32", // vec4uint8
};

inline std::string GetIndicesType(int rank) {
Expand Down

0 comments on commit 46ccca4

Please sign in to comment.