Skip to content

Commit

Permalink
optimize column block dequant
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Feb 28, 2024
1 parent b9f9cb7 commit 31a602f
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -739,20 +739,119 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
// First convert 4b weight into fp16(weight + 16)
weights2Half(weights, dest);

int out_idx = 0;
CUTLASS_PRAGMA_UNROLL
for (int n_out = 0; n_out < kMmaIterationsB; n_out++){
ElementScale s = scales[n_out];
ElementScale offset;
ElementScale addon[kMmaIterationsB];
if constexpr (kMmaIterationsB % 4 == 0) {
const b64* scales_ptr = reinterpret_cast<const b64*>(scales.data());
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);
if constexpr(kHasOffset){
offset = s * static_cast<ElementScale>(-16 - int(offsets[n_out]));
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets.data());
CUTLASS_PRAGMA_UNROLL
for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1, rb2;\n"

// offset from [d, c, b, a] --> [d, b, c, a]
" prmt.b32 rb2, %4, rb0, 0x3120;\n"

// static_cast<cutlass::half_t>(-16 - offset)
// input [d, b, c, a],
" shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6
" shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6
" lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00
" lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n"
" mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset)
" mul.rn.f16x2 %1, %3, rb1;\n"
"}\n"
: "=r"(addon_ptr[0]), "=r"(addon_ptr[1])
: "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b),
"r"(p[0]));
#else
assert(0);
#endif
scales_ptr++;
p++;
addon_ptr += 2;
}
} else {
offset = s * static_cast<ElementScale>(-16-8);
CUTLASS_PRAGMA_UNROLL
for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
"{\n\t"
" .reg .b32 rb0;\n"
" mov.u32 rb0, 0xce00ce00;\n"
" mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8)
" mul.rn.f16x2 %1, %3, rb0;\n"
"}\n"
: "=r"(addon_ptr[0]), "=r"(addon_ptr[1])
: "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b));
#else
assert(0);
#endif
scales_ptr++;
addon_ptr += 2;
}
}
} else if constexpr (kMmaIterationsB % 2 == 0) {
const uint32_t* scales_ptr = reinterpret_cast<const uint32_t*>(scales.data());
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);

if constexpr (kHasOffset){
// possible buffer over read 2 bytes here.
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets.data());
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1, rb2;\n"

// offset from [?, ?, b, a] --> [?, b, ?, a]
" prmt.b32 rb2, %2, rb0, 0x3120;\n"

// static_cast<cutlass::half_t>(-16 - offset)
// input [d, b, c, a],
" shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6
" lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00
" mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset)
"}\n"
: "=r"(addon_ptr[0])
: "r"(scales_ptr[0])
"r"(p[0]));
#else
assert(0);
#endif
} else {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
"{\n\t"
" .reg .b32 rb0;\n"
" mov.u32 rb0, 0xce00ce00;\n"
" mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8)
"}\n"
: "=r"(addon_ptr[0])
: "r"(scales_ptr[0]));
#else
assert(0);
#endif
}
} else {
// kMmaIterationsB == 1
if constexpr(kHasOffset){
uint8_t zp = offsets[0];
addon[0] = scales[0] * static_cast<ElementScale>(-16 - static_cast<int>(zp));
} else {
addon[0] = scales[0] * static_cast<ElementScale>(-16-8);
}
}

int out_idx = 0;
CUTLASS_PRAGMA_UNROLL
for (int n_out = 0; n_out < kMmaIterationsB; n_out++){
CUTLASS_PRAGMA_UNROLL
for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){
dest[out_idx] = s * dest[out_idx] + offset;
dest[out_idx + 1] = s * dest[out_idx + 1] + offset;
dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out];
dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out];
out_idx += 2;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) {
testPrepack<true, false>(256, 256);
}

TEST(BlkQ4_GEMM, Sm80Test) {
TEST(BlkQ4_GEMM, Sm80RowBlockingTest) {
Status status = onnxruntime::cuda::test::sm80_supported();
if (!status.IsOK()) {
// skip the test if sm80 is not supported
Expand All @@ -290,14 +290,30 @@ TEST(BlkQ4_GEMM, Sm80Test) {

onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, false>(256, 1024, 576);
onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576);
}

onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(256, 672, 576);
onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(256, 672, 576);
TEST(BlkQ4_GEMM, Sm80ColBlockingTest) {
Status status = onnxruntime::cuda::test::sm80_supported();
if (!status.IsOK()) {
// skip the test if sm80 is not supported
return;
}
onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(64, 672, 576);
onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(64, 672, 576);

onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576);
onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576);

// small m
}

TEST(BlkQ4_GEMM, Sm80SmallMTest) {
Status status = onnxruntime::cuda::test::sm80_supported();
if (!status.IsOK()) {
// skip the test if sm80 is not supported
return;
}

// // small m
onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576);
onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576);

Expand Down

0 comments on commit 31a602f

Please sign in to comment.