From 2a5c9b86ebbdba8fb76f79de26524a2fdd2e5c2a Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Tue, 5 Mar 2024 10:11:19 +0800 Subject: [PATCH 1/7] Zhijxu/fix conv1d replacement (#19758) remove the constraint - "group number should be less than 3"; add more condition to make sure the conv1d replacement only happens on conv1d instead of conv2d/conv3d; add more tests; --- .../core/optimizer/conv1d_replacement.cc | 63 +++++++++++------- .../test/optimizer/graph_transform_test.cc | 64 ++++++++++++++++--- 2 files changed, 96 insertions(+), 31 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc index 0412000e04e1b..ff220fcb067b8 100644 --- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -42,30 +42,45 @@ */ namespace onnxruntime { bool NodeCanBeReplacedByMatmul(const Node& node) { - // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, - // then it can be replaced by MatMul - // Kernel_shape is 1 means it is conv1d + /* + If node type is Conv, and satisfy the following conditions then it can be replaced by MatMul: + - not bias as input which means only has 2 inputs: input and weight + - "dilations" should be [1] + size 1 means conv1d + - "strides" should be [1] + - "pads" should be [0,0] + - "autopad" should be "NOTSET" + - "kernel_shape" should be [1] + */ if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { return false; } - const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); - const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); - const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); - const auto* group = graph_utils::GetNodeAttribute(node, "group"); - if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { + + // TODO: bias input can also be supported if needed + if (node.InputDefs().size() != 2) { return false; } - if ((dilations->ints_size() && dilations->ints(0) != 1) || - (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || - (stride->ints_size() && stride->ints(0) != 1) || - group->i() >= 3) { + + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* strides = graph_utils::GetNodeAttribute(node, "strides"); + const auto* pads = graph_utils::GetNodeAttribute(node, "pads"); + const auto* autopad = graph_utils::GetNodeAttribute(node, "auto_pad"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + if (dilations == nullptr || strides == nullptr || pads == nullptr || autopad == nullptr || kernel_shape == nullptr) { return false; } - return true; + if ((dilations->ints_size() == 1 && dilations->ints(0) == 1) && + (strides->ints_size() == 1 && strides->ints(0) == 1) && + (autopad->s() == "NOTSET") && + (pads->ints_size() == 2 && pads->ints(0) == 0 && pads->ints(1) == 0) && + (kernel_shape->ints_size() == 1 && kernel_shape->ints(0) == 1)) { + return true; + } + return false; } -void Conv1dToMatmul(Graph& graph, Node& conv) { +void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name) { // Shape of conv1d input: [batch_size, in_channels, in_length] // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 // We need to split the input into "group", and squeeze&split the weight, and then do MatMul @@ -83,7 +98,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("input_split_output"), nullptr)); } - auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + auto& input_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {conv1d_input}, {conv1d_input_splitted_outputs}); input_split.SetExecutionProviderType(execution_provider_type); input_split.AddAttribute("axis", int64_t(1)); @@ -93,23 +108,25 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { } // 2. Squeeze conv weight auto conv1d_weight = conv.MutableInputDefs()[1]; + // auto con1d_bias = xx; auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); - auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName(transformer_name + "WeightSqueeze"), "Squeeze", node_description, {conv1d_weight}, {weight_squeeze_output}); + int64_t weight_squeeze_axis = 2; if (onnx_opset_version > 12) { // After onnx version 12, squeeze node has axes as input instead of attribute ONNX_NAMESPACE::TensorProto initializer_proto; - initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.set_name(graph.GenerateNodeName(transformer_name + "ConstAsInitializer")); initializer_proto.add_dims(static_cast(1)); initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector initializer_proto_value{2}; + InlinedVector initializer_proto_value{weight_squeeze_axis}; initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); // Squeeze node doesn't have opschema here, so we need to set input args count manually weight_squeeze.MutableInputArgsCount().resize(2); graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); } else { - weight_squeeze.AddAttribute("axes", std::vector{2}); + weight_squeeze.AddAttribute("axes", std::vector{weight_squeeze_axis}); } weight_squeeze.SetExecutionProviderType(execution_provider_type); // 3. Split conv weight @@ -118,7 +135,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("weight_split_output"), nullptr)); } - auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + auto& weight_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); weight_split.AddAttribute("axis", int64_t(0)); weight_split.SetExecutionProviderType(execution_provider_type); @@ -130,13 +147,13 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { for (int i = 0; i < group_num; i++) { auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); matmul_outputs.push_back(matmul_output); - auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + auto& matmul = graph.AddNode(graph.GenerateNodeName(transformer_name + "Matmul"), "MatMul", node_description, {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, {matmul_output}); matmul.SetExecutionProviderType(execution_provider_type); } // 5. Concat matmul outputs - auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + auto& concat_node = graph.AddNode(graph.GenerateNodeName(transformer_name + "Concat"), "Concat", node_description, matmul_outputs, {}); concat_node.SetExecutionProviderType(execution_provider_type); concat_node.AddAttribute("axis", int64_t(1)); @@ -155,7 +172,7 @@ Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_leve ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (NodeCanBeReplacedByMatmul(node)) { LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); - Conv1dToMatmul(graph, node); + Conv1dToMatmul(graph, node, Name()); modified = true; } } diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index bab7c09839273..109937ff96d1d 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1200,7 +1200,7 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } -TEST_F(GraphTransformationTests, Conv1dReplacement) { +TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); @@ -1208,7 +1208,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { }; for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { - for (auto group : {1, 2}) { + for (auto group : {1, 2, 4}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; @@ -1222,6 +1222,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(group)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; auto post_graph_checker = [&](Graph& graph) { @@ -1243,28 +1245,64 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { } } -TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { +// node has bias input so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); return Status::OK(); }; - // "group" is 3 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* bias_arg = builder.MakeInitializer({out_channel}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + +// "auto_pad " is not NOTSET so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); auto* conv_output = builder.MakeOutput(); auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); - conv_node.AddAttribute("group", static_cast(3)); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "VALID"); }; std::unique_ptr transformer = std::make_unique(); @@ -1272,8 +1310,16 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { TransformerLevel::Level1, 1, pre_graph_checker, pre_graph_checker)); } +} + +// pads is not all zero, so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; - // "kernel_shape" is not 1 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); @@ -1285,9 +1331,11 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); - conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{1, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; std::unique_ptr transformer = std::make_unique(); From 7e613ee821405b1192d0b71b9434a4f94643f1e4 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 5 Mar 2024 11:45:45 +0800 Subject: [PATCH 2/7] [quant] supports act_order inputs in Matmulnbits and new quantization algorithm "hqq" (#19106) ### Description 1. Support quantized GPTQ weight in huggingface like [TheBloke/Llama-2-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) 2. Support Act_order for GPTQ 3. Support [HQQ](https://mobiusml.github.io/hqq_blog/) algorithm to quantize matmul weight and add quant script ### Motivation and Context --- docs/ContribOperators.md | 43 +- docs/OperatorKernels.md | 4 +- .../cpu/quantization/matmul_nbits.cc | 105 ++++- .../cpu/quantization/matmul_nbits_impl.cc | 108 +++++ .../cpu/quantization/matmul_nbits_impl.h | 23 ++ .../cuda/quantization/dequantize_blockwise.cu | 159 ++++++-- .../quantization/dequantize_blockwise.cuh | 6 +- .../cuda/quantization/matmul_nbits.cc | 170 ++++---- .../cuda/quantization/matmul_nbits.h | 41 ++ .../core/graph/contrib_ops/contrib_defs.cc | 38 +- .../quantization/matmul_4bits_quantizer.py | 379 ++++++++++++++++-- .../test/contrib_ops/matmul_4bits_test.cc | 78 +++- .../test/python/quantization/op_test_utils.py | 3 +- .../quantization/test_op_matmul_4bits.py | 19 +- 14 files changed, 942 insertions(+), 234 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e295dfa203ae5..5f0100fad95a2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - + Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. #### Version @@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4) +#### Inputs (3 - 5)
A : T1
The input tensor, not quantized
B : T2
-
1-dimensional data blob
+
1 or 2 dimensional data blob
scales : T1
quantization scale
-
zero_points (optional) : T2
+
zero_points (optional) : T3
quantization zero points
+
g_idx (optional) : T4
+
group_idx
#### Outputs @@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float), tensor(float16)
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8)
-
Constrain quantized weight types to uint8.
+
T2 : tensor(uint8), tensor(int32)
+
Constrain quantized weight types to uint8/int32.
+
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+
Constrain quantized zero point types to uint8/int32/float16/float.
+
T4 : tensor(int32)
+
the index tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0e60b4622f2fb..71b0def659741 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -470,7 +470,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -855,7 +855,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 166f5c8f52f54..602dd98d8c0d6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include + +#include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" @@ -50,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } } // namespace +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -59,6 +76,17 @@ class MatMulNBits final : public OpKernel { block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { + const auto& node = info.node(); + auto input_defs = node.InputDefs(); + // g_idx + if (input_defs.size() > 4) { + act_order_ = true; + } + int32_t type; + if (input_defs.size() > 3 && GetType(*input_defs[3], type)) { + zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -88,6 +116,8 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; + bool act_order_{false}; + bool zero_point_is_not_quant_{false}; const int64_t accuracy_level_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; @@ -105,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - + if (act_order_ || zero_point_is_not_quant_) { + return Status::OK(); + } #if defined(ORT_NEURAL_SPEED) if (!all_constant_) { @@ -212,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); @@ -257,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { #endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr; + const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr; + const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -281,8 +315,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + const bool has_single_b_matrix = + (!act_order_) && (!zero_point_is_not_quant_) && + std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); if (has_single_b_matrix) { const auto compute_type = static_cast(accuracy_level_); @@ -328,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const uint8_t* b_data = b->Data(); const size_t ldb = helper.Ldb(true); - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - zero_points_data, // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); - + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } + } #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); @@ -374,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX( kCpuExecutionProvider, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc new file mode 100644 index 0000000000000..f92e59e990ba5 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/float16.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void Dequantize4BitsKernelReOrder( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, const int32_t* reorder_idx, int block_size, + int groups_per_threadblock, int total_groups, int out_rows, int out_cols, + int blockIdx_x, int threadIdx_x) { + const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size); + if (group_id >= total_groups) { + return; + } + const int scales_shape_x = (out_cols + block_size - 1) / block_size; + const int zero_point_shape_x = (scales_shape_x + 1) / 2; + + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1)); + + const int out_x = element_offset % (scales_shape_x * block_size); + const int out_y = element_offset / (scales_shape_x * block_size); + if (out_y >= out_rows || out_x >= out_cols) { + return; + } + T* output_i = output + out_y * out_cols + out_x; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int remain_x = std::min(8, out_cols - out_x); + for (int i = 0; i < remain_x; i++) { + int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + float zp_f = 8; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = *(zero_points + n_idx * scales_shape_x + rid); + } else { + uint8_t zp = 8; + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * MLFloat16(zp_f); + output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * zp_f; + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // reorder_idx for groupwise quantization + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* pool) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + constexpr int element_per_thread = 8; + int groups_per_threadblock = 256 * element_per_thread / block_size; + int groups_per_K = ceildiv(K, block_size); + int total_groups = N * groups_per_K; // total elemenets in quant_data + int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock)); + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(blocks_per_grid), + [&](std::ptrdiff_t block_id) { + for (int j = 0; j < 256; j++) { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } + }); +} + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const float* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h new file mode 100644 index 0000000000000..5061ac5c800a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 6b66f1d84e221..cd6593352008b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -2,10 +2,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" @@ -56,41 +58,94 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f } template -__global__ void Dequantize4BitsKernel( +__global__ void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int block_size, - int blocks_per_K, - int blocks_per_threadblock, - int total_blks, - int shift) { - int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); - if (block_id >= total_blks) { + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (group_id >= total_groups) { return; } - int n_idx = block_id / blocks_per_K; - int kb_idx = block_id % blocks_per_K; - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + // T __shared__ zero_points_after_reorder[];//K + // T __shared__ scales_after_reorder[]; // K + // const int num_r_per_thread = k / 256; + + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + T* output_i = output + element_offset; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + for (int i = 0; i < 8; i++) { + int32_t rid = reorder_idx[kb_idx * block_size + i]; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * __short2half_rn(zp); + output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * T(zp); + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (block_id >= total_groups) { + return; + } + int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - uint8_t zp = 8; - if (zero_points) { - zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; - zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + T zero_point_value; + if constexpr (std::is_same_v) { + const int scales_shape_x = groups_per_K; + const int zero_point_shape_x = (groups_per_K + 1) / 2; + int kb_idx = block_id % scales_shape_x; + int n_idx = block_id / scales_shape_x; + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + zero_point_value = static_cast(zp); + } else { + zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); } output = output + element_offset; - DequantizeEightElements(quant_value, scale, static_cast(zp), output); + DequantizeEightElements(quant_value, scale, zero_point_value, output); } -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2] + const int32_t* reorder_idx, int k, int n, int block_size, @@ -98,47 +153,79 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int blocks_per_K = k / block_size; - int total_blks = n * blocks_per_K; - int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); - int shift = static_cast(log2f(float(block_size))); - - Dequantize4BitsKernel<<>>( - output, - quant_data, - scales_data, - zero_points, - block_size, - blocks_per_K, - blocks_per_threadblock, - total_blks, - shift); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // total elemenets in quant_data + int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + if (!reorder_idx) { + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + // static_assert(std::is_same_v, "ZeroT must be uint8_t"); + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } return Status::OK(); } -template Status Dequantize4Bits( +template Status Dequantize4Bits( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); -template Status Dequantize4Bits( +template Status Dequantize4Bits( half* output, const uint8_t* quant_data, const half* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); /////////////////////////////////////////////////////////////////////////////// // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f9c09c55fd893..580b5087f3fa3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,18 +7,18 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const ZeroT* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 015df70c8ec3c..1cec6f6a12f1c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -1,15 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// -// This module define MatMulFp32Q4 operator, it is basically -// matmul float32 with right hand side being a 2-D matrix -// pre-packed and block-compacted into int4 -// - -#include "core/common/safeint.h" -#include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.h" + +#include + +#include "core/common/status.h" +#include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -19,40 +16,19 @@ namespace contrib { namespace cuda { using namespace onnxruntime::cuda; -template -class MatMulNBits final : public CudaKernel { - public: - MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_blk_{true}; -}; - template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); + const Tensor* reorder_idx = ctx->Input(4); const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); typedef typename ToCudaType::MappedType CudaT; @@ -67,77 +43,99 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - if (!is_4bit_done) { - int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); - auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - // column-wise block + bool is_4bit_done = (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (is_4bit_done) { + return Status::OK(); + } + + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, + (const CudaT*)zero_points_data, + reorder_idx_data, SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), SafeInt(N_), + SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 - cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); - T* b_data_cpu = new T[K_ * N_]; - cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); - delete[] b_data_cpu; +cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); +T* b_data_cpu = new T[K_ * N_]; +cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); +delete[] b_data_cpu; #endif - const CudaT alpha = ToCudaType::FromFloat(1.f); - const CudaT zero = ToCudaType::FromFloat(0.f); - - if (helper.OutputOffsets().size() == 1) { - CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( - GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, - SafeInt(helper.N()), - SafeInt(helper.M()), - SafeInt(helper.K()), - &alpha, - reinterpret_cast(b_data), - SafeInt(K_padded), - reinterpret_cast(a_data), - helper.Lda(transa), - &zero, - reinterpret_cast(Y->MutableData()), - helper.Ldc(), - GetDeviceProp(), - UseTF32())); - } + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..f5c2c6c4e4fdf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulNBits operator, it is basically +// matmul float with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// +#pragma once +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e33ce20737f80..f06a3785f362d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3343,22 +3343,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. -Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: -- n_blocks_per_col = (K + block_size - 1) / block_size -- blob_size = block_size / 8 * bits - - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. -Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) @@ -3377,12 +3378,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored "type T1.", AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1-dimensional data blob", "T2") + .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index eb7bbec997d59..a1916e806c5c0 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -65,7 +65,7 @@ def __init__( self, calibration_data_reader: CalibrationDataReader, percdamp=0.01, - blocksize=128, + block_size=128, actorder=False, mse=False, perchannel=True, @@ -79,7 +79,7 @@ def __init__( a calibration data reader. It enumerates calibration data and generates inputs for the original model. percdamp: percent of the average Hessian diagonal to use for dampening. - blocksize (int, optional): + block_size (int, optional): channel number in one block to execute a GPTQ quantization iteration. actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. @@ -93,42 +93,285 @@ def __init__( ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp - self.blocksize = blocksize + self.block_size = block_size self.actorder = actorder self.mse = mse self.perchannel = perchannel -class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" +class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + block_size=128, + bits=4, + axis=1, + ): + """ + This is a class for HQQ algorithm Weight Only Quant Configuration. + HQQ algorithm quant weight without needing calibrate data. + + Args: + block_size (int, optional): + channel number in one block to execute a GPTQ quantization iteration. + bits (int, optional): + how many bits to represent weight. + axis (int, optional): + 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + """ + super().__init__( + algorithm="HQQ", + ) + self.block_size = block_size + self.bits = bits + self.axis = axis + +class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, - model: ModelProto | str, - block_size: int, - is_symmetric: bool, + block_size: int = 128, + is_symmetric: bool = False, accuracy_level: int | None = None, - nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, ): - if nodes_to_exclude is None: - nodes_to_exclude = [] - self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) - self.model_path = model if isinstance(model, str) else None + super().__init__(algorithm="DEFAULT") self.block_size = block_size self.is_symmetric = is_symmetric + self.bits = 4 self.accuracy_level = accuracy_level - self.nodes_to_exclude = set(nodes_to_exclude) - self.algo_config = algo_config + + +def is_divisible(val1, val2): + return int(val2 * np.ceil(val1 / val2)) == val1 + + +class HQQWeightOnlyQuantizer: + def __init__( + self, + config: HQQWeightOnlyQuantConfig, + ): + self.config = config + + # Proximal solver || weight - dequantize(quantize(weight))||_p^p + @staticmethod + def optimize_weights( + tensor, + scale, + zero, + min_max: list[int], + axis: int = 0, + opt_params: dict = None, # noqa: RUF013 + verbose=False, + ): + import torch + + opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + + dtype = torch.float16 if tensor.is_cuda else torch.float32 + w_f = tensor.to(dtype) + scale = scale.to(dtype) + zero = zero.to(dtype) + + if lp_norm == 1: + + def shrink_op(x, beta): + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + + else: + + def shrink_op(x, beta, p=lp_norm): + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1) + ) + + best_error = 1e4 + for i in range(iters): + w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1]) + w_r = (w_q - zero) / scale + w_e = shrink_op(w_f - w_r, beta) + zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(w_f - w_r).mean()) + if verbose: + print(i, np.round(current_error, 6)) + if current_error < best_error: + best_error = current_error + else: + break + + del w_f, w_q, w_r, w_e + + return scale, zero @staticmethod - def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: - for gid in range(len(graph_path) - 1, -1, -1): - graph = graph_path[gid] - for tensor in graph.initializer: - if tensor.name == name: - return tensor, graph - return None, None + def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): + if pack_tensor.shape[0] == ori_int_tensor.shape[0]: + ori_int_tensor = ori_int_tensor.T + pack_tensor = pack_tensor.T + if bits in [2, 4, 8]: + compress_ratio = pack_tensor.element_size() * 8 // bits + for j in range(0, compress_ratio): + pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j)) + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + # from Official implementation of Half-Quadratic Quantization (HQQ) + def quantize_internal( + self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1 + ): + import torch + + weight = tensor.float() + ori_shape = weight.shape + + pad_len = (group_size - ori_shape[axis] % group_size) % group_size + if axis == 1: + weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0) + else: + weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0) + shape = weight.shape + + # Reshape for grouping + if (group_size is not None) and channel_wise: + weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1]) + + # Get min/max values + if channel_wise is False: + _min, _max = weight.min(), weight.max() + optimize = False + else: + _min = weight.min(axis=axis, keepdim=True)[0] + _max = weight.max(axis=axis, keepdim=True)[0] + + max_v = 2**bits - 1 + min_v = 0 + min_max = [min_v, max_v] + + # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on. + # clamp to avoid half-precision problems + scale = (max_v / (_max - _min)).clamp(max=2e4) + #!!!!!!!!!!!!!!! + min_max_axis = _max - _min + if (min_max_axis == 0).sum().item() > 0: + min_max_axis[min_max_axis == 0] = max_v + scale = (max_v / min_max_axis).clamp(max=2e4) + zero = -_min * scale + + if round_zero: + zero = torch.round(zero) + + # Fine-tune weights + if optimize: + scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis) + + # Quantize + # Necessary for fake quantization backprop + w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1]) + w_q = w_q.reshape(shape).int() + + scale = 1.0 / scale + if axis == 1: + scale = scale.reshape(shape[0], -1) + zero = zero.reshape(shape[0], -1) + else: + scale = scale.reshape(-1, shape[-1]) + zero = zero.reshape(-1, shape[-1]) + # cleanup + del weight, _min, _max + + return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) + + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + if node.op_type != "MatMul": + return node # only care about MatMul for now + import torch + + logger.info(f"start to quantize {node.name} ...") + inputB = node.input[1] # noqa: N806 + b_pb, bs_graph = get_initializer(inputB, graph_stack) + if b_pb is None: + logger.info("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + b_array = onnx.numpy_helper.to_array(b_pb) + if len(b_array.shape) != 2: + logger.info("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + b_array_torch = torch.from_numpy(b_array) + if torch.cuda.is_available(): + b_array_torch = b_array_torch.cuda() + quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal( + b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size + ) + quant_weight_torch = quant_weight_torch.contiguous() + scales_torch = scales_torch.contiguous() + zero_points_torch = zero_points_torch.contiguous() + + packed_torch = torch.zeros( + (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2), + dtype=torch.uint8, + device=quant_weight_torch.device, + ) + self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits) + scales = scales_torch.cpu().numpy() + zero_points = zero_points_torch.cpu().numpy() + b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) + b_quant.name = b_pb.name + "_Q4" + for input in bs_graph.input: + if input.name == inputB: + bs_graph.input.remove(input) + break + + scales_tensor = onnx.numpy_helper.from_array(scales) + scales_tensor.name = b_pb.name + "_scales" + bs_graph.initializer.extend([b_quant, scales_tensor]) + + input_names = [node.input[0], b_quant.name, scales_tensor.name] + zp_tensor = onnx.numpy_helper.from_array(zero_points) + zp_tensor.name = b_pb.name + "_zero_points" + bs_graph.initializer.extend([zp_tensor]) + input_names.append(zp_tensor.name) + + kwargs = {} + rows, cols = b_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = self.config.bits + kwargs["block_size"] = self.config.block_size + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.info(f"complete quantization of {node.name} ...") + + return matmul_q4_node + + +def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + +class DefaultWeightOnlyQuantizer: + def __init__(self, config: DefaultWeightOnlyQuantConfig): + self.config = config def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: """4b quantize fp32 weight to a blob""" @@ -137,7 +380,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = fp32weight.shape - block_size = self.block_size + block_size = self.config.block_size blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size @@ -149,23 +392,19 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) + quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) return (packed, scales, zero_point) - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": return node # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") - if node.name in self.nodes_to_exclude: - logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - return node - inputB = node.input[1] # noqa: N806 - B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 if B is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return node # only care about constant weight @@ -188,7 +427,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) Bs_graph.initializer.extend([B_quant, scales_tensor]) input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.is_symmetric: + if not self.config.is_symmetric: zp_tensor = onnx.numpy_helper.from_array(zero_points) zp_tensor.name = B.name + "_zero_points" Bs_graph.initializer.extend([zp_tensor]) @@ -199,8 +438,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = 4 - kwargs["block_size"] = self.block_size - if self.accuracy_level is not None: + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: kwargs["accuracy_level"] = self.accuracy_level matmul_q4_node = onnx.helper.make_node( @@ -216,6 +455,38 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) return matmul_q4_node + +class MatMul4BitsQuantizer: + """Perform 4b quantization of constant MatMul weights""" + + def __init__( + self, + model: ModelProto | str, + block_size: int = 128, + is_symmetric: bool = False, + accuracy_level: int | None = None, + nodes_to_exclude=None, + algo_config: WeightOnlyQuantConfig = None, + ): + if nodes_to_exclude is None: + nodes_to_exclude = [] + self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) + self.model_path = model if isinstance(model, str) else None + self.block_size = block_size + self.is_symmetric = is_symmetric + self.accuracy_level = accuracy_level + self.nodes_to_exclude = set(nodes_to_exclude) + self.node_quantizer = None + if algo_config is None: + algo_config = DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + ) + self.algo_config = algo_config + if algo_config.algorithm == "HQQ": + self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) + elif algo_config.algorithm == "DEFAULT": + self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config) + def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] @@ -246,8 +517,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - - new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) + out_node = None + if node.name in self.nodes_to_exclude: + logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + out_node = node + elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": + out_node = self.node_quantizer.quantize(node, graph_stack) + else: + out_node = self.node_quantizer.quantize(node, graph_stack) + new_nodes.append(out_node) graph.ClearField("node") graph.node.extend(new_nodes) @@ -300,7 +578,7 @@ def inc_dataloader(): from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize kwargs["percdamp"] = self.algo_config.percdamp - kwargs["blocksize"] = self.algo_config.blocksize + kwargs["blocksize"] = self.algo_config.block_size kwargs["actorder"] = self.algo_config.actorder kwargs["mse"] = self.algo_config.mse kwargs["perchannel"] = self.algo_config.perchannel @@ -316,7 +594,7 @@ def inc_dataloader(): logger.info(f"complete quantization of model with {algorithm} algorithm.") def process(self): - if self.algo_config is None: + if self.algo_config.algorithm in ["HQQ", "DEFAULT"]: # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] opset_import = self.model.opset_import() @@ -327,7 +605,6 @@ def process(self): has_ms_domain = True if not has_ms_domain: opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - self._process_subgraph(graph_stack) self.model.clean_initializers() else: @@ -366,6 +643,14 @@ def parse_args(): parser.add_argument("--input_model", required=True, help="Path to the input model file") parser.add_argument("--output_model", required=True, help="Path to the output model file") parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization") + parser.add_argument( + "--quant_method", + default="default", + type=str, + choices=["default", "hqq"], + help="the algorithm used to quantize weight", + ) + parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") parser.add_argument( "--symmetric", required=False, @@ -411,12 +696,24 @@ def parse_args(): raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) + if args.quant_method == "hqq": + quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) + elif args.quant_method == "default": + quant_config = DefaultWeightOnlyQuantConfig( + block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + ) + elif args.quant_method == "rtn": + quant_config = RTNWeightOnlyQuantConfig() + elif args.quant_method == "gptq": + quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size) + else: + raise ValueError(f"Unsupported quantization method: {args.quant_method}") + quant = MatMul4BitsQuantizer( model=model, - block_size=args.block_size, - is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, nodes_to_exclude=args.nodes_to_exclude, + algo_config=quant_config, ) quant.process() quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 2ad20eafc2ef1..d294fd4e2b0e0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #ifndef ORT_MINIMAL_BUILD +#include #include "core/common/span_utils.h" #include "core/framework/tensor.h" @@ -66,7 +67,9 @@ void QuantizeDequantize(std::vector& raw_vals, } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, - bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) { + bool has_zeropoint, bool use_float16, bool has_g_idx = false, + bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { + zp_is_4bit = zp_is_4bit | has_g_idx; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); @@ -113,12 +116,40 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddAttribute("block_size", block_size); test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", accuracy_level); + auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; }; + if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + + test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -132,9 +163,34 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); - } + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + test.AddInput("zero_points", {static_cast(q_scale_size)}, zp_f, true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); + } test.AddOutput("Y", {M, N}, expected_vals); if (accuracy_level == 4) { test.SetOutputAbsErr("Y", 0.1f); @@ -158,6 +214,8 @@ TEST(MatMulNBits, Float32) { for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); + RunTest(M, N, K, block_size, accuracy_level, false, false, true); + RunTest(M, N, K, block_size, accuracy_level, true, false, false, false); } #endif } @@ -172,8 +230,10 @@ TEST(MatMulNBits, Float16) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { - RunTest(M, N, K, block_size, 0, false, true); - RunTest(M, N, K, block_size, 0, true, true); + for (auto has_gidx : {true, false}) { + RunTest(M, N, K, block_size, 0, false, true, has_gidx); + RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); + } } } } @@ -183,9 +243,9 @@ TEST(MatMulNBits, Float16) { TEST(MatMulNBits, Float16Large) { for (auto block_size : {16, 32, 64, 128}) { for (auto symmetric : {false, true}) { - RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f); + RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, 0.05f); } } } diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index c1bbb49f10c7e..b30282f2ab41f 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -358,6 +358,7 @@ def check_model_correctness( model_onnx = onnx.load(f) ops_set = set(node.op_type for node in model_onnx.graph.node) check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + check_target_evaluator = False with open(model_path_to_check, "rb") as f: model_check = onnx.load(f) @@ -413,7 +414,7 @@ def check_model_correctness( check_sign_f8_quantization(model_path_origin, model_path_to_check) # Verifies the expected outputs. - if check_reference_evaluator and onnx_recent_enough: + if check_target_evaluator and onnx_recent_enough: if op_matmul: reference_new_ops = [QLinearMatMul] else: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 73dae08af8ece..88e5052db4e2e 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -125,7 +125,10 @@ def quant_test( from onnxruntime.quantization import matmul_4bits_quantizer model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric) + quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric + ) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) @@ -165,6 +168,9 @@ def quant_test_with_algo( elif algorithm == "GPTQ": # test GPTQ algorithm algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) + elif algorithm == "HQQ": + # test HQQ algorithm + algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) @@ -227,6 +233,17 @@ def test_quantize_matmul_int4_using_gptq_algo(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_using_hqq_algo(self): + if not find_spec("torch"): + self.skipTest("skip test_hqq_quant since torch is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False) + if __name__ == "__main__": unittest.main() From cd56ea4a74ee41c040899d702667d2c86bee4ef0 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 5 Mar 2024 13:15:30 +0800 Subject: [PATCH 3/7] enable embedding sparse optimization by default (#19714) --- docs/ORTModule_Training_Guidelines.md | 2 +- .../training/ortmodule/_graph_execution_manager.py | 14 +++++++++----- .../python/training/ortmodule/options.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index f50b18b736936..84631bd1f6555 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -246,7 +246,7 @@ to standard outputs. #### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input data sparsity based performance optimizations. ```bash diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index fda6e345da235..e189ffff9cc7f 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -681,11 +681,15 @@ def _enable_conditional_optimizations( ) if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0: - graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) - self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) - self._runtime_options.embed_sparsity_ratio = ",".join( - [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] - ) + if detected_device.type == "cuda": + # Embedding sparsity optimization is only supported on CUDA devices. + graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) + self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) + self._runtime_options.embed_sparsity_ratio = ",".join( + [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] + ) + else: + self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.") # If users don't want to print input density, disable the input density observer to avoid overhead # when looping through inputs during training. diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 539859a0d58a6..93d24a34df6bd 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -271,7 +271,7 @@ def __init__(self, logger: Logger): self.enable_sparse_optimizer = True self.label_sparsity_ratio = "" self.embed_sparsity_ratio = "" - self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done. + self.enable_embedding_sparse_optimizer = True # Configuration for memory optimization. self.memory_optimization_level = ( From bdf678df93cb257e311de3fa82fe6409be2854ff Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 5 Mar 2024 17:09:42 +0100 Subject: [PATCH 4/7] Fix CUDA BatchNorm bugs and add support for NHWC (#19742) ### Description - Fix incorrect running_mean / running_var in training mode due to incorrect momentum and missing input mean/var. runnig_var could be correct, but has a too high epsilon. - Fix incorrect checks when using NHWC - Pass NHWC flag to NormalizeDims to get correct new dimensions from x_shape - Register missing double operations to get parity between NHWC/NCHW --- .../core/providers/cpu/nn/batch_norm_helper.h | 41 +++++++++++++------ .../providers/cuda/cuda_execution_provider.cc | 18 +++++--- .../core/providers/cuda/cuda_nhwc_kernels.cc | 16 ++++++++ .../core/providers/cuda/nn/batch_norm.cc | 11 ++++- .../providers/cpu/nn/batch_norm_op_test.cc | 1 + 5 files changed, 66 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index a5d46aff83b50..ccecbabfa3db3 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -25,6 +25,8 @@ class BatchNormHelper { const Tensor* var, bool is_spatial = true, bool is_nhwc = false) { + // NHWC dependent shape: X + // All other shapes are assumed to be in NCHW layout? const auto& x_dims = X->Shape().GetDims(); // If x_dims size < 2, num_channels defaults to 1. @@ -48,16 +50,22 @@ class BatchNormHelper { // validate 'scales' shape const auto& scale_dims = scale->Shape().GetDims(); if (static_cast(scale_dims.size()) != kNumInputScaleDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); } if (scale_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels); } + // N & C do not belong to features + // skip the first element for NHWC and the first two elements for NCHW. + int feature_offset = is_nhwc ? 1 : 2; + // in non-spatial cases - the other dims of 'scale' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (scale_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -65,7 +73,8 @@ class BatchNormHelper { // validate 'B' shape const auto& B_dims = B->Shape().GetDims(); if (static_cast(B_dims.size()) != kNumInputBiasDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); } if (B_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels); @@ -73,8 +82,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'B' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (B_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (B_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -82,16 +92,19 @@ class BatchNormHelper { // validate 'mean' shape const auto& mean_dims = mean->Shape().GetDims(); if (static_cast(mean_dims.size()) != kNumInputMeanDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); } if (mean_dims[0] != num_channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: 0th dimension != ", num_channels); } // in non-spatial cases - the other dims of 'mean' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (mean_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -99,7 +112,8 @@ class BatchNormHelper { // validate 'var' shape const auto& var_dims = var->Shape().GetDims(); if (static_cast(var_dims.size()) != kNumInputVarianceDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); } if (var_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels); @@ -107,8 +121,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'var' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (var_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (var_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1ce089fd93044..8ba282031a5d4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1202,9 +1202,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); @@ -2107,9 +2110,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index f416caecd115f..64edc319e15ac 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -18,10 +18,14 @@ namespace onnxruntime::cuda { class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, @@ -72,10 +76,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double, + BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, BatchNormalization); @@ -86,18 +94,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo::ComputeInternal(OpKernelContext* p_op_kernel_context) CudnnTensor data_desc; vector new_dims; - BatchNormHelper::NormalizeDims(x_shape, new_dims); + BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC); ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType(), NHWC)); // For half data type, the alpha, beta, scale, B, mean, var need to be float type @@ -137,6 +137,12 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) auto saved_mean_data = reinterpret_cast(saved_mean->MutableData()); auto saved_inv_var_data = reinterpret_cast(saved_var->MutableData()); + auto stream = static_cast(p_op_kernel_context->GetComputeStream()->GetHandle()); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper( GetCudnnHandle(p_op_kernel_context), cudnn_batch_norm_mode_, @@ -149,7 +155,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) bn_tensor_desc, scale_data, b_data, - momentum_, + 1.0 - momentum_, running_mean_data, running_var_data, epsilon_, @@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false) #ifdef ENABLE_CUDA_NHWC_OPS SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true) +SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true) SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true) #endif } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index cbb4531a50b7c..54e5c71bd753a 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -916,6 +916,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // exclude CUDA Execution Provider due to flakiness // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", + // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } From 06e684c9f2f8495de5259967cc12bab24da3d522 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:37:45 -0800 Subject: [PATCH 5/7] Adding cuda kernel (optimized for sm80) for block-wise 4b quantized float 16 GEMM. (#18619) ### Description Adding CUDA kernel for block-wise 4b quantized float 16 GEMM, this is specially optimized for Nvidia Ampere GPUs. ### Motivation and Context Trying to improve quantized LLM inference performance on Nvidia Ampere GPUs ### Note: This is implemented by extending CUTLASS, so it has a hard dependency on CUTLASS. However, in current build system, loading of CUTLASS dependency is guarded with: (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) If both of these options are turned off, then compilation will fail. Why CUTLASS dependency is guarded at all? It's a header file only library that does not introduce any binary if not instantiated. What's the downside of removing all the guards and just include CUTLASS unconditionally? --- .lintrunner.toml | 1 + cmake/CMakeLists.txt | 5 +- cmake/onnxruntime_providers_cuda.cmake | 2 +- cmake/onnxruntime_unittests.cmake | 1 + onnxruntime/core/mickey/README.md | 4 + .../core/mickey/blk_q4/f16_gemm_sm80.h | 208 +++ .../{prepack_sm80.h => f16_prepack_sm80.h} | 2 +- .../cutlass_ext/q4gemm/device/quantb_gemm.h | 481 ++++++ .../q4gemm/kernel/default_quantb_gemm.h | 255 ++++ .../cutlass_ext/q4gemm/kernel/quantb_gemm.h | 462 ++++++ .../q4gemm/threadblock/default_quantb_mma.h | 248 ++++ .../threadblock/default_quantb_mma_core.h | 340 +++++ .../optional_predicated_tile_access_iter.h | 314 ++++ .../optional_regular_tile_access_iter.h | 224 +++ .../threadblock/quantb_mma_multistage.h | 1290 +++++++++++++++++ .../warp/default_quantb_mma_tensor_op.h | 112 ++ .../quantb_meta_mma_tensor_op_tile_iterator.h | 883 +++++++++++ .../q4gemm/warp/quantb_mma_tensor_op.h | 361 +++++ onnxruntime/core/util/matrix_layout.h | 1 - .../test/cuda_host/blkq4_fp16_quant_sm80.h | 203 +++ .../cuda/test_cases/blkq4_fp16_gemm_sm80.h | 188 +++ .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 330 +++++ .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 344 +++++ .../blkq4_fp16_sm80_prepack_test.cc | 507 ------- .../cuda_execution_provider_test.cc | 4 +- 25 files changed, 6257 insertions(+), 513 deletions(-) create mode 100644 onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h rename onnxruntime/core/mickey/blk_q4/{prepack_sm80.h => f16_prepack_sm80.h} (99%) create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h create mode 100644 onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu delete mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc diff --git a/.lintrunner.toml b/.lintrunner.toml index 4e5d077b08ff4..be95e03479cf9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -132,6 +132,7 @@ exclude_patterns = [ 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/core/graph/contrib_ops/quantization_defs.cc', 'onnxruntime/core/mlas/**', # Contains assembly code + 'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks ] command = [ diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8453da19ce3a6..0d55d4cab9826 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -727,6 +727,9 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) + message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) @@ -747,8 +750,8 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - endif() + if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 0f6d48bdb6ec8..7f295a59a0931 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -201,7 +201,7 @@ endif() include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 88f662075e177..b004054c616a5 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -774,6 +774,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md index 7e8d30cd1805b..735ec4b80daf3 100644 --- a/onnxruntime/core/mickey/README.md +++ b/onnxruntime/core/mickey/README.md @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that are often shared by various AI operators. The intention is to make this header files only, with no binary impact unless it is instantiated where it is needed. + +Currently cuda code are scattered in multiple locations in the repo. +Hopefully this can be the starting point of consolidating all cuda +code. diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h new file mode 100644 index 0000000000000..52bff7e40dbe3 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -0,0 +1,208 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blk_q4/f16_gemm_sm80.h + * + * Abstract: + * Entry point for Q4F16 GEMM kernel for SM80 devices. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/device/quantb_gemm.h" + +namespace onnxruntime { +namespace cuda { + +// +// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type +// +template < + typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16 + typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape + bool SmallM, // <- true if M <= 16 + bool kHasQuantOffset> +struct BlkQ4F16GemmImpl { + // + // Type definitions + // + + using ElementDequant = ElementDequant_; + using QuantBlocking = QuantBlocking_; + + static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!"); + + // Data types that are fixed for this kernel + using ElementAccumulator = float; + using ElementComputeEpilogue = ElementAccumulator; + using ElementInputA = ElementDequant; + using ElementOutput = ElementDequant; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + + // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators + // for async shared memory loading and minimize bank conflict + using ElementWPack = ElementDequant; + + using ElementQScale = ElementDequant; // <- data type of quantization scale + using ElementQOffset = uint8_t; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputWPack = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // Layout of quantization scale and offset, oriented to be loaded using less instructions + // in a warp tile + using LayoutInputQScale = + typename std::conditional::type; // <- layout of quantization scale + + using ShapeMMAThreadBlock = + typename std::conditional, + cutlass::gemm::GemmShape<128, 256, 64>>::type; + + static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32; + using ShapeMMAWarp = + typename std::conditional, + cutlass::gemm::GemmShape<64, 64, 64>>::type; + + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + + // Number of pipelines you want to use + static constexpr int NumStages = 3; + + using Gemm = cutlass::gemm::device::QuantBGemm< + ElementInputA, + LayoutInputA, + ElementWPack, + LayoutInputWPack, + ElementQScale, + typename std::conditional::type, + LayoutInputQScale, + QuantBlocking, + ElementOutput, + LayoutOutput, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages>; + + using Arguments = typename Gemm::Arguments; + + // Invoke gemm kernel (the version with quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_Qoffset_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (!kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_Qoffset_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } + + // Invoke gemm kernel (the version without quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h similarity index 99% rename from onnxruntime/core/mickey/blk_q4/prepack_sm80.h rename to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index e291ab39e8aa3..a08cfb97eed4a 100644 --- a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -3,7 +3,7 @@ * Licensed under the MIT License. * * Module Name: - * prepack_sm80.h + * blk_q4/f16_prepack_sm80.h * * Abstract: * Prepack weights and quantization parameters (scales and offsets) for diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h new file mode 100644 index 0000000000000..38795291b0328 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass_ext/q4gemm/kernel/default_quantb_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! A specialized GEMM operator for quantized B GEMM. + + It is modified from cutlass::gemm::device::Gemm. Both this class and the original Gemm class + are pretty much boilerplate code that construct the Gemm kernel class, and pass parameters + and controls to it. The only difference is that this class has a few more template parameters + to support quantization. + + This implementation pretty much follows the design of cutlass. But this class seems to be + just a wrapper of the Gemm kernel class. Consider combining them in future iterations. + +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class QuantBGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + // Quantization Parameters + static_assert(std::is_same::value, + "LayoutB, i.e. packed weights must appear ColumnMajor."); + static_assert(InstructionShape::kK == 16, + "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout."); + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + static constexpr bool kHasQOffset = !(std::is_same::value); + + // TODO(chenfucn): consider moving to uint4_t or smaller for QOffset + static_assert(!kHasQOffset || std::is_same::value, "QOffset must be uint8_t"); + + /// Define the kernel + using GemmKernel = typename kernel::DefaultQuantBGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQMeta, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout + >::GemmKernel; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_Qscale; + TensorRef ref_Qoffset; + + typename EpilogueOutputOp::Params epilogue; + + // split-K parallelism (etc.) are not yet supported, keeping this for future extension + int split_k_slices{1}; + // For gather+scatter operations + int const *gather_A_indices{nullptr}; + int const *gather_B_indices{nullptr}; + int const *scatter_D_indices{nullptr}; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); + } + + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); + } + }; + + private: + /// Kernel parameters object + typename GemmKernel::Params params_; + + public: + /// Constructs the GEMM. + QuantBGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_Qscale.reset(args.ref_Qscale.non_const_ref().data()); + params_.ref_Qoffset.reset(args.ref_Qoffset.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + std::cerr << "Failed to obtain maximum shared memory size " << smem_size << " for kernel: " + << cudaGetErrorString(result) << "\n"; + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h new file mode 100644 index 0000000000000..2f4460bb59e9f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/default_gemm.h. templates for combining + * threadblock-scoped matrix multiply-add with the appropriate + * threadblock-scoped epilogue. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/kernel/quantb_gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "cutlass/layout/permute.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Access granularity of quant scales in units of elements + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, + /// + typename Enable = void +> +struct DefaultQuantBGemm; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout type for quant scales + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Access granularity of quant scales in units of elements + typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout +> +struct DefaultQuantBGemm { + + static_assert((platform::is_same::value + || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + + using Affine2Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< + 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + using Epilogue = typename platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::QuantBGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h new file mode 100644 index 0000000000000..6e5ad8f406147 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -0,0 +1,462 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/gemm.h. + * Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct QuantBGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + static constexpr bool kHasQOffset = Mma::kHasQOffset; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorQScale::Params params_QScale; + typename Mma::IteratorQScale::TensorRef ref_QScale; + typename Mma::IteratorQOffset::Params params_QOffset; + typename Mma::IteratorQOffset::TensorRef ref_QOffset; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_size; // how many k vectors are processed by this threadblock + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr, + int const *gather_A_indices = nullptr, + int const *gather_B_indices = nullptr, + int const *scatter_D_indices = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + QuantBGemm() { } + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + // TODO check problem_size K, N must be multiple of QuantBlocking + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (problem_size.k() % Mma::Shape::kK != 0) { + // Currently we don't support this case due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. We can fix this by adding a predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return Status::kErrorInvalidProblem; + } + + int qscale_k = problem_size.k() / Mma::QuantBlocking::kRow; + int qscale_n = problem_size.n() / Mma::QuantBlocking::kColumn; + if ((qscale_k == 0) || (qscale_k * Mma::QuantBlocking::kRow != problem_size.k())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + if ((qscale_n == 0) || (qscale_n * Mma::QuantBlocking::kColumn != problem_size.n())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + + if (!TensorRef_aligned(ref_QScale, Mma::IteratorQScale::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + + if constexpr(kHasQOffset) { + if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k/2, params.problem_size.n()/2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; + const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; + + // should have been verified by can_implement() + assert((qscale_k > 0) && (qscale_k * Mma::QuantBlocking::kRow == problem_size_k)); + assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); + + cutlass::MatrixCoord tb_offset_QScale{ + threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) + }; + + typename Mma::IteratorQScale iterator_QScale( + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); + + typename Mma::IteratorQOffset iterator_QOffset( + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + const int warp_idx = canonical_warp_idx(); + const int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_QScale, iterator_QOffset, accumulators); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h new file mode 100644 index 0000000000000..0af604f090e1f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma.h + * @brief Modified from cutlass/gemm/threadblock/default_mma.h. + * Defining global memory data layout and iterators, combinging with mma core and + * pipelined GEMM kernel. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h" +#include "cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout for quant scales and offsets + typename LayoutQMeta_, + /// Blocking size for quantization + typename QuantBlocking_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute + > +struct DefaultQuantBMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout for quant scales and offsets + typename LayoutQMeta, + /// Blocking size for quantization + typename QuantBlocking, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout + > +struct DefaultQuantBMma { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; + + // Define iterators over tiles from the quant scales + using ThreadMapQScale = typename MmaCore::IteratorThreadMapQScale; + using AccessTypeQScale = + cutlass::Array; + using IteratorQScale = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, + ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>; + + using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset; + using AccessTypeQOffset = + cutlass::Array; + using IteratorQOffset = + cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorQScale, typename MmaCore::SmemIteratorQScale, + cutlass::arch::CacheOperation::Global, IteratorQOffset, + typename MmaCore::SmemIteratorQOffset, cutlass::arch::CacheOperation::Global, + ElementAccumulator, LayoutC, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h new file mode 100644 index 0000000000000..ad322f6505200 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_core.h + * @brief Modified from cutlass/gemm/threadblock/default_mma_core.h. + * Defining data layout in shared memory, and its iterators. + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Element data type of quant scale + typename ElementQScale, + /// Element data type of quant offset + typename ElementQOffset, + /// Layout of quant scale + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DefaultQuantBMmaCore; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Element data type of quant scale + typename ElementQScale_, + /// Element data type of quant offset + typename ElementQOffset_, + /// Layout of quant scale + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultQuantBMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + (Shape::kK / 2) / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK/2>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + using SmemLayoutQScale = LayoutQMeta; + using SmemLayoutQOffset = LayoutQMeta; + + /// Threadblock-level quantization meta data shape + using ThreadblockQShape = MatrixShape; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!"); + static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, + "Only support single column or row quantize blocking!"); + static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + + /// Threadblock-level quantization meta data shape in pitch-linear layout + using TBQPitchLinearShape = typename std::conditional< + std::is_same::value, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// By default we would like to use 128b load. However, we can't load more than + /// a column at a time in a column major layout. + static int const kElementsPerAccessQScale = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + + /// quant scale is tiny. Not all threads are needed. + static int const kAccessCntQScale = ThreadblockQShape::kCount / kElementsPerAccessQScale; + static int const kThreadsQScale = (kAccessCntQScale > kThreads) ? kThreads : kAccessCntQScale; + + using IteratorThreadMapQScale = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQScale, kElementsPerAccessQScale>; + + using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + + static int const kElementsPerAccessQOffset = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + static int const kAccessCntQOffset = ThreadblockQShape::kCount / kElementsPerAccessQOffset; + static int const kThreadsQOffset = (kAccessCntQOffset > kThreads) ? kThreads : kAccessCntQOffset; + + using IteratorThreadMapQOffset = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQOffset, kElementsPerAccessQOffset>; + + using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultQuantBMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQScale, QuantBlocking, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h new file mode 100644 index 0000000000000..6f27a692a3a2e --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -0,0 +1,314 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_predicated_tile_access_iter.h + * @brief Templates for loading and storing optional tiles of matrix data. + * This iterator is just a wrapper of PredicatedTileAccessIterator, with + * the option to turn it off at compile time and minimize its runtime + * footprint. Also, it utilize the higher numbered threads in the + * threadblock when the iterator can not utilize all the threads. + */ + +#pragma once + +#include + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D matrix data loader, when element is std::monostate, the +/// iterator becomes no-op with minimal runtime footprint. Also, it utilize the +/// higher numbered threads in the threadblock when the iterator can not utilize +/// all the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + /// Element data type of the iterator, no-op when it is std::monostate + typename Element_, + /// Layout of the source matrix + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + /// Number of threads in the threadblock, when provided, the iterator + /// will utilize the higher numbered threads + int kThreadBlockSize_ = -1> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kAdvanceRank = AdvanceRank_; + static constexpr int kThreadblockSize = kThreadBlockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized version below."); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + public: + Base base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : OptionalPredicatedTileAccessIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + base_.add_tile_offset(tile_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + OptionalPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + base_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + base_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + base_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + base_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return base_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for the disabled version +/// Reduce runtime overhead +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + int kThreadBlockSize_> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kThreadblockSize = kThreadBlockSize_; + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + public: + std::monostate base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_() {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) {} + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) {} + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() {} + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) {} + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) {} + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return false; } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h new file mode 100644 index 0000000000000..4b0ae5317f8bb --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h @@ -0,0 +1,224 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_regular_tile_access_iter.h + * @brief Templates implementing the address computation of storing of tiles + * from pitch-linear rank=2 tensors. + * + * This iterator is just a wrapper of RegularTileAccessIterator, with the + * option to turn it off at compile time and minimize its runtime footprint. + * Also, it utilize the higher numbered threads in the threadblock when the + * iterator can not utilize all the threads. + * + * Must be used in conjunction with OptionalPredicatedTileAccessIterator, + * with the same template parameters. + */ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D tile iterator, when element is std::monostate, the iterator +/// becomes no-op with minimal runtime footprint. Also, it utilize the higher +/// numbered threads in the threadblock when the iterator can not utilize all +/// the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Element_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + /// Number of threads in the threadblock, when not -1, the iterator + /// will utilize the higher numbered threads + int ThreadblockSize_ = -1, + int Alignment = + sizeof_bits::value * ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized template"); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + private: + + Base base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_(ref, flip_thread_id(thread_id)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + base_.add_tile_offset(coord); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization when Element is std::monostate, the iterator becomes no-op +/// +template < + typename Shape_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + int ThreadblockSize_, + int Alignment> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + private: + + std::monostate base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + return *this; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) {} +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h new file mode 100644 index 0000000000000..8b6bac8c5099a --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -0,0 +1,1290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_multistage.h + * @brief Modified from cutlass/gemm/threadblock/mma_multistage.h. + * Added the quantized data memory pipeline, dequantization, and feeding + * to tensor cores. Mainloop pipeline is heavily modified. + */ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Utilities for printing layout for the prepacked weights and quantization parameters +/// +template< + /// Data type of the prepacked weights + typename ElementWeight, + /// Data type of the quant scales + typename ElementQScale, + /// Data type of the quant offsets + typename ElementQOffset> +struct QuantBLayoutDebug{ + static constexpr bool debug_smem = true; + static constexpr bool debug_fragment = true; + ElementWeight* smem_b_ptr_; + ElementQScale* smem_qscale_ptr_; + ElementQOffset* smem_qoffset_ptr_; + int warp_id_; + int lane_id_; + int block_id_; + + template + CUTLASS_DEVICE + static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + static_assert(Size % 4 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const Element* ptr = reinterpret_cast(&frag); + for (int i = 0; i < Size/4; i++, ptr+=4){ + if constexpr(std::is_integral::value){ + printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", + threadIdx.x, label, i, + ptr[0], ptr[1], ptr[2], ptr[3]); + } else { + printf("T%.2d%c%d, %.3f, %.3f, %.3f, %.3f\n", + threadIdx.x, label, i, + float(ptr[0]), float(ptr[1]), float(ptr[2]), float(ptr[3])); + } + } + } + } + } + + template + CUTLASS_DEVICE + static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; + static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const uint8_t* ptr = reinterpret_cast(&frag); + for (int i = 0; i < I8Size/2; i++, ptr+=2){ + printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); + } + } + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dummy type when quant offset is not used, to avoid compilation error, +/// and reduce runtime footprint +/// +struct DummyType{ + std::monostate dummy_; + public: + DummyType() = default; + + CUTLASS_HOST_DEVICE + void* data() const { + return nullptr; + } + + CUTLASS_HOST_DEVICE + std::monostate& operator[](int idx) { + return dummy_; + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + static constexpr bool kHasQOffset = !std::is_same::value; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the prepacked weights + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // Tensor reference to the quantization scales + using TensorRefQScale = TensorRef; + using TensorRefQOffset = TensorRef; + + // Block size of the quantization (one set of quantization parameters per block of weights) + using QuantBlocking = typename Operator::QuantBlocking; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the prepacked weights in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the quantization parameter matrix in shared memory + /// Validation done in mma core class ThreadblockQShape + using ShapeQScale = + MatrixShape<(Shape::kK / QuantBlocking::kRow) * kStages, + Shape::kN / QuantBlocking::kColumn>; + + using BufTypeQOffset = std::conditional_t, + DummyType>; + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for prepacked weights + AlignedBuffer operand_B; + + /// Buffer for quantization scales + AlignedBuffer operand_QScale; + + /// Buffer for quantization offsets + BufTypeQOffset operand_QOffset; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQScale LayoutQMeta() { + return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQOffset LayoutQOffset() { + return Operator::SmemLayoutQOffset::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the prepacked weights + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the quantization scales + CUTLASS_HOST_DEVICE + TensorRefQScale operand_QScale_ref() { + return TensorRefQScale{operand_QScale.data(), LayoutQMeta()}; + } + + CUTLASS_HOST_DEVICE + TensorRefQOffset operand_QOffset_ref() { + if constexpr (!kHasQOffset){ + return TensorRefQOffset(); + } else { + return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; + } + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + /// Iterator to load a warp-scoped tile of quant scales from shared memory + typename Operator::IteratorQMeta warp_tile_iterator_QScale_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) + {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over tiles of quant scales in global memory + typename IteratorQScale_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQScale_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQScale, + /// Iterators over tiles of quant scales in global memory + typename IteratorQOffset_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQOffset_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQOffset, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaMultistage : + public QuantBMmaBase { +public: + ///< Base class + using Base = QuantBMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using IteratorQScale = IteratorQScale_; + using IteratorQOffset = IteratorQOffset_; + using SmemIteratorQScale = SmemIteratorQScale_; + using SmemIteratorQOffset = SmemIteratorQOffset_; + using QuantBlocking = typename Base::QuantBlocking; + + static cutlass::arch::CacheOperation::Kind const kCacheOpQScale = CacheOpQScale; + static cutlass::arch::CacheOperation::Kind const kCacheOpQOffset = CacheOpQOffset; + static constexpr bool kHasQOffset = Base::kHasQOffset; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of packed weights + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQScale = + IteratorQScale::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant scale + static int const kAccessesPerGroupQScale = + (AsyncCopyIterationsPerStageQScale + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQOffset = + IteratorQOffset::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant offset + static int const kAccessesPerGroupQOffset = + (AsyncCopyIterationsPerStageQOffset + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; + }; + + private: + + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + + using WarpLoadedFragmentQScale = typename Operator::FragmentQScale; + WarpLoadedFragmentQScale warp_loaded_frag_QScale_; + + using WarpLoadedFragmentQOffset = typename std::conditional::type; + WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; + }; + + + private: + + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of quant meta data to shared memory + SmemIteratorQScale smem_iterator_QScale_; + SmemIteratorQOffset smem_iterator_QOffset_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + /// very small meta data tensor require less threads to load + bool const should_load_qscale_; + bool const should_load_qoffset_; + + /// Shared memory pointers for debug dumping + static constexpr bool debug_layout = false; + using LayoutDebugType = typename std::conditional, + std::monostate>::type; + LayoutDebugType layout_debug_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + if constexpr(debug_layout){ + layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); + layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); + if constexpr(kHasQOffset){ + layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); + } else { + layout_debug_.smem_qoffset_ptr_ = nullptr; + } + layout_debug_.warp_id_ = warp_idx; + layout_debug_.lane_id_ = lane_idx; + layout_debug_.block_id_ = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + } + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_QScale_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_QScale_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + + smem_read_stage_idx_ = 0; + } + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + IteratorQScale &iterator_QScale, + IteratorQOffset &iterator_QOffset) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_QScale.add_tile_offset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_QScale_.add_tile_offset({1, 0}); + + if constexpr (kHasQOffset) { + iterator_QOffset.add_tile_offset({1, 0}); + smem_iterator_QOffset_.add_tile_offset({1, 0}); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0}); + if constexpr (kHasQOffset) { + smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0}); + } + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, + // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only + // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so + // it should be loaded in less than one cp.async instruction per thread. + // Even less for quant offset matrix. + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, + "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, + "Quant scale should 1 access per vector!"); + + // Async Copy for quantization scale + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); + } + + CUTLASS_DEVICE + void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, + "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, + "Quant offset should 1 access per vector!"); + + if constexpr(kHasQOffset) { + // Async Copy for quantization offset + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start = 0) { + auto group_start_A = group_start * Detail::kAccessesPerGroupA; + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + auto group_start_B = group_start * Detail::kAccessesPerGroupB; + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Async Copy for quantization scale + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); + + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + auto gmem_ptr = iterator_QScale.get(); + + cutlass::arch::cp_async( + dst_ptr, gmem_ptr, iterator_QScale.valid()); + + if constexpr (kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + + // Async Copy for quantization offset + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { + if (threadIdx.x == 0){ + printf("stage: %d\n", smem_write_stage_idx_); + } + cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); + if constexpr(kHasQOffset){ + cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); + } + } + } + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Loading next warp-level tiles from shared memory. This can be skipped on the very + // last iteration where: + // (gemm_k_iterations == (1 - Base::kStages)) && (warp_mma_k == (Base::kWarpGemmIterations - 1)) + // However, evaluating this condition seems more expensive than simply loading the tiles + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k == 0) { + copy_qscale_tiles(iterator_QScale); + } + if (warp_mma_k == 1) { + copy_qoffset_tiles(iterator_QOffset); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + /// Specialized mainloop iteration of matrix multiply-accumulate, for small M + CUTLASS_DEVICE + void mac_loop_iter_small_m( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // In the case of small M, memory latency dominates. We try to move uses far + // from their definitions to hide latency. + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Loading next warp-level tiles from shared memory. + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + copy_tiles_and_advance(iterator_A, iterator_B, 0); + + if constexpr(Shape::kM > 32) { + // the case of bigger m + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } else { + // the case of small m + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + } + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + if constexpr(Shape::kM > 32) { + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } else { + mac_loop_iter_small_m( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over quant scales in global memory + IteratorQScale iterator_QScale, + ///< Iterator over quant offsets in global memory + IteratorQOffset iterator_QOffset, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..2c49888c94504 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/default_mma_tensor_op.h + * Default warp-level GEMM operators selected by data type, size, and layouts of operands. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Data type of quant scales + typename ElementQScale, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale, + /// Data type of quant offsets + typename ElementQOffset, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset, + /// Blocking size of quantization + typename QuantBlocking, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultQuantBMmaTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::QuantBMmaTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, QuantBlocking, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h new file mode 100644 index 0000000000000..4ba39dda3db8d --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -0,0 +1,883 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file quantb_meta_mma_tensor_op_tile_iterator.h + * @brief Templates for loading quantization meta data for operand B + * from shared memory to fragments. This is meant to be used in + * lock step with the operand B tile iterator. Containing logic + * to figure out the operand B layout in the tensor core, + * and deliver each meta data element to its corresponding + * operand B element for dequantization. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace{ + +struct b32_pair{ + uint32_t a; + uint32_t b; +}; + +struct fp16_quad{ + cutlass::half_t a; + cutlass::half_t b; + cutlass::half_t c; + cutlass::half_t d; +}; + +struct b16_quad{ + int16_t a; + int16_t b; + int16_t c; + int16_t d; +}; + +union b64 { + uint64_t single; + b32_pair pair; + b16_quad quard; + fp16_quad fp16_quad; +}; + +static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); + +/// Convert packed 4b weights into fp16(weight + 16) +/// Current bit hacking only supports fp16, need to add bf16 later. +/// +template +CUTLASS_DEVICE +void weights2Half(cutlass::Array const &weights, + cutlass::Array& dest) +{ + static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const uint32_t* w_oct = reinterpret_cast(weights.data()); + + CUTLASS_PRAGMA_UNROLL + for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + // static_cast(16 + weight) + // 4b weights are prepacked into [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent weights + // are in different 16b half words, making it easier to convert to fp16. + asm volatile( + "{\n\t" + " shl.b32 %0, %4, 6;\n" + " shl.b32 %1, %4, 2;\n" + " shr.u32 %2, %4, 2;\n" + " shr.u32 %3, %4, 6;\n" + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 + " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" + "}\n" + : "=r"(dest_pair[0]), "=r"(dest_pair[1]), + "=r"(dest_pair[2]), "=r"(dest_pair[3]) + : "r"(*w_oct)); +#else + assert(0); +#endif + } + +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +// Traits to describe the layout of quantization meta data layout in a MMA fragment +// Since operand B is quantized on a per block basis, it's one meta data per block. + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ArchMmaOperator = ArchMmaOperator_; + + static_assert(Threads == 32, "This iterator should work in a warp only."); + + /// Shape of the curresponding operand B tile iterator + using TileShapeB = MatrixShape; + + // Tensor core operand B layout is a column major 4x8 tile, divided + // into 32 threads (T0 ~ T31) as shown below. Each element of the tile is 32b, + // so for fp16 it becomes 8 x 8, and int8 it becomes 16 x 8. + // T0 | T4 | T8 | T12 | T16 | T20 | T24 | T28 + // T1 | T5 | T9 | T13 | T17 | T21 | T25 | T29 + // T2 | T6 | T10 | T14 | T18 | T22 | T26 | T30 + // T3 | T7 | T11 | T15 | T19 | T23 | T27 | T31 + using CoreTile = layout::PitchLinearShape<4, 8>; + + /// Each thread holds a 32b fragment per tile: for half precision, it's 2 elements, 4 elements for int8 + static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits::value; + + /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension) + static int const kBTilesPerMma = + sizeof_bits::value * ArchMmaOperator::FragmentB::kElements / 32; + static_assert(kBTilesPerMma == 1 || kBTilesPerMma == 2, "Only support 1 or 2 operand B tiles per mma."); + + /// Each operand B tile iterator load covers a number of mma instructions + static int const kMmaIterationsB = WarpShapeB::kColumn / ArchMmaOperator::Shape::kN; + + /// Number of B elements a fragment of meta data should cover + static int const kExpandedSize = kNumBsPerCoreTileFragement * kBTilesPerMma * kMmaIterationsB; + + // Now we figure out how many meta data elements to load for each TileShapeB + + /// Number of meta elements per CoreTile. + static int const kCoreTileFragementSize = (kNumBsPerCoreTileFragement + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension + /// exceeds the tile depth, so two tiles share the same meta data + static int const kTilesPerMma = ((kBTilesPerMma == 2) && + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 : 1; + + /// stride to reach the meta data for the next CoreTile on the K dimension + static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Stride on N dimension should be the tile width, shrunk by blocking size on this dimension. + static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn; + + /// On N dimension, how many tiles share the same meta data + static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided; + + /// Each fragment should cover kMmaIterationsB number of mma intructions on the N dimension. + /// When blocking size on this dimension exceeds the tile width, multiple iterations + /// would share the same data. + static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats; + + static int const kFragementSize = kCoreTileFragementSize * kTilesPerMma * kMmaIterations; + + CUTLASS_DEVICE + static MatrixCoord lane_position(int lane_id) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + // The scale and offset tensors are prepacked to reduce the number of load instructions. + return make_Coord((lane_id % CoreTile::kContiguous) * 4, + lane_id / CoreTile::kContiguous); + } else { + return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, + lane_id / CoreTile::kContiguous); + } + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is to load quantization meta data for operand B from +/// shared memory to fragments (hopefully allocated to registers by compilers). +/// Examples of meta data include scale or offsets. The operand B matrix is +/// quantized on a per block basis, meaning one element of meta data per block. +/// +/// This is meant to be used in lock step with the operand B tile iterator. +/// So all parameters are logical positions in the operand B tiles. +/// The goal here is to deliver each meta data element to its corresponding +/// operand B element for dequantization. As a result, we need to figure +/// out the operand B layout in the tensor core. +/// +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class QuantBMetaMmaTensorOpTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using Layout = cutlass::layout::ColumnMajor; + using ElementOffset = ElementOffset_; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, + "Only support row blocking for column major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + + using AccessTypeScale = Array; + using AccessTypeOffset = Array; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)){} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + // The scale and offset tensors are prepacked to reduce the number of load instructions needed + const int row = lane_position_.row(); + const int column = lane_position_.column() / BlockingShape::kColumn; + + Array *dst_ptr = reinterpret_cast*>(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + + if constexpr(kHasOffset){ + Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + *dst_ptr_offset = *src_ptr_offset; + dst_ptr_offset++; + } + } + + } else { + // Other cases, offsets and scales are not prepacked. + + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + + AccessTypeScale* dst_ptr = reinterpret_cast(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + + if constexpr(kHasOffset){ + AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); + static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + if constexpr(kBTilesPerMma == 2){ + // Optimize for a special case of: + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const b64* scales_ptr = reinterpret_cast(scales.data()); + const ElementOffset* offsets_ptr = nullptr; + if constexpr(kHasOffset) { offsets_ptr = offsets.data(); } + + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + // dequantize: d = scale * (weight - offset) + // to use FMA, d = scale * weight + (scale * (-offset)) + + b64 offsets; + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets_ptr); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + + offsets_ptr += 4; + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); +#endif + } + + CUTLASS_PRAGMA_UNROLL + for (int n_r = 0; n_r < kNRepeats; n_r++){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %1, %3, %1, %5;\n" + "}\n" + : "+r"(dest_pair[0]), "+r"(dest_pair[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(offsets.pair.a), "r"(offsets.pair.b)); +#else + assert(0); +#endif + dest_pair += 2; + } + scales_ptr++; + } + + } else { + // unoptiomized path for other cases, very slow + int out_idx = 0; + ElementScale offset; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ElementScale s = scales[idx]; + if constexpr(kHasOffset){ + offset = s * static_cast(-16 - int(offsets[idx])); + } else { + offset = s * static_cast(-16-8); + } + dest[out_idx] = s * dest[out_idx] + offset; + out_idx++; + } + } + } + + } + + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using ElementOffset = ElementOffset_; + using Layout = cutlass::layout::RowMajor; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, + "Only support column blocking for row major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) + {} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); + + ElementScale* src_ptr = pointer_ + layout_({row, column}); + ElementScale* dst_ptr = frag.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr[n_idx] = src_ptr[n_idx * kNStride]; + } + + if constexpr(kHasOffset){ + ElementOffset* src_ptr_offset = pointer_offset_ + layout_offset_({row, column}); + ElementOffset* dst_ptr_offset = frag_offset.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr_offset[n_idx] = src_ptr_offset[n_idx * kNStride]; + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; + int idx = col * kMmaIterations + n_idx; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + ElementScale addon[kMmaIterationsB]; + if constexpr (kMmaIterationsB % 4 == 0) { + const b64* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + scales_ptr++; + p++; + addon_ptr += 2; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + assert(0); +#endif + scales_ptr++; + addon_ptr += 2; + } + } + } else if constexpr (kMmaIterationsB % 2 == 0) { + const uint32_t* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + + if constexpr (kHasOffset){ + // possible buffer over read 2 bytes here. + const uint32_t* p = reinterpret_cast(offsets.data()); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) + "r"(p[0])); +#else + assert(0); +#endif + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); +#else + assert(0); +#endif + } + } else { + // kMmaIterationsB == 1 + if constexpr(kHasOffset){ + uint8_t zp = offsets[0]; + addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); + } else { + addon[0] = scales[0] * static_cast(-16-8); + } + } + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; + dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; + out_idx += 2; + } + } + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..f29cedf326a44 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -0,0 +1,361 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/mma_tensor_op.h + * Templates implementing warp-level matrix multiply-accumulate operations + * targeting tensor cores. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class QuantBMmaTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + // warp B MatrixShape<64, 64>, + // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, + // instruction op shape cutlass::MatrixShape<16, 8>, + // kPartitionsK 1 + // FragmentB::kElements 32 + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; // cutlass::Array + + /// Storage for transformed B tile + /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded + /// we multiply the number of elements by 4. + /// TODO: make sure ArchMmaOperator::ElementB same as dequantized ElementB + /// and change the transform function below to perform dequantization + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + using ElementQScale = ElementQScale_; + using SmemLayoutQScale = SmemLayoutQScale_; + using QuantBlocking = QuantBlocking_; + + using ElementQOffset = ElementQOffset_; + using SmemLayoutQOffset = SmemLayoutQOffset_; + + /// Iterates over the quantization parameters in memory + using WarpQScaleShape = MatrixShape<(Shape::kK / QuantBlocking::kRow), (Shape::kN / QuantBlocking::kColumn)>; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(WarpQScaleShape::kCount > 0, "QuantBlocking too big to fit in a warp block!"); + + // TODO This is an expanding iterator, it needs to replicate the quantization parameters + // to all threads in the warp. + using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; + + using FragmentQScale = typename IteratorQMeta::FragmentScale; + using FragmentQOffset = typename IteratorQMeta::FragmentOffset; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + QuantBMmaTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentB &dst_B, + FragmentB const &B, + FragmentQScale const &scales, + FragmentQOffset const &offsets) const { + + Array const *ptr_B = + reinterpret_cast const *>(&B); + IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index a0405e32034ae..783a29d8a2055 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -17,7 +17,6 @@ #include #include "core/common/gsl.h" -// TODO!! Already have this in cuda, what about cpu code though? #if defined(_MSC_VER) #define ORT_FORCEINLINE __forceinline #else diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h new file mode 100644 index 0000000000000..6ea8b55505214 --- /dev/null +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -0,0 +1,203 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_quant_sm80.h + * + * Abstract: + * Oracle computation for blockwise 4b quantization for fp16 + * gemm kernel specifically for Ampere GPUs. This is used for + * testing the cuda kernel implementation in + * (test/providers/cuda/test_cases) + * and for testing the cuda op prepack code in (test/optimizer) + */ + +#pragma once + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace test { + +static inline void sm80_prepack_weights_ref( + int rows, + int columns, + const MatrixRef& tensor_weight, + const MatrixRef& tensor_weight_prepacked) { + ORT_ENFORCE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns, + "Unexpected tensor_weight shape! Expected: (", rows / 2, ", ", columns, "), Got: (", + tensor_weight.shape()[0], ", ", tensor_weight.shape()[1], ")."); + ORT_ENFORCE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2, + "tensor_weight_prepacked shape is not compatible with prepacked weight shape"); + + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } +} + +template < + typename ScaleElementT, + typename Layout, + typename QuantBlocking> +inline void sm80_prepack_quant_scales_ref( + int rows, + int columns, + const MatrixRef& tensor_scale, + const MatrixRef& tensor_scale_prepacked) { + ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn), + "Unexpected tensor_scale shape! Expected: (", + rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); + ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ScaleElementT) != 2 || QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_scales_ref should only be called for row-wise block quantization on 16b float values."); + } + + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } +} + +template +inline void sm80_prepack_quant_offsets_ref( + int rows, + int columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, + "Unexpected tensor_offset_prepacked shape (", + tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ORT_ENFORCE(tensor_offset.shape() == zp_shape, + "Unexpected tensor_offset shape (", + tensor_offset.shape()[0], ",", tensor_offset.shape()[1], + ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_offsets_ref should only be called for row-wise block quantization."); + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h new file mode 100644 index 0000000000000..bbe370675fc48 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -0,0 +1,188 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80.h + * + * Abstract: + * Bridge between gtest code and gemm kernel implementation. + * Gemm kernel requires CUTLASS header files, which causes strange + * compilation errors with RE2 header files, which are required + * by gtest. + */ + +#pragma once + +#include + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" +#include "test/cuda_host/blkq4_fp16_quant_sm80.h" + +namespace onnxruntime { +namespace cuda { +namespace test { + +Status sm80_supported(); + +/** + * @brief Generate a set of quantized weights, scales and offsets + * and dequantized weights for testing quantization and + * dequantization. All outputs are column major layout. + * + * @tparam ElementT The type of the dequantized weights. + * @tparam block_size The block size of the quantization. + * @tparam col_blocking Whether to use column blocking (all elements of + * a block comes from a single column) or row blocking + * @tparam has_offsets Whether to generate offsets. + * + * @param[in] rows The number of rows of the weight matrix. + * @param[in] columns The number of columns of the weight matrix. + * @param[out] dequants The dequantized weights, column major layout. + * @param[out] q_weights The quantized weights, column major layout. + * @param[out] q_scales The scales, column major layout. + * @param[out] q_zp The zero points, column major layout. + */ +template +inline void blkq4_weights_gen( + int rows, int columns, + std::vector& dequants, + std::vector& q_weights, + std::vector& q_scales, + std::vector& q_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + q_weights.resize(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + q_scales.resize(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + uint32_t v = dis(gen); + uint32_t m = (v % 63) + 1; + uint32_t e = (v >> 6) % 4; + q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + MatrixRef tensor_offset; + if constexpr (has_offsets) { + q_zp.resize(zp_shape.product()); + tensor_offset = MatrixRef( + q_zp, zp_shape); + for (int c = 0; c < zp_shape[1]; c++) { + for (int r = 0; r < zp_shape[0]; ++r) { + uint8_t v0 = dis(gen) % 16; + uint8_t v1 = 8; + if (r * 2 + 1 < meta_shape[0]) { + v1 = dis(gen) % 16; + } + tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); + } + } + } + + dequants.resize(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + uint8_t offset = 8; + if constexpr (has_offsets) { + if (scale_cord[0] % 2 == 0) { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; + } else { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; + } + } + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } +} + +template < + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc new file mode 100644 index 0000000000000..e687ae73e66f2 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -0,0 +1,330 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_test.cc + * + * Abstract: + * Test code for block-wise quantized 4b GEMM kernels. + * This part requires gtest header files, which do not play + * well with CUTLASS headers. + */ + +#include + +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas_q4.h" + +#include "blkq4_fp16_gemm_sm80.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +template +void testPrepack(int rows, int columns) { + using ElementT = MLFloat16; + constexpr int block_size = 32; + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + using LayoutQmeta = typename Base::LayoutQmeta; + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + rows, columns, dequants, q_weights, q_scales, q_zp); + + // for quantization tool, the input is row major, all outputs are column major + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + MatrixRef tensor_scale( + q_scales, meta_shape); + MatrixRef tensor_offset; + if constexpr (has_offset) { + tensor_offset = MatrixRef(q_zp, zp_shape); + } + + // for quantization tool, the input is row major, test weight gen output is column major + std::vector dequants_transposed(dequants.size()); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + MatrixRef tensor_dequant_transposed(dequants_transposed, make_Position(rows, columns)); + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + tensor_dequant_transposed.at(row, col) = tensor_dequant.at(row, col); + } + } + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape( + block_size, col_blocking, rows, columns, q_rows, q_cols); + // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes + EXPECT_EQ(q_rows, q_weight_shape[0]); + EXPECT_EQ(q_cols, q_weight_shape[1]); + + // + // Quantization tool outputs: + // + std::vector o_elements(q_rows * q_cols); + MatrixRef tensor_o_elements(o_elements, q_weight_shape); + + std::vector o_scales(meta_shape.product()); + MatrixRef tensor_o_scales(o_scales, meta_shape); + + std::vector o_zp(zp_shape.product()); + MatrixRef tensor_o_zp(o_zp, zp_shape); + + MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, + dequants_transposed.data(), block_size, + col_blocking, rows, columns, columns, nullptr); + for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { + for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { + EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) + << "quantized value mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + if (has_offset) { + uint8_t pair01 = tensor_o_zp.at(row / 2, col); + uint8_t expected_pair01 = tensor_offset.at(row / 2, col); + EXPECT_EQ(expected_pair01 & 0xf, pair01 & 0xf) + << "quantized offset mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(expected_pair01 >> 4, pair01 >> 4) + << "quantized offset mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) + << "quantized scale mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) + << "quantized scale mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } + + // + // Now we just setup quantized weights tensor_q_weight, quantization scale tensor_scale + // and quantization offset tensor_offset. The above tests just make sure our setup is + // consistent with quantization tool output. + // + // Next we test the prepack code + // + + std::vector packed_w_ref(q_weight_shape.product()); + MatrixRef tensor_packed_w_ref( + packed_w_ref, make_Position(rows, columns / 2)); + onnxruntime::test::sm80_prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); + + std::vector packed_w(q_weight_shape.product()); + MatrixRef tensor_packed_w( + packed_w, make_Position(rows, columns / 2)); + Base::prepack_weights(rows, columns, o_elements, packed_w); + + for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) + << "prepacked weights mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + std::vector packed_scales_ref(meta_shape.product()); + MatrixRef tensor_packed_s_ref = + make_MatrixRef(packed_scales_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_scales_ref( + rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); + } else { + for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) { + tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col); + } + } + } + + std::vector packed_scales(meta_shape.product()); + MatrixRef tensor_packed_s( + packed_scales, meta_shape); + Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); + + for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) + << "prepacked scales mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + if (has_offset) { + std::vector packed_zp_ref(meta_shape.product()); + MatrixRef tensor_packed_zp_ref = + make_MatrixRef(packed_zp_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_offsets_ref( + rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); + } else { + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_packed_zp_ref.at(row, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4; + } + } + } + } + + std::vector packed_zp(meta_shape.product()); + MatrixRef tensor_packed_zp( + packed_zp, meta_shape); + Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); + + for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) + << "prepacked offsets mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } +} + +// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 +TEST(BlkQ4_GEMM, PrepackSm80Test) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); +} + +TEST(BlkQ4_GEMM, Sm80RowBlockingTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 32, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 32, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 192); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 192); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576); +} + +TEST(BlkQ4_GEMM, Sm80ColBlockingTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(64, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(64, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576); +} + +TEST(BlkQ4_GEMM, Sm80SmallMTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + // // small m + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, true>(16, 1024, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, false>(16, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, true>(16, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, true>(16, 1024, 576); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu new file mode 100644 index 0000000000000..69c929d446ce4 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -0,0 +1,344 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_testcu.cu + * + * Abstract: + * Test code for invoking block-wise quantized 4b GEMM kernels. + * This part requires CUTLASS header files, which do not play + * well with gtest headers. + */ + +#include +#include +#include + +#include "core/mickey/blk_q4/f16_gemm_sm80.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "core/common/common.h" + +#include "blkq4_fp16_gemm_sm80.h" + +namespace onnxruntime { +namespace cuda{ +namespace test{ + +Status sm80_supported(){ + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::ostringstream ss; + ss << "Unable to obtain GPU device properties: " << cudaGetErrorString(error); + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::ostringstream ss; + ss << "Device compute capability mismatch, desired 8.0, actual " << props.major << "." << props.minor; + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + return Status::OK(); +} + +/** + * @brief Reference implementation of GEMM + * Copied directly from cutlass util/reference/device/gemm.h + * for the strange reason that compiler insists on asking + * for explicit stream argument in kernel launch. +*/ +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType +> +void compute_gemm_ref( + cutlass::gemm::GemmCoord problem_size, + ScalarType alpha, + cutlass::TensorRef tensor_a, + cutlass::TensorRef tensor_b, + ScalarType beta, + cutlass::TensorRef tensor_c, + cutlass::TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = cutlass::MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + cutlass::reference::device::kernel::Gemm< + cutlass::TensorRef, + cutlass::TensorRef, + cutlass::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + cutlass::multiply_add, + cutlass::NumericConverter + ><<>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Converting cutlass tensor to MatrixRef +// + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_MatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + auto* ptr = const_cast::type *>(tensor.host_data()); + return MatrixRef(ptr, tensor.capacity(), shape); +} + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_ConstMatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + return MatrixRef(tensor.host_data(), tensor.capacity(), shape); +} + +// +// Invoking the kernel +// + +template< + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k) { + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution<> dis(0, 8192); + + using ElementDequant = cutlass::half_t; + using QuantBlocking = + typename std::conditional, + cutlass::MatrixShape<1, block_size>>::type; + + using GemmRunner = BlkQ4F16GemmImpl; + + using ElementAccumulator = typename GemmRunner::ElementAccumulator; + using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; + using ElementInputA = typename GemmRunner::ElementInputA; + using ElementOutput = typename GemmRunner::ElementOutput; + using ElementW = typename GemmRunner::ElementW; + using ElementWPack = typename GemmRunner::ElementWPack; + using ElementQScale = typename GemmRunner::ElementQScale; + using ElementQOffset = typename GemmRunner::ElementQOffset; + + using LayoutInputA = typename GemmRunner::LayoutInputA; + using LayoutOutput = typename GemmRunner::LayoutOutput; + using LayoutInputWPack = typename GemmRunner::LayoutInputWPack; + using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; + + const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + const auto q_weight_shape = cutlass::make_Coord(problem_size.k()/2, problem_size.n()); + const auto meta_shape = cutlass::make_Coord(problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn); + + // + // Generate quantized and dequantizeed input matrix B [K, N] + // + static_assert(std::is_same::value); + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + problem_size.k(), problem_size.n(), dequants, q_weights, q_scales, q_zp); + + using PrepackT = onnxruntime::cuda::BlockwiseQuantization< + ElementDequant, + block_size, + 4, + column_wise_blocking>; + + std::vector packed_w(q_weight_shape.product()); + PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w); + std::vector packed_scales(meta_shape.product()); + PrepackT::prepack_quant_scales(problem_size.k(), problem_size.n(), q_scales, packed_scales); + std::vector packed_zp; + if constexpr (has_offsets) { + packed_zp.resize(meta_shape.product()); + PrepackT::prepack_quant_offsets(problem_size.k(), problem_size.n(), q_zp, packed_zp); + } + + cutlass::HostTensor tensor_a( + problem_size.mk()); // <- Create matrix A with dimensions M x K + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(4), + ElementInputA(-4), + 2); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(4), + ElementOutput(-4), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + + // + // Copy data from host to GPU... + // + thrust::device_vector d_packed_w(packed_w); + cutlass::TensorRef ref_W( + reinterpret_cast(d_packed_w.data().get()), + LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); + + thrust::device_vector d_packed_scales(packed_scales); + cutlass::TensorRef ref_scales( + d_packed_scales.data().get(), LayoutInputQScale::packed(meta_shape)); + + thrust::device_vector d_packed_zp(packed_zp); + cutlass::TensorRef ref_zp( + d_packed_zp.data().get(), LayoutInputQScale::packed(meta_shape)); + + tensor_a.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + + // run GEMM + cutlass::Status status; + if constexpr (has_offsets){ + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + ref_scales, ref_zp, + tensor_c.device_ref(), tensor_d.device_ref()); + } else { + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + ref_scales, + tensor_c.device_ref(), tensor_d.device_ref()); + } + ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + + // Running reference kernel + using ElementInputB = ElementInputA; + using LayoutInputB = cutlass::layout::ColumnMajor; + thrust::device_vector d_dequants(dequants); + cutlass::TensorRef ref_B( + d_dequants.data().get(), LayoutInputB::packed(problem_size.kn())); + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + compute_gemm_ref( + problem_size, + alpha, + tensor_a.device_ref(), + ref_B, + beta, + tensor_c.device_ref(), + tensor_ref_d.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + ORT_ENFORCE(passed, "Gemm kernel result wrong!"); +} + +template void run_blkq4_gemm<16, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc deleted file mode 100644 index aba2b0b2cb4a4..0000000000000 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "core/framework/float16.h" -#include "core/mickey/blk_q4/prepack_sm80.h" -#include "core/mlas/inc/mlas_q4.h" - -#include "gtest/gtest.h" - -namespace onnxruntime { -namespace test { - -void prepack_weights_ref( - int rows, - int columns, - const MatrixRef& tensor_weight, - const MatrixRef& tensor_weight_prepacked) { - EXPECT_TRUE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns); - EXPECT_TRUE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2); - - auto t0_base = make_Position(0, 0); - auto t1_base = make_Position(4, 0); - auto t2_base = make_Position(0, 8); - auto t3_base = make_Position(4, 8); - for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { - for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { - // Packing from a 8x16 tile to a 16x8 tile - auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); - auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto cord = make_Position(row, col); - auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 - uint8_t buf[4]; - buf[0] = tensor_weight.at(dtile_base + t0_base + cord); - buf[1] = tensor_weight.at(dtile_base + t1_base + cord); - buf[2] = tensor_weight.at(dtile_base + t2_base + cord); - buf[3] = tensor_weight.at(dtile_base + t3_base + cord); - - // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights - // are in different b16 register at the same positions. This makes it easier to convert to - // fp16x2 format in a b32 register - - tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); - tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); - } - } - } - } -} - -template < - typename ScaleElementT, - typename Layout, - typename QuantBlocking> -void prepack_quant_scales_ref( - int rows, - int columns, - const MatrixRef& tensor_scale, - const MatrixRef& tensor_scale_prepacked) { - EXPECT_TRUE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_scale_prepacked.shape() == tensor_scale.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - - for (int col = 0; col < tensor_scale.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); - tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); - tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); - tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); - } - } - } - } else { - // In all other cases, we don't prepack scale or offset - FAIL() << "Scale prepack only supported for 16b gemm with (1,n) quantization blocking"; - } -} - -template -void prepack_quant_offsets_ref( - size_t rows, - size_t columns, - MatrixRef tensor_offset, - MatrixRef tensor_offset_prepacked) { - // EXPECT_TRUE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_offset_prepacked.shape() == tensor_offset.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (QuantBlocking::kRow != 1) { - FAIL() << "Offsets prepack only supported for 16b gemm with (1,n) quantization blocking"; - } - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own - // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to - // convert to fp16x2 format in a b32 register - tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); - tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); - tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); - tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); - } - } - } - } -} - -template -void testPrepack(int rows, int columns, bool has_offset = true) { - using ElementT = MLFloat16; - constexpr int block_size = 32; - using Base = onnxruntime::cuda::BlockwiseQuantization< - ElementT, - block_size, - 4, - ColumnMajorQuantBlocking>; - - using QuantBlocking = typename Base::QuantBlocking; - using ElementW = typename Base::ElementW; - using LayoutWPack = typename Base::LayoutWPack; - using ElementQOffset = typename Base::ElementQOffset; - using LayoutQmeta = typename Base::LayoutQmeta; - - unsigned int seed = 28571; // Replace with desired seed value - std::seed_seq seq{seed}; - std::mt19937 gen(seq); - std::uniform_int_distribution<> dis(0, 8192); - - const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); - const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - - // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. - // - - std::vector q_weights(q_weight_shape.product()); - MatrixRef tensor_q_weight( - q_weights, make_Position(rows / 2, columns)); - int v = 7; - for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { - for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); - } - } - - std::vector q_scales(meta_shape.product()); - for (size_t i = 0; i < q_scales.size(); i++) { - q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f); - } - MatrixRef tensor_scale( - q_scales, meta_shape); - - std::vector q_zp(meta_shape.product()); - for (size_t i = 0; i < q_zp.size(); i++) { - q_zp[i] = dis(gen) % 16; - } - MatrixRef tensor_offset( - q_zp, meta_shape); - -#if 0 // debug - // Fill tensor_q_weight with the patterned data, easier to debug with print - int loop_val = 0; - int offset = 3; - for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) { - for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) { - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); - auto val = (loop_val + offset) % 256; - tensor_q_weight.at(weight_cord) = ElementW(val); - loop_val++; - if (loop_val == 256) { - loop_val = 0; - offset += 11; - } - } - } - } - } - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - int c = col * QuantBlocking::kColumn; - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - int r = row * QuantBlocking::kRow; - auto weight_cord = cutlass::make_Coord(r/2, c); - int w = 0; - if (r % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - tensor_scale.at({row, col}) = w; - tensor_offset.at({row, col}) = ElementQOffset(w); - } - } - - int fill_val = -512; - int factor = 1; - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); - fill_val++; - if (fill_val == 512) { - fill_val = -512; - factor += 1; - } - } - } - -#endif // debug - - std::vector dequants(rows * columns); - MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); - - // Dequantize weights and save into matrix B for reference - for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { - for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { - auto weight_cord = make_Position(row / 2, col); - auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8; - int w = 0; - if (row % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - float scale = float(tensor_scale.at(scale_cord)); - float dequant = scale * float(w - offset); - tensor_dequant.at(row, col) = ElementT(dequant); - // Prints for help debugging in case of test failure - // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); - } - } - - int q_rows, q_cols; - MlasBlockwiseQuantizedShape( - block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols); - // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes - EXPECT_EQ(q_rows, q_weight_shape[0]); - EXPECT_EQ(q_cols, q_weight_shape[1]); - - // - // Quantization tool outputs: - // - std::vector o_elements(q_rows * q_cols); - MatrixRef tensor_o_elements(o_elements, q_weight_shape); - - std::vector o_scales(meta_shape.product()); - MatrixRef tensor_o_scales(o_scales, meta_shape); - - std::vector o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true); - MatrixRef tensor_o_zp( - o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); - - MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, - tensor_dequant.data().data(), block_size, - ColumnMajorQuantBlocking, rows, columns, columns, nullptr); - for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { - for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { - EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) - << "quantized value mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - for (int col = 0; col < meta_shape[1]; ++col) { - for (int row = 0; row < meta_shape[0]; row += 2) { - if (has_offset) { - uint8_t pair01 = tensor_o_zp.at(row / 2, col); - EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf) - << "quantized offset mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4) - << "quantized offset mismatch at [" << row + 1 << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) - << "quantized scale mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) - << "quantized scale mismatch at [" << row + 1 << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - } - - // - // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight, - // quantization scale tensor_scale and quantization offset tensor_offset. The above - // testing just make sure our test setup is consistent with quantization tool output. - // - // Next we test the prepack code - // - - std::vector packed_w_ref(q_weight_shape.product()); - MatrixRef tensor_packed_w_ref( - packed_w_ref, make_Position(rows, columns / 2)); - prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); - - std::vector packed_w(q_weight_shape.product()); - MatrixRef tensor_packed_w( - packed_w, make_Position(rows, columns / 2)); - Base::prepack_weights(rows, columns, o_elements, packed_w); - - for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) - << "prepacked weights mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - std::vector packed_scales_ref(meta_shape.product()); - MatrixRef tensor_packed_s_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) - : tensor_scale; - if (Base::ShouldRearrangeMeta) { - prepack_quant_scales_ref( - rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); - } - - std::vector packed_scales(meta_shape.product()); - MatrixRef tensor_packed_s( - packed_scales, meta_shape); - Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); - - for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) - << "prepacked scales mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - if (has_offset) { - std::vector packed_zp_ref(meta_shape.product()); - MatrixRef tensor_packed_zp_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) - : tensor_offset; - if (Base::ShouldRearrangeMeta) { - prepack_quant_offsets_ref( - rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); - } - - std::vector packed_zp(meta_shape.product()); - MatrixRef tensor_packed_zp( - packed_zp, meta_shape); - Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); - - for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) - << "prepacked offsets mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - } -} - -// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 -TEST(BlkQ4_GEMM, PrepackSm80Test) { - testPrepack(32, 32); - testPrepack(32, 32, false); - testPrepack(32, 32); - testPrepack(32, 32, false); - testPrepack(32, 64); - testPrepack(32, 128); - testPrepack(32, 256); - testPrepack(64, 32); - testPrepack(128, 32); - testPrepack(256, 32); - testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); - testPrepack(32, 64); - testPrepack(32, 128); - testPrepack(32, 256); - testPrepack(64, 32); - testPrepack(128, 32); - testPrepack(256, 32); - testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); -} - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 5505d689381c9..8dfaaedcbb378 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -29,7 +29,7 @@ TEST(TestDeferredRelease, WithArena) { AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; @@ -71,7 +71,7 @@ TEST(TestDeferredRelease, WithoutArena) { // For details, see CUDAPinnedAllocator in cuda_allocator.cc. // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; From 1e78bcea6011ac43093bb08a647cf3717d73047a Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Mar 2024 13:33:01 -0800 Subject: [PATCH 6/7] Implement CUDA IsInf-10,20 (#19772) ### Description Implment IsInf-10,20 for CUDA. Add FP16 types also on CPU. ### Motivation and Context Certain models lag in performance due to IsInf not available on CUDA. --- docs/OperatorKernels.md | 4 +- .../core/framework/data_types_internal.h | 2 +- .../core/providers/cpu/tensor/isinf.cc | 64 ++++++++++--- .../core/providers/cuda/cu_inc/common.cuh | 94 +++++++++++++++++++ onnxruntime/core/providers/cuda/cuda_common.h | 18 ++++ .../providers/cuda/cuda_execution_provider.cc | 5 + .../cuda/math/unary_elementwise_ops.cc | 38 ++++++++ .../cuda/math/unary_elementwise_ops.h | 12 +++ .../cuda/math/unary_elementwise_ops_impl.cu | 38 ++++++++ .../cuda/math/unary_elementwise_ops_impl.h | 15 +++ .../core/providers/rocm/cu_inc/common.cuh | 94 +++++++++++++++++++ .../providers/rocm/rocm_execution_provider.cc | 9 ++ .../test/providers/cpu/tensor/isinf_test.cc | 42 +++++++++ 13 files changed, 420 insertions(+), 15 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 71b0def659741..4514a85531d6b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -160,7 +160,7 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| @@ -631,6 +631,8 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index fbeee8a2aedc5..3a3b5cb6888f2 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -305,7 +305,7 @@ class CallableDispatchableHelper { return 0; } - void CheckCalledOnce() { + void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); } }; diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index 1b449f46927a2..9d18d1fa62288 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( using IsInfTypesOpset20 = TypeList< float, - double + double, + MLFloat16, + BFloat16 #if !defined(DISABLE_FLOAT8_TYPES) , Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ @@ -76,10 +78,8 @@ ONNX_CPU_OPERATOR_KERNEL( IsInf); IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { - Status status = info.GetAttr("detect_positive", &detect_positive_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); - status = info.GetAttr("detect_negative", &detect_negative_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + detect_positive_ = info.GetAttrOrDefault("detect_positive", 1); + detect_negative_ = info.GetAttrOrDefault("detect_negative", 1); opset_ = info.node().SinceVersion(); } @@ -87,29 +87,67 @@ namespace isinf_internal { template struct ComputeDispatchTarget { void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { - const auto total_items = X.Shape().Size(); + auto input_data = X.DataAsSpan(); auto output_data = Y.MutableData(); if (detect_positive && detect_negative) { EigenMap(Y) = EigenMap(X).array().isInf(); } else if (detect_positive) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == std::numeric_limits::infinity()); }); } else if (detect_negative) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == -std::numeric_limits::infinity()); }); } else { // all false - memset(output_data, false, onnxruntime::narrow(total_items)); + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); } } }; diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 66794f88d8670..bba9178348132 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -438,6 +438,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef CUDA_LONG diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 41c999bacee13..61da125b40953 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -70,6 +70,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E4M3FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + template <> class ToCudaType { public: @@ -79,6 +88,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E5M2FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + #endif inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8ba282031a5d4..3c0930638a205 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); @@ -1342,6 +1343,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1739,6 +1741,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2250,6 +2254,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index fd8b69d7bd2f5..00de1b37f3302 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -71,6 +71,44 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa return Status::OK(); \ } +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsInf, + kOnnxDomain, + 10, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_OPERATOR_KERNEL_EX( + IsInf, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) { + detect_positive_ = static_cast(info.GetAttrOrDefault("detect_positive", 1)); + detect_negative_ = static_cast(info.GetAttrOrDefault("detect_negative", 1)); + opset_ = info.node().SinceVersion(); +} + +Status IsInf::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_, + p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + return Status::OK(); +} + #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 775b78c43a736..3b7d6df7221b7 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once + #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { @@ -119,5 +120,16 @@ class Sign final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +class IsInf final : public UnaryElementwise { + public: + explicit IsInf(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool detect_positive_{true}; + bool detect_negative_{true}; + int opset_; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 73c5ac80756be..fd8f7929d4426 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -11,6 +11,7 @@ #endif namespace onnxruntime { + namespace cuda { #define OP(name, expr) \ @@ -284,5 +285,42 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2) #endif +namespace isinf_details { +template +struct IsInf_DispFunc { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, + bool detect_positive, bool detect_negative, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + if (detect_positive && detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_positive) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } + } +}; + +} // namespace isinf_details + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count) { + if (op_set < 20) { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } else { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 608a81a24cf4f..a606d479bc79b 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -137,5 +137,20 @@ void Impl_CastSat( #endif +// IsInf + +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16 +#endif + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count); } // namespace cuda + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5f966ac746fcb..f3685606c17f5 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -335,6 +335,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef HIP_LONG diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 0265c06b9a938..4a679b790ee40 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -793,6 +793,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); @@ -1342,6 +1343,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, R class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape); +// Opset 20 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -1738,6 +1742,8 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2294,6 +2300,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // opset 20 + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc index 2e583c5d2547b..bd97306142f18 100644 --- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc @@ -99,6 +99,48 @@ TEST(IsInfTest, test_isinf_negative_double20) { run_is_inf_test(20, 0, 1, input, output); } +TEST(IsInfTest, test_isinf_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + +TEST(IsInfTest, test_isinf_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + #if !defined(DISABLE_FLOAT8_TYPES) TEST(IsInfTest, test_Float8E4M3FN) { std::initializer_list input = { From d9730c7f43437070eba28d8dcdd9f94c102265ab Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:39:36 -0800 Subject: [PATCH 7/7] [TensorRT EP] Fix bug for DDS output handling for empty tensor (#19575) When the DDS output is empty tensor (i.e. any of the dimension is 0), TRT EP won't perform either cudaMemcpyAsync() nor cuda::Impl_Cast(), to prevent accidentally overwriting other location that might belong to other tensors. This PR also refactors the code to only allocate single bytes for all empty tensors. #TODO: add unit tests to cover the DDS code paths or doing more testing with concurrent,sequential, threaded faster-rcnn using onnx_test_runner and verifying outputs --------- Co-authored-by: Chi Lo --- cmake/deps.txt | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 465 ++++++------------ 2 files changed, 160 insertions(+), 309 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 9cba25b00157d..9630b6185fcf6 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -37,8 +37,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 -#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 +#use the commit of Final DDS removal. DDS output is now supported by ORT TRT. +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bacfaaa951653cd4e72efe727a543567cb38f7de.zip;26434329612e804164ab7baa6ae629ada56c1b26 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 81346671f2aad..157cd0a200b35 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -717,6 +717,77 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + data = const_cast(input_tensor_ptr); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + data = scratch_buffers.back().get(); \ + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + buffers[output_name] = output_tensor_ptr; \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = static_cast(elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = 1; \ + } \ + break; \ + } + +#define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \ + } \ + break; \ + } + +#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ + } \ + break; \ + } + /* * Set TensorRT execution context input. * @@ -737,6 +808,17 @@ Status BindContextInput(Ort::KernelContext& ctx, auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shapes = tensor_info.GetShape(); const auto tensor_type = tensor_info.GetElementType(); + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + const auto elem_cnt = tensor_info.GetElementCount(); if (trt_engine->isShapeInferenceIO(input_name)) { // Get the shape value of "shape tensor" @@ -765,113 +847,24 @@ Status BindContextInput(Ort::KernelContext& ctx, ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); } - // Bind "execution tensor" input buffers + + // Bind "execution tensor" input buffer + // + // Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses. + // Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte. + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors void* data = nullptr; switch (tensor_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - data = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(int32_t))); - data = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - data = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(float))); - data = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); - } - break; - } + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // Cast int64 input to int32 input because TensorRT doesn't support int64 + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) + // Cast double input to float because TensorRT doesn't support double + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -884,7 +877,7 @@ Status BindContextInput(Ort::KernelContext& ctx, } /* - * Set TensorRT execution context output. + * Bind TensorRT execution context output. * * Please note that the "data-depedent shape" output needs corresponding allocator provided. * @@ -912,7 +905,6 @@ Status BindContextOutput(Ort::KernelContext& ctx, size_t i, std::unordered_map& output_tensors, std::unordered_map& output_dim_sizes, - std::unordered_set& dds_output_set, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, @@ -920,142 +912,47 @@ Status BindContextOutput(Ort::KernelContext& ctx, // Get output shape nvinfer1::Dims dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; - bool is_dds_output = false; + bool is_DDS = false; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { // data-dependent shape if (dims.d[j] == -1) { - is_dds_output = true; - dds_output_set.emplace(output_name); + is_DDS = true; break; } output_shapes[j] = dims.d[j]; } + auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end(); + // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer. // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output. // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output, // which we defer allocation until the size is known and don't call IExecution::setTensorAddress) // // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. - if (is_dds_output) { - if (dds_output_allocator_map.find(output_name) == dds_output_allocator_map.end()) { + if (is_DDS || known_DDS) { + if (!known_DDS) { auto allocatorPtr = std::make_unique(); trt_context->setOutputAllocator(output_name, allocatorPtr.get()); dds_output_allocator_map[output_name] = std::move(allocatorPtr); - } else { - trt_context->setOutputAllocator(output_name, dds_output_allocator_map[output_name].get()); } } else { output_tensors[i] = ctx.GetOutput(output_index, output_shapes); auto& output_tensor = output_tensors[i]; + const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + switch (output_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = 1; - } else { - SafeInt output_dim_size(1); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dims.d[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = 1; - } else { - SafeInt output_dim_size(1); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dims.d[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; - } - break; - } + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // Allocate int32 CUDA memory for int64 output type because TensorRT doesn't support int64 + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) + // Allocate float CUDA memory for double output type because TensorRT doesn't support double + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -1068,10 +965,13 @@ Status BindContextOutput(Ort::KernelContext& ctx, } /* - * Set ORT kernel context Output. + * Bind ORT kernel context Output. * - * Note: In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime. + * In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime. * Once the output has been put in the allocation buffer, ORT calls this function to bind the allocation to ORT kernel context output. + * + * Note: Current approach of setting the ORT kernel context output is copying the output data from allocation buffer to ORT context output address which is not optimal, + * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support. */ Status BindKernelOutput(Ort::KernelContext& ctx, OrtMemoryInfo* mem_info, @@ -1083,93 +983,46 @@ Status BindKernelOutput(Ort::KernelContext& ctx, auto allocator = allocator_map[output_name].get(); auto& shape = allocator->getOutputShape(); auto output_tensor = ctx.GetOutput(output_index, shape); + + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + /* + * Copy output data from allocation buffer to ORT kernel context output location or + * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. + * + * Note: + * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, + * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. + * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we + * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, + * and within the same stream, operations are guaranteed to be executed in order. + */ switch (output_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(float), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint16_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(bool), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int8_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint8_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // The allocation buffer holds the INT32 output data since TRT doesn't support INT64 but INT32. - // So, we need to cast the data from INT32 to INT64 and then set INT64 output data to kernel context. - SafeInt output_dim_size(1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= shape[i]; - } - } - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), output_dim_size); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // The allocation buffer holds the FLOAT output data since TRT doesn't support DOUBLE but FLOAT. - // So, we need to cast the data from FLOAT to DOUBEL and then set DOUBLE output data to kernel context. - SafeInt output_dim_size(1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= shape[i]; - } - } - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), output_dim_size); - } - break; - } + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. + CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) + // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. + CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); return Status::OK(); } @@ -3513,7 +3366,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView output_tensors.reserve(num_outputs); std::unordered_map output_dim_sizes; output_dim_sizes.reserve(num_outputs); - std::unordered_set dds_output_set; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3531,7 +3383,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers); + dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3590,7 +3442,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView output_type = iter->second; } - if (dds_output_set.find(output_name) != dds_output_set.end()) { + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { size_t output_index = 0; const auto& index_iter = output_indexes.find(output_name); if (index_iter != output_indexes.end()) { @@ -3806,7 +3658,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_tensors.reserve(num_outputs); std::unordered_map output_dim_sizes; output_dim_sizes.reserve(num_outputs); - std::unordered_set dds_output_set; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3824,7 +3675,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con } Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers); + dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3883,7 +3734,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_type = iter->second; } - if (dds_output_set.find(output_name) != dds_output_set.end()) { + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { size_t output_index = 0; const auto& index_iter = output_indexes.find(output_name); if (index_iter != output_indexes.end()) {