Skip to content

Commit

Permalink
ptx for row blocking no zero-point
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Feb 23, 2024
1 parent 7d5d5ca commit b9f9cb7
Showing 1 changed file with 20 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,

static constexpr bool kHasOffset = !(std::is_same<ElementOffset, std::monostate>::value);

static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1,
"Only support row blocking for column major layout");

using MetaTile = QuantBMetaMmaTile<WarpShapeB, BlockingShape, ArchMmaOperator, Threads>;

/// Number of MMA instructions for this tile
Expand Down Expand Up @@ -350,12 +353,11 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
CUTLASS_HOST_DEVICE
void load(FragmentScale &frag, FragmentOffset &frag_offset) {
if constexpr(kNumBsPerCoreTileFragement == 2
&& kBTilesPerMma == 2
&& BlockingShape::kRow == 1){
&& kBTilesPerMma == 2){
// Optimize for a special case of:
// 16b gemm (kNumBsPerCoreTileFragement == 2)
// 2 B operand tiles per mma (kBTilesPerMma == 2)
// (1,n) quantization blocking
// (1,n) quantization blocking (BlockingShape::kRow == 1)
// The scale and offset tensors are prepacked to reduce the number of load instructions needed
const int row = lane_position_.row();
const int column = lane_position_.column() / BlockingShape::kColumn;
Expand Down Expand Up @@ -444,11 +446,10 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
// First convert 4b weight into fp16(weight + 16)
weights2Half(weights, dest);

if constexpr(kBTilesPerMma == 2
&& BlockingShape::kRow == 1){
if constexpr(kBTilesPerMma == 2){
// Optimize for a special case of:
// 2 B operand tiles per mma (kBTilesPerMma == 2)
// (1,n) quantization blocking
// (1,n) quantization blocking (BlockingShape::kRow == 1)

uint32_t* dest_pair = reinterpret_cast<uint32_t*>(dest.data());
const b64* scales_ptr = reinterpret_cast<const b64*>(scales.data());
Expand All @@ -475,7 +476,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
" shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6
" lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00
" lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n"
" mul.rn.f16x2 %0, %2, rb0;\n" // dest = scale * (16 + weight)
" mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset)
" mul.rn.f16x2 %1, %3, rb1;\n"
"}\n"
: "=r"(offsets.pair.a), "=r"(offsets.pair.b)
Expand All @@ -487,10 +488,22 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,

offsets_ptr += 4;
} else {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
"{\n\t"
" .reg .b32 rb0;\n"
" mov.u32 rb0, 0xce00ce00;\n"
" mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8)
" mul.rn.f16x2 %1, %3, rb0;\n"
"}\n"
: "=r"(offsets.pair.a), "=r"(offsets.pair.b)
: "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b));
#else
offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast<cutlass::half_t>(-16-8);
offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast<cutlass::half_t>(-16-8);
offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast<cutlass::half_t>(-16-8);
offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast<cutlass::half_t>(-16-8);
#endif
}

CUTLASS_PRAGMA_UNROLL
Expand Down

0 comments on commit b9f9cb7

Please sign in to comment.