Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Jan 25, 2024
1 parent 8aeb46c commit 9c92e1a
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 230 deletions.
16 changes: 8 additions & 8 deletions onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ template <
/// Element type for quant offsets
typename ElementQOffset_,
/// Layout type for quant scales and offsets
typename LayoutQScale_,
typename LayoutQMeta_,
/// Blocking dimensions for quantization
typename QuantBlocking_,
/// Element type for C and D matrix operands
Expand Down Expand Up @@ -180,7 +180,7 @@ class QuantBGemm {
"InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout.");
using ElementQScale = ElementQScale_;
using ElementQOffset = ElementQOffset_;
using LayoutQScale = LayoutQScale_;
using LayoutQMeta = LayoutQMeta_;
using QuantBlocking = QuantBlocking_;
static constexpr bool kHasQOffset = !(std::is_same<ElementQOffset, std::monostate>::value);

Expand All @@ -197,7 +197,7 @@ class QuantBGemm {
kAlignmentB,
ElementQScale,
ElementQOffset,
LayoutQScale,
LayoutQMeta,
QuantBlocking,
ElementC,
LayoutC,
Expand Down Expand Up @@ -230,8 +230,8 @@ class QuantBGemm {
TensorRef<ElementB const, LayoutB> ref_B;
TensorRef<ElementC const, LayoutC> ref_C;
TensorRef<ElementC, LayoutC> ref_D;
TensorRef<ElementQScale const, LayoutQScale> ref_Qscale;
TensorRef<ElementQOffset const, LayoutQScale> ref_Qoffset;
TensorRef<ElementQScale const, LayoutQMeta> ref_Qscale;
TensorRef<ElementQOffset const, LayoutQMeta> ref_Qoffset;

typename EpilogueOutputOp::Params epilogue;

Expand All @@ -258,7 +258,7 @@ class QuantBGemm {
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementQScale const, LayoutQScale> ref_Qscale_,
TensorRef<ElementQScale const, LayoutQMeta> ref_Qscale_,
TensorRef<ElementC const, LayoutC> ref_C_,
TensorRef<ElementC, LayoutC> ref_D_,
typename EpilogueOutputOp::Params epilogue_ =
Expand All @@ -279,8 +279,8 @@ class QuantBGemm {
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementQScale const, LayoutQScale> ref_Qscale_,
TensorRef<ElementQOffset const, LayoutQScale> ref_Qoffset_,
TensorRef<ElementQScale const, LayoutQMeta> ref_Qscale_,
TensorRef<ElementQOffset const, LayoutQMeta> ref_Qoffset_,
TensorRef<ElementC const, LayoutC> ref_C_,
TensorRef<ElementC, LayoutC> ref_D_,
typename EpilogueOutputOp::Params epilogue_ =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ template <
/// Element type for quant offsets
typename ElementQOffset_,
/// Layout type for quant scales and offsets
typename LayoutQScale_,
typename LayoutQMeta_,
/// Blocking dimensions for quantization
typename QuantBlocking_,
/// Access granularity of quant scales in units of elements
Expand Down Expand Up @@ -167,7 +167,7 @@ template <
/// Element type for quant offsets
typename ElementQOffset,
/// Layout type for quant scales
typename LayoutQScale,
typename LayoutQMeta,
/// Blocking dimensions for quantization
typename QuantBlocking,
/// Access granularity of quant scales in units of elements
Expand Down Expand Up @@ -207,7 +207,7 @@ template <
typename PermuteBLayout
>
struct DefaultQuantBGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementQScale, ElementQOffset, LayoutQScale, QuantBlocking,
ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking,
ElementC, LayoutC, ElementAccumulator,
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape,
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
Expand All @@ -221,7 +221,7 @@ struct DefaultQuantBGemm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAli
/// Define the threadblock-scoped matrix multiply-accumulate
using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma<
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
ElementQScale, ElementQOffset, LayoutQScale, QuantBlocking,
ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking,
ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80,
ThreadblockShape, WarpShape, InstructionShape, Stages,
Operator, false, GatherA, GatherB,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ template <
/// Element type for quant offsets
typename ElementQOffset_,
/// Layout for quant scales and offsets
typename LayoutQScale_,
typename LayoutQMeta_,
/// Blocking size for quantization
typename QuantBlocking_,
/// Element type for internal accumulation
Expand Down Expand Up @@ -138,7 +138,7 @@ template <
/// Element type for quant offsets
typename ElementQOffset,
/// Layout for quant scales and offsets
typename LayoutQScale,
typename LayoutQMeta,
/// Blocking size for quantization
typename QuantBlocking,
/// Element type for internal accumulation
Expand Down Expand Up @@ -168,7 +168,7 @@ template <
>
struct DefaultQuantBMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
kAlignmentB, ElementQScale, ElementQOffset,
LayoutQScale, QuantBlocking,
LayoutQMeta, QuantBlocking,
ElementAccumulator, LayoutC,
arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,
InstructionShape, Stages, Operator, false,
Expand All @@ -191,7 +191,7 @@ struct DefaultQuantBMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore<
ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQScale, QuantBlocking,
ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking,
ElementAccumulator, LayoutC, arch::OpClassTensorOp,
Stages, Operator, false, CacheOpA, CacheOpB>;

Expand All @@ -218,14 +218,14 @@ struct DefaultQuantBMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
using IteratorQScale =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
typename MmaCore::ThreadblockQShape,
ElementQScale, LayoutQScale, 0, ThreadMapQScale, AccessTypeQScale>;
ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>;

using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset;
using AccessTypeQOffset =
cutlass::Array<ElementQOffset, ThreadMapQOffset::kElementsPerAccess>;
using IteratorQOffset =
cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator<
typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQScale,
typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta,
0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>;

// Define the threadblock-scoped multistage matrix multiply
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ template <
/// Element data type of quant offset
typename ElementQOffset,
/// Layout of quant scale
typename LayoutQScale,
typename LayoutQMeta,
/// Blocking dimensions for quantization
typename QuantBlocking,
/// Data type of accumulator
Expand Down Expand Up @@ -157,7 +157,7 @@ template <
/// Element data type of quant offset
typename ElementQOffset_,
/// Layout of quant scale
typename LayoutQScale_,
typename LayoutQMeta_,
/// Blocking dimensions for quantization
typename QuantBlocking_,
/// Data type of accumulator
Expand All @@ -174,7 +174,7 @@ template <
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultQuantBMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::RowMajor, ElementB_, layout::ColumnMajor,
ElementQScale_, ElementQOffset_, LayoutQScale_, QuantBlocking_,
ElementQScale_, ElementQOffset_, LayoutQMeta_, QuantBlocking_,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
using Shape = Shape_;
Expand All @@ -187,7 +187,7 @@ struct DefaultQuantBMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,

using ElementQScale = ElementQScale_;
using ElementQOffset = ElementQOffset_;
using LayoutQScale = LayoutQScale_;
using LayoutQMeta = LayoutQMeta_;
using QuantBlocking = QuantBlocking_;

using ElementC = ElementC_;
Expand Down Expand Up @@ -269,8 +269,8 @@ struct DefaultQuantBMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
MatrixShape<Shape::kK/2, Shape::kN/2>, ElementB, SmemLayoutB, 1,
IteratorThreadMapB>;

using SmemLayoutQScale = LayoutQScale;
using SmemLayoutQOffset = LayoutQScale;
using SmemLayoutQScale = LayoutQMeta;
using SmemLayoutQOffset = LayoutQMeta;

/// Threadblock-level quantization meta data shape
using ThreadblockQShape = MatrixShape<Shape::kK / QuantBlocking::kRow, Shape::kN / QuantBlocking::kColumn>;
Expand All @@ -279,12 +279,12 @@ struct DefaultQuantBMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!");
static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1,
"Only support single column or row quantize blocking!");
static_assert(QuantBlocking::kColumn != 1 || std::is_same<LayoutQScale, layout::RowMajor>::value,
static_assert(QuantBlocking::kColumn != 1 || std::is_same<LayoutQMeta, layout::RowMajor>::value,
"Quant scale matrix's major dimension must have more elements, to facilitate fast loading!");

/// Threadblock-level quantization meta data shape in pitch-linear layout
using TBQPitchLinearShape = typename std::conditional<
std::is_same<LayoutQScale, layout::RowMajor>::value,
std::is_same<LayoutQMeta, layout::RowMajor>::value,
layout::PitchLinearShape<ThreadblockQShape::kColumn, ThreadblockQShape::kRow>,
layout::PitchLinearShape<ThreadblockQShape::kRow, ThreadblockQShape::kColumn>>::type;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class QuantBMmaBase {
}

CUTLASS_HOST_DEVICE
static typename Operator::SmemLayoutQScale LayoutQScale() {
static typename Operator::SmemLayoutQScale LayoutQMeta() {
return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn});
}

Expand All @@ -301,7 +301,7 @@ class QuantBMmaBase {
/// Returns a TensorRef to the quantization scales
CUTLASS_HOST_DEVICE
TensorRefQScale operand_QScale_ref() {
return TensorRefQScale{operand_QScale.data(), LayoutQScale()};
return TensorRefQScale{operand_QScale.data(), LayoutQMeta()};
}

CUTLASS_HOST_DEVICE
Expand Down
Loading

0 comments on commit 9c92e1a

Please sign in to comment.