diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index ab59cc2c59b75..6ea8b55505214 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -83,76 +83,10 @@ inline void sm80_prepack_quant_scales_ref( // 16b gemm (2 elements per 32b register, operand tile shape 8x8) // 2 B operand tiles per mma instruction stacked on k dimension // (1,n) quantization blocking - if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - - for (int col = 0; col < tensor_scale.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); - tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); - tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); - tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); - } - } - } - } else { - // In all other cases, we don't prepack scale or offset - std::copy(tensor_scale.data().begin(), tensor_scale.data().end(), tensor_scale_prepacked.data().begin()); + if constexpr (sizeof(ScaleElementT) != 2 || QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_scales_ref should only be called for row-wise block quantization on 16b float values."); } -} -template -inline void sm80_expand_prepack_quant_offsets_ref( - int rows, - int columns, - MatrixRef tensor_offset, - MatrixRef tensor_offset_prepacked) { - const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); - const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); - ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, - "Unexpected tensor_offset_prepacked shape (", - tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], - ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); - ORT_ENFORCE(tensor_offset.shape() == zp_shape, - "Unexpected tensor_offset shape (", - tensor_offset.shape()[0], ",", tensor_offset.shape()[1], - ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (QuantBlocking::kRow != 1) { - return; - } // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, @@ -168,7 +102,7 @@ inline void sm80_expand_prepack_quant_offsets_ref( // // We need to deliver quantization scale and offset elements to the corresponding threads, // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // needs two separate loads for a mma instruction, due to the tile fragment layout shown // above. To reduce the number of loads, we rearrange each column as below, so we can use // a single load to load fragments for two tiles: // T0 T0 @@ -179,22 +113,16 @@ inline void sm80_expand_prepack_quant_offsets_ref( // T1 T2 // T2 T3 // T3 T3 - if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own - // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to - // convert to fp16x2 format in a b32 register - uint8_t pair01 = tensor_offset.at(src_idx / 2, col); - uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); - tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; - tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; - tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; - tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; - } + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); } } } @@ -206,18 +134,23 @@ inline void sm80_prepack_quant_offsets_ref( int columns, MatrixRef tensor_offset, MatrixRef tensor_offset_prepacked) { - ORT_ENFORCE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn), - "Unexpected tensor_offset shape! Expected: (", - rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); - ORT_ENFORCE(tensor_offset_prepacked.shape() == tensor_offset.shape()); + const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, + "Unexpected tensor_offset_prepacked shape (", + tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ORT_ENFORCE(tensor_offset.shape() == zp_shape, + "Unexpected tensor_offset shape (", + tensor_offset.shape()[0], ",", tensor_offset.shape()[1], + ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); // Only prepacking scale and offset tensors for a often used special case: // 16b gemm (2 elements per 32b register, operand tile shape 8x8) // 2 B operand tiles per mma instruction stacked on k dimension // (1,n) quantization blocking if constexpr (QuantBlocking::kRow != 1) { - std::copy(tensor_offset.data().begin(), tensor_offset.data().end(), tensor_offset_prepacked.data().begin()); - return; + ORT_THROW("sm80_prepack_quant_offsets_ref should only be called for row-wise block quantization."); } // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use @@ -234,7 +167,7 @@ inline void sm80_prepack_quant_offsets_ref( // // We need to deliver quantization scale and offset elements to the corresponding threads, // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // needs two separate loads for a mma instruction, due to the tile fragment layout shown // above. To reduce the number of loads, we rearrange each column as below, so we can use // a single load to load fragments for two tiles: // T0 T0 @@ -246,18 +179,20 @@ inline void sm80_prepack_quant_offsets_ref( // T2 T3 // T3 T3 if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { + for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { for (int thread_id = 0; thread_id < 4; thread_id++) { const int dst_idx = row_blk + thread_id * 4; const int src_idx = row_blk + thread_id * 2; // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to // convert to fp16x2 format in a b32 register - tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); - tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); - tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); - tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; } } } diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 897cf3fc774d3..f987c4a7c507d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -198,7 +198,7 @@ void testPrepack(int rows, int columns) { MatrixRef tensor_packed_zp_ref = make_MatrixRef(packed_zp_ref, meta_shape); if constexpr (Base::ShouldRearrangeMeta) { - onnxruntime::test::sm80_expand_prepack_quant_offsets_ref( + onnxruntime::test::sm80_prepack_quant_offsets_ref( rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); } else { for (int col = 0; col < meta_shape[1]; ++col) {