diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h index dfd1032b42c68..8b6bac8c5099a 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -327,7 +327,7 @@ class QuantBMmaBase { typename Operator::IteratorB warp_tile_iterator_B_; /// Iterator to load a warp-scoped tile of quant scales from shared memory - typename Operator::IteratorQScale warp_tile_iterator_QScale_; + typename Operator::IteratorQMeta warp_tile_iterator_QScale_; public: diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 5d05016b8693a..c142ddb132629 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -197,7 +197,7 @@ class QuantBMetaMmaTile{ // 16b gemm (kNumBsPerCoreTileFragement == 2) // 2 B operand tiles per mma (kBTilesPerMma == 2) // (1,n) quantization blocking - // The weight and offset tensor is prepacked to reduce load instructions. + // The scale and offset tensors are prepacked to reduce the number of load instructions. return make_Coord((lane_id % CoreTile::kContiguous) * 4, lane_id / CoreTile::kContiguous); } else { @@ -356,7 +356,7 @@ class QuantBMetaMmaTensorOpTileIterator -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - return converter(source); - } -}; - -template -struct ConvertAndPack { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - return source; - } -}; - -template -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - Array tmp; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); - tmp[i] = source[idx]; - } - - return converter(tmp); - } -}; - -template -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - Array tmp; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); - tmp[i] = source[idx]; - } - - return converter(tmp); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace internal - -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> @@ -292,13 +217,13 @@ class QuantBMmaTensorOp { // TODO This is an expanding iterator, it needs to replicate the quantization parameters // to all threads in the warp. - using IteratorQScale = QuantBMetaMmaTensorOpTileIterator< + using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQOffset, ArchMmaOperator, kThreadCount, kPartitionsK>; - using FragmentQScale = typename IteratorQScale::FragmentScale; - using FragmentQOffset = typename IteratorQScale::FragmentOffset; + using FragmentQScale = typename IteratorQMeta::FragmentScale; + using FragmentQOffset = typename IteratorQMeta::FragmentOffset; /// Number of mma operations performed using MmaIterations = MatrixShape< @@ -419,7 +344,7 @@ class QuantBMmaTensorOp { Array const *ptr_B = reinterpret_cast const *>(&B); - IteratorQScale::dequant(scales, offsets, *ptr_B, dst_B); + IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); } };