From 990759969dbca0af2c631078fa274a2fb1a47178 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Mon, 20 Nov 2023 17:23:32 +0000 Subject: [PATCH 01/13] adding cuda kernel with tests --- .lintrunner.toml | 1 + cmake/onnxruntime_providers_cuda.cmake | 6 +- cmake/onnxruntime_unittests.cmake | 1 + onnxruntime/core/mickey/README.md | 4 + .../core/mickey/blk_q4/f16_gemm_sm80.h | 208 +++ .../{prepack_sm80.h => f16_prepack_sm80.h} | 2 +- .../cutlass_ext/q4gemm/device/quantb_gemm.h | 489 +++++++ .../q4gemm/kernel/default_quantb_gemm.h | 255 ++++ .../cutlass_ext/q4gemm/kernel/quantb_gemm.h | 470 ++++++ .../q4gemm/threadblock/default_quantb_mma.h | 248 ++++ .../threadblock/default_quantb_mma_core.h | 340 +++++ .../optional_predicated_tile_access_iter.h | 314 ++++ .../optional_regular_tile_access_iter.h | 224 +++ .../threadblock/quantb_mma_multistage.h | 1290 +++++++++++++++++ .../warp/default_quantb_mma_tensor_op.h | 112 ++ .../quantb_meta_mma_tensor_op_tile_iterator.h | 787 ++++++++++ .../q4gemm/warp/quantb_mma_tensor_op.h | 436 ++++++ onnxruntime/core/util/matrix_layout.h | 1 - .../cuda/test_cases/blkq4_fp16_gemm_sm80.h | 204 +++ ...k_test.cc => blkq4_fp16_gemm_sm80_test.cc} | 242 +--- .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 489 +++++++ 21 files changed, 5949 insertions(+), 174 deletions(-) create mode 100644 onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h rename onnxruntime/core/mickey/blk_q4/{prepack_sm80.h => f16_prepack_sm80.h} (99%) create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h rename onnxruntime/test/providers/cuda/test_cases/{blkq4_fp16_sm80_prepack_test.cc => blkq4_fp16_gemm_sm80_test.cc} (61%) create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu diff --git a/.lintrunner.toml b/.lintrunner.toml index 4e5d077b08ff4..be95e03479cf9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -132,6 +132,7 @@ exclude_patterns = [ 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/core/graph/contrib_ops/quantization_defs.cc', 'onnxruntime/core/mlas/**', # Contains assembly code + 'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks ] command = [ diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 9887d615c92d7..33441a94bd324 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -199,8 +199,10 @@ target_link_libraries(${target} PRIVATE cuda) endif() - include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) + include(cutlass) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) + endif() target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 0987d6d164dbd..4e38f85975025 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -783,6 +783,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md index 7e8d30cd1805b..735ec4b80daf3 100644 --- a/onnxruntime/core/mickey/README.md +++ b/onnxruntime/core/mickey/README.md @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that are often shared by various AI operators. The intention is to make this header files only, with no binary impact unless it is instantiated where it is needed. + +Currently cuda code are scattered in multiple locations in the repo. +Hopefully this can be the starting point of consolidating all cuda +code. diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h new file mode 100644 index 0000000000000..52bff7e40dbe3 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -0,0 +1,208 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blk_q4/f16_gemm_sm80.h + * + * Abstract: + * Entry point for Q4F16 GEMM kernel for SM80 devices. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/device/quantb_gemm.h" + +namespace onnxruntime { +namespace cuda { + +// +// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type +// +template < + typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16 + typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape + bool SmallM, // <- true if M <= 16 + bool kHasQuantOffset> +struct BlkQ4F16GemmImpl { + // + // Type definitions + // + + using ElementDequant = ElementDequant_; + using QuantBlocking = QuantBlocking_; + + static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!"); + + // Data types that are fixed for this kernel + using ElementAccumulator = float; + using ElementComputeEpilogue = ElementAccumulator; + using ElementInputA = ElementDequant; + using ElementOutput = ElementDequant; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + + // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators + // for async shared memory loading and minimize bank conflict + using ElementWPack = ElementDequant; + + using ElementQScale = ElementDequant; // <- data type of quantization scale + using ElementQOffset = uint8_t; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputWPack = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // Layout of quantization scale and offset, oriented to be loaded using less instructions + // in a warp tile + using LayoutInputQScale = + typename std::conditional::type; // <- layout of quantization scale + + using ShapeMMAThreadBlock = + typename std::conditional, + cutlass::gemm::GemmShape<128, 256, 64>>::type; + + static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32; + using ShapeMMAWarp = + typename std::conditional, + cutlass::gemm::GemmShape<64, 64, 64>>::type; + + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + + // Number of pipelines you want to use + static constexpr int NumStages = 3; + + using Gemm = cutlass::gemm::device::QuantBGemm< + ElementInputA, + LayoutInputA, + ElementWPack, + LayoutInputWPack, + ElementQScale, + typename std::conditional::type, + LayoutInputQScale, + QuantBlocking, + ElementOutput, + LayoutOutput, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages>; + + using Arguments = typename Gemm::Arguments; + + // Invoke gemm kernel (the version with quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_Qoffset_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (!kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_Qoffset_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } + + // Invoke gemm kernel (the version without quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h similarity index 99% rename from onnxruntime/core/mickey/blk_q4/prepack_sm80.h rename to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index e291ab39e8aa3..a08cfb97eed4a 100644 --- a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -3,7 +3,7 @@ * Licensed under the MIT License. * * Module Name: - * prepack_sm80.h + * blk_q4/f16_prepack_sm80.h * * Abstract: * Prepack weights and quantization parameters (scales and offsets) for diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h new file mode 100644 index 0000000000000..2e9b04d93b12e --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -0,0 +1,489 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass_ext/q4gemm/kernel/default_quantb_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! A specialized GEMM operator for quantized B GEMM. + + It is modified from cutlass::gemm::device::Gemm. Both this class and the original Gemm class + are pretty much boilerplate code that construct the Gemm kernel class, and pass parameters + and controls to it. The only difference is that this class has a few more template parameters + to support quantization. + + This implementation pretty much follows the design of cutlass. But this class seems to be + just a wrapper of the Gemm kernel class. Is this really necessary? + +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQScale_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class QuantBGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + // Quantization Parameters + static_assert(std::is_same::value, + "LayoutB, i.e. packed weights must appear ColumnMajor."); + static_assert(InstructionShape::kK == 16, + "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout."); + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQScale = LayoutQScale_; + using QuantBlocking = QuantBlocking_; + static constexpr bool kHasQOffset = !(std::is_same::value); + + // TODO enable uint4_t or smaller for QOffset + static_assert(!kHasQOffset || std::is_same::value, "QOffset must be uint8_t"); + + /// Define the kernel + using GemmKernel = typename kernel::DefaultQuantBGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQScale, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout + >::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_Qscale; + TensorRef ref_Qoffset; + + typename EpilogueOutputOp::Params epilogue; + + // split-K parallelism (etc.) are not yet supported, keeping this for future extension + int split_k_slices{1}; + // For gather+scatter operations + int const *gather_A_indices{nullptr}; + int const *gather_B_indices{nullptr}; + int const *scatter_D_indices{nullptr}; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params() + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); + } + + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params() + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + QuantBGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_Qscale.reset(args.ref_Qscale.non_const_ref().data()); + params_.ref_Qoffset.reset(args.ref_Qoffset.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + std::cerr << "Failed to obtain maximum shared memory size " << smem_size << " for kernel: " + << cudaGetErrorString(result) << "\n"; + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h new file mode 100644 index 0000000000000..3860a241395a6 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/default_gemm.h. templates for combining + * threadblock-scoped matrix multiply-add with the appropriate + * threadblock-scoped epilogue. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/kernel/quantb_gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "cutlass/layout/permute.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQScale_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Access granularity of quant scales in units of elements + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, + /// + typename Enable = void +> +struct DefaultQuantBGemm; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout type for quant scales + typename LayoutQScale, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Access granularity of quant scales in units of elements + typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout +> +struct DefaultQuantBGemm { + + static_assert((platform::is_same::value + || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementQScale, ElementQOffset, LayoutQScale, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + + using Affine2Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< + 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + using Epilogue = typename platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::QuantBGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h new file mode 100644 index 0000000000000..1f781b37b98b8 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -0,0 +1,470 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/gemm.h. + * Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct QuantBGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + static constexpr bool kHasQOffset = Mma::kHasQOffset; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorQScale::Params params_QScale; + typename Mma::IteratorQScale::TensorRef ref_QScale; + typename Mma::IteratorQOffset::Params params_QOffset; + typename Mma::IteratorQOffset::TensorRef ref_QOffset; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_size; // how many k vectors are processed by this threadblock + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr, + int const *gather_A_indices = nullptr, + int const *gather_B_indices = nullptr, + int const *scatter_D_indices = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + QuantBGemm() { } + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + // TODO check problem_size K, N must be multiple of QuantBlocking + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (problem_size.k() % Mma::Shape::kK != 0) { + // Currently we don't support this case due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. We can fix this by adding a predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return Status::kErrorInvalidProblem; + } + + int qscale_k = problem_size.k() / Mma::QuantBlocking::kRow; + int qscale_n = problem_size.n() / Mma::QuantBlocking::kColumn; + if ((qscale_k == 0) || (qscale_k * Mma::QuantBlocking::kRow != problem_size.k())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + if ((qscale_n == 0) || (qscale_n * Mma::QuantBlocking::kColumn != problem_size.n())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + + if (!TensorRef_aligned(ref_QScale, Mma::IteratorQScale::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + + if constexpr(kHasQOffset) { + if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k/2, params.problem_size.n()/2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; + int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; + if (qscale_k == 0) { + printf("qscale_k is 0! can_implement() should have returned false!\n"); + } + if (qscale_k * Mma::QuantBlocking::kRow != problem_size_k) { + printf("qscale_k * Mma::QuantBlocking::kK != problem_size_k! can_implement() should have returned false!\n"); + } + if (qscale_n == 0) { + printf("qscale_n is 0! can_implement() should have returned false!\n"); + } + if (qscale_n * Mma::QuantBlocking::kColumn != params.problem_size.n()) { + printf("qscale_n * Mma::QuantBlocking::kN != params.problem_size.n()! can_implement() should have returned false!\n"); + } + + cutlass::MatrixCoord tb_offset_QScale{ + threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) + }; + + typename Mma::IteratorQScale iterator_QScale( + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); + + typename Mma::IteratorQOffset iterator_QOffset( + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_QScale, iterator_QOffset, accumulators); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h new file mode 100644 index 0000000000000..be03ba60fe15f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma.h + * @brief Modified from cutlass/gemm/threadblock/default_mma.h. + * Defining global memory data layout and iterators, combinging with mma core and + * pipelined GEMM kernel. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h" +#include "cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout for quant scales and offsets + typename LayoutQScale_, + /// Blocking size for quantization + typename QuantBlocking_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute + > +struct DefaultQuantBMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout for quant scales and offsets + typename LayoutQScale, + /// Blocking size for quantization + typename QuantBlocking, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout + > +struct DefaultQuantBMma { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQScale, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; + + // Define iterators over tiles from the quant scales + using ThreadMapQScale = typename MmaCore::IteratorThreadMapQScale; + using AccessTypeQScale = + cutlass::Array; + using IteratorQScale = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, + ElementQScale, LayoutQScale, 0, ThreadMapQScale, AccessTypeQScale>; + + using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset; + using AccessTypeQOffset = + cutlass::Array; + using IteratorQOffset = + cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQScale, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorQScale, typename MmaCore::SmemIteratorQScale, + cutlass::arch::CacheOperation::Global, IteratorQOffset, + typename MmaCore::SmemIteratorQOffset, cutlass::arch::CacheOperation::Global, + ElementAccumulator, LayoutC, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h new file mode 100644 index 0000000000000..060d2134cfcef --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_core.h + * @brief Modified from cutlass/gemm/threadblock/default_mma_core.h. + * Defining data layout in shared memory, and its iterators. + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Element data type of quant scale + typename ElementQScale, + /// Element data type of quant offset + typename ElementQOffset, + /// Layout of quant scale + typename LayoutQScale, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DefaultQuantBMmaCore; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Element data type of quant scale + typename ElementQScale_, + /// Element data type of quant offset + typename ElementQOffset_, + /// Layout of quant scale + typename LayoutQScale_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultQuantBMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQScale = LayoutQScale_; + using QuantBlocking = QuantBlocking_; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + (Shape::kK / 2) / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK/2>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + using SmemLayoutQScale = LayoutQScale; + using SmemLayoutQOffset = LayoutQScale; + + /// Threadblock-level quantization meta data shape + using ThreadblockQShape = MatrixShape; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!"); + static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, + "Only support single column or row quantize blocking!"); + static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + + /// Threadblock-level quantization meta data shape in pitch-linear layout + using TBQPitchLinearShape = typename std::conditional< + std::is_same::value, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// By default we would like to use 128b load. However, we can't load more than + /// a column at a time in a column major layout. + static int const kElementsPerAccessQScale = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + + /// quant scale is tiny. Not all threads are needed. + static int const kAccessCntQScale = ThreadblockQShape::kCount / kElementsPerAccessQScale; + static int const kThreadsQScale = (kAccessCntQScale > kThreads) ? kThreads : kAccessCntQScale; + + using IteratorThreadMapQScale = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQScale, kElementsPerAccessQScale>; + + using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + + static int const kElementsPerAccessQOffset = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + static int const kAccessCntQOffset = ThreadblockQShape::kCount / kElementsPerAccessQOffset; + static int const kThreadsQOffset = (kAccessCntQOffset > kThreads) ? kThreads : kAccessCntQOffset; + + using IteratorThreadMapQOffset = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQOffset, kElementsPerAccessQOffset>; + + using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultQuantBMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQScale, QuantBlocking, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h new file mode 100644 index 0000000000000..6f27a692a3a2e --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -0,0 +1,314 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_predicated_tile_access_iter.h + * @brief Templates for loading and storing optional tiles of matrix data. + * This iterator is just a wrapper of PredicatedTileAccessIterator, with + * the option to turn it off at compile time and minimize its runtime + * footprint. Also, it utilize the higher numbered threads in the + * threadblock when the iterator can not utilize all the threads. + */ + +#pragma once + +#include + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D matrix data loader, when element is std::monostate, the +/// iterator becomes no-op with minimal runtime footprint. Also, it utilize the +/// higher numbered threads in the threadblock when the iterator can not utilize +/// all the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + /// Element data type of the iterator, no-op when it is std::monostate + typename Element_, + /// Layout of the source matrix + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + /// Number of threads in the threadblock, when provided, the iterator + /// will utilize the higher numbered threads + int kThreadBlockSize_ = -1> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kAdvanceRank = AdvanceRank_; + static constexpr int kThreadblockSize = kThreadBlockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized version below."); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + public: + Base base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : OptionalPredicatedTileAccessIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + base_.add_tile_offset(tile_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + OptionalPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + base_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + base_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + base_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + base_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return base_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for the disabled version +/// Reduce runtime overhead +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + int kThreadBlockSize_> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kThreadblockSize = kThreadBlockSize_; + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + public: + std::monostate base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_() {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) {} + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) {} + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() {} + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) {} + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) {} + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return false; } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h new file mode 100644 index 0000000000000..4b0ae5317f8bb --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h @@ -0,0 +1,224 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_regular_tile_access_iter.h + * @brief Templates implementing the address computation of storing of tiles + * from pitch-linear rank=2 tensors. + * + * This iterator is just a wrapper of RegularTileAccessIterator, with the + * option to turn it off at compile time and minimize its runtime footprint. + * Also, it utilize the higher numbered threads in the threadblock when the + * iterator can not utilize all the threads. + * + * Must be used in conjunction with OptionalPredicatedTileAccessIterator, + * with the same template parameters. + */ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D tile iterator, when element is std::monostate, the iterator +/// becomes no-op with minimal runtime footprint. Also, it utilize the higher +/// numbered threads in the threadblock when the iterator can not utilize all +/// the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Element_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + /// Number of threads in the threadblock, when not -1, the iterator + /// will utilize the higher numbered threads + int ThreadblockSize_ = -1, + int Alignment = + sizeof_bits::value * ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized template"); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + private: + + Base base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_(ref, flip_thread_id(thread_id)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + base_.add_tile_offset(coord); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization when Element is std::monostate, the iterator becomes no-op +/// +template < + typename Shape_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + int ThreadblockSize_, + int Alignment> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + private: + + std::monostate base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + return *this; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) {} +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h new file mode 100644 index 0000000000000..c8aff17151f29 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -0,0 +1,1290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_multistage.h + * @brief Modified from cutlass/gemm/threadblock/mma_multistage.h. + * Added the quantized data memory pipeline, dequantization, and feeding + * to tensor cores. Mainloop pipeline is heavily modified. + */ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Utilities for printing layout for the prepacked weights and quantization parameters +/// +template< + /// Data type of the prepacked weights + typename ElementWeight, + /// Data type of the quant scales + typename ElementQScale, + /// Data type of the quant offsets + typename ElementQOffset> +struct QuantBLayoutDebug{ + static constexpr bool debug_smem = true; + static constexpr bool debug_fragment = true; + ElementWeight* smem_b_ptr_; + ElementQScale* smem_qscale_ptr_; + ElementQOffset* smem_qoffset_ptr_; + int warp_id_; + int lane_id_; + int block_id_; + + template + CUTLASS_DEVICE + static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + static_assert(Size % 4 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const Element* ptr = reinterpret_cast(&frag); + for (int i = 0; i < Size/4; i++, ptr+=4){ + if constexpr(std::is_integral::value){ + printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", + threadIdx.x, label, i, + ptr[0], ptr[1], ptr[2], ptr[3]); + } else { + printf("T%.2d%c%d, %.3f, %.3f, %.3f, %.3f\n", + threadIdx.x, label, i, + float(ptr[0]), float(ptr[1]), float(ptr[2]), float(ptr[3])); + } + } + } + } + } + + template + CUTLASS_DEVICE + static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; + static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const uint8_t* ptr = reinterpret_cast(&frag); + for (int i = 0; i < I8Size/2; i++, ptr+=2){ + printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); + } + } + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dummy type when quant offset is not used, to avoid compilation error, +/// and reduce runtime footprint +/// +struct DummyType{ + std::monostate dummy_; + public: + DummyType() = default; + + CUTLASS_HOST_DEVICE + void* data() const { + return nullptr; + } + + CUTLASS_HOST_DEVICE + std::monostate& operator[](int idx) { + return dummy_; + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + static constexpr bool kHasQOffset = !std::is_same::value; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the prepacked weights + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // Tensor reference to the quantization scales + using TensorRefQScale = TensorRef; + using TensorRefQOffset = TensorRef; + + // Block size of the quantization (one set of quantization parameters per block of weights) + using QuantBlocking = typename Operator::QuantBlocking; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the prepacked weights in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the quantization parameter matrix in shared memory + /// Validation done in mma core class ThreadblockQShape + using ShapeQScale = + MatrixShape<(Shape::kK / QuantBlocking::kRow) * kStages, + Shape::kN / QuantBlocking::kColumn>; + + using BufTypeQOffset = std::conditional_t, + DummyType>; + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for prepacked weights + AlignedBuffer operand_B; + + /// Buffer for quantization scales + AlignedBuffer operand_QScale; + + /// Buffer for quantization offsets + BufTypeQOffset operand_QOffset; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQScale LayoutQScale() { + return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQOffset LayoutQOffset() { + return Operator::SmemLayoutQOffset::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the prepacked weights + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the quantization scales + CUTLASS_HOST_DEVICE + TensorRefQScale operand_QScale_ref() { + return TensorRefQScale{operand_QScale.data(), LayoutQScale()}; + } + + CUTLASS_HOST_DEVICE + TensorRefQOffset operand_QOffset_ref() { + if constexpr (!kHasQOffset){ + return TensorRefQOffset(); + } else { + return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; + } + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + /// Iterator to load a warp-scoped tile of quant scales from shared memory + typename Operator::IteratorQScale warp_tile_iterator_QScale_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) + {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over tiles of quant scales in global memory + typename IteratorQScale_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQScale_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQScale, + /// Iterators over tiles of quant scales in global memory + typename IteratorQOffset_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQOffset_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQOffset, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaMultistage : + public QuantBMmaBase { +public: + ///< Base class + using Base = QuantBMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using IteratorQScale = IteratorQScale_; + using IteratorQOffset = IteratorQOffset_; + using SmemIteratorQScale = SmemIteratorQScale_; + using SmemIteratorQOffset = SmemIteratorQOffset_; + using QuantBlocking = typename Base::QuantBlocking; + + static cutlass::arch::CacheOperation::Kind const kCacheOpQScale = CacheOpQScale; + static cutlass::arch::CacheOperation::Kind const kCacheOpQOffset = CacheOpQOffset; + static constexpr bool kHasQOffset = Base::kHasQOffset; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of packed weights + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQScale = + IteratorQScale::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant scale + static int const kAccessesPerGroupQScale = + (AsyncCopyIterationsPerStageQScale + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQOffset = + IteratorQOffset::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant offset + static int const kAccessesPerGroupQOffset = + (AsyncCopyIterationsPerStageQOffset + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; + }; + + private: + + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + + using WarpLoadedFragmentQScale = typename Operator::FragmentQScale; + WarpLoadedFragmentQScale warp_loaded_frag_QScale_; + + using WarpLoadedFragmentQOffset = typename std::conditional::type; + WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; + }; + + + private: + + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of quant meta data to shared memory + SmemIteratorQScale smem_iterator_QScale_; + SmemIteratorQOffset smem_iterator_QOffset_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + /// very small meta data tensor require less threads to load + bool const should_load_qscale_; + bool const should_load_qoffset_; + + /// Shared memory pointers for debug dumping + static constexpr bool debug_layout = false; + using LayoutDebugType = typename std::conditional, + std::monostate>::type; + LayoutDebugType layout_debug_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + if constexpr(debug_layout){ + layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); + layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); + if constexpr(kHasQOffset){ + layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); + } else { + layout_debug_.smem_qoffset_ptr_ = nullptr; + } + layout_debug_.warp_id_ = warp_idx; + layout_debug_.lane_id_ = lane_idx; + layout_debug_.block_id_ = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + } + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_QScale_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_QScale_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + + smem_read_stage_idx_ = 0; + } + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + IteratorQScale &iterator_QScale, + IteratorQOffset &iterator_QOffset) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_QScale.add_tile_offset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_QScale_.add_tile_offset({1, 0}); + + if constexpr (kHasQOffset){ + iterator_QOffset.add_tile_offset({1, 0}); + smem_iterator_QOffset_.add_tile_offset({1, 0}); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0}); + if constexpr (kHasQOffset){ + smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0}); + } + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, + // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only + // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so + // it should be loaded in less than one cp.async instruction per thread. + // Even less for quant offset matrix. + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, + "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, + "Quant scale should 1 access per vector!"); + + // Async Copy for quantization scale + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); + } + + CUTLASS_DEVICE + void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, + "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, + "Quant offset should 1 access per vector!"); + + if constexpr(kHasQOffset){ + // Async Copy for quantization offset + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start = 0) { + auto group_start_A = group_start * Detail::kAccessesPerGroupA; + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + auto group_start_B = group_start * Detail::kAccessesPerGroupB; + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Async Copy for quantization scale + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); + + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + auto gmem_ptr = iterator_QScale.get(); + + cutlass::arch::cp_async( + dst_ptr, gmem_ptr, iterator_QScale.valid()); + + if constexpr (kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + + // Async Copy for quantization offset + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + if constexpr(debug_layout){ + if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1){ + if (threadIdx.x == 0){ + printf("stage: %d\n", smem_write_stage_idx_); + } + cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); + if constexpr(kHasQOffset){ + cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); + } + } + } + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Loading next warp-level tiles from shared memory. This can be skipped on the very + // last iteration where: + // (gemm_k_iterations == (1 - Base::kStages)) && (warp_mma_k == (Base::kWarpGemmIterations - 1)) + // However, evaluating this condition seems more expensive than simply loading the tiles + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + if constexpr(debug_layout){ + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout){ + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k == 0) { + copy_qscale_tiles(iterator_QScale); + } + if (warp_mma_k == 1) { + copy_qoffset_tiles(iterator_QOffset); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + /// Specialized mainloop iteration of matrix multiply-accumulate, for small M + CUTLASS_DEVICE + void mac_loop_iter_small_m( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // In the case of small M, memory latency dominates. We try to move uses far + // from their definitions to hide latency. + if constexpr(debug_layout){ + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout){ + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Loading next warp-level tiles from shared memory. + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + copy_tiles_and_advance(iterator_A, iterator_B, 0); + + if constexpr(Shape::kM > 32){ + // the case of bigger m + if constexpr(debug_layout){ + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout){ + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } else { + // the case of small m + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + } + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + if constexpr(Shape::kM > 32){ + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } else { + mac_loop_iter_small_m( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over quant scales in global memory + IteratorQScale iterator_QScale, + ///< Iterator over quant offsets in global memory + IteratorQOffset iterator_QOffset, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..2c49888c94504 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/default_mma_tensor_op.h + * Default warp-level GEMM operators selected by data type, size, and layouts of operands. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Data type of quant scales + typename ElementQScale, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale, + /// Data type of quant offsets + typename ElementQOffset, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset, + /// Blocking size of quantization + typename QuantBlocking, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultQuantBMmaTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::QuantBMmaTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, QuantBlocking, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h new file mode 100644 index 0000000000000..107db414c23bc --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -0,0 +1,787 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file quantb_meta_mma_tensor_op_tile_iterator.h + * @brief Templates for loading quantization meta data for operand B + * from shared memory to fragments. This is meant to be used in + * lock step with the operand B tile iterator. Containing logic + * to figure out the operand B layout in the tensor core, + * and deliver each meta data element to its corresponding + * operand B element for dequantization. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace{ + +struct b32_pair{ + uint32_t a; + uint32_t b; +}; + +struct fp16_quard{ + cutlass::half_t a; + cutlass::half_t b; + cutlass::half_t c; + cutlass::half_t d; +}; + +struct b16_quard{ + int16_t a; + int16_t b; + int16_t c; + int16_t d; +}; + +union b64 { + uint64_t single; + b32_pair pair; + b16_quard quard; + fp16_quard fp16_quard; +}; + +static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); + +/// Convert packed 4b weights into fp16(weight + 16) +/// Current bit hacking only supports fp16, need to add bf16 later. +/// +template +CUTLASS_DEVICE +void weights2Half(cutlass::Array const &weights, + cutlass::Array& dest) +{ + static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const uint32_t* w_oct = reinterpret_cast(weights.data()); + + CUTLASS_PRAGMA_UNROLL + for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + // static_cast(16 + weight) + // 4b weights are prepacked into [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent weights + // are in different 16b half words, making it easier to convert to fp16. + asm volatile( + "{\n\t" + " shl.b32 %0, %4, 6;\n" + " shl.b32 %1, %4, 2;\n" + " shr.u32 %2, %4, 2;\n" + " shr.u32 %3, %4, 6;\n" + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" + "}\n" + : "=r"(dest_pair[0]), "=r"(dest_pair[1]), + "=r"(dest_pair[2]), "=r"(dest_pair[3]) + : "r"(*w_oct)); +#else + assert(0); +#endif + } + +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +// Traits to describe the layout of quantization meta data layout in a MMA fragment +// Since operand B is quantized on a per block basis, it's one meta data per block. + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ArchMmaOperator = ArchMmaOperator_; + + static_assert(Threads == 32, "This iterator should work in a warp only."); + + /// Shape of the curresponding operand B tile iterator + using TileShapeB = MatrixShape; + + // Tensor core operand B layout is a column major 4x8 tile, divided + // into 32 threads (T0 ~ T31) as shown below. Each element of the tile is 32b, + // so for fp16 it becomes 8 x 8, and int8 it becomes 16 x 8. + // T0 | T4 | T8 | T12 | T16 | T20 | T24 | T28 + // T1 | T5 | T9 | T13 | T17 | T21 | T25 | T29 + // T2 | T6 | T10 | T14 | T18 | T22 | T26 | T30 + // T3 | T7 | T11 | T15 | T19 | T23 | T27 | T31 + using CoreTile = layout::PitchLinearShape<4, 8>; + + /// Each thread holds a 32b fragement per tile: for half precision, it's 2 elements, 4 elements for int8 + static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits::value; + + /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension) + static int const kBTilesPerMma = + sizeof_bits::value * ArchMmaOperator::FragmentB::kElements / 32; + static_assert(kBTilesPerMma == 1 || kBTilesPerMma == 2, "Only support 1 or 2 operand B tiles per mma."); + + /// Each operand B tile iterator load covers a number of mma instructions + static int const kMmaIterationsB = WarpShapeB::kColumn / ArchMmaOperator::Shape::kN; + + /// Number of B elements a fragment of meta data should cover + static int const kExpandedSize = kNumBsPerCoreTileFragement * kBTilesPerMma * kMmaIterationsB; + + // Now we figure out how many meta data elements to load for each TileShapeB + + /// Number of meta elements per CoreTile. + static int const kCoreTileFragementSize = (kNumBsPerCoreTileFragement + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension + /// exceeds the tile depth, so two tiles share the same meta data + static int const kTilesPerMma = ((kBTilesPerMma == 2) && + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 : 1; + + /// stride to reach the meta data for the next CoreTile on the K dimension + static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Stride on N dimention should be the tile width, shrunk by blocking size on this dimension. + static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn; + + /// On N dimension, how many tiles share the same meta data + static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided; + + /// Each fragement should cover kMmaIterationsB number of mma intructions on the N dimension. + /// When blocking size on this dimension exceeds the tile width, multiple iterations + /// would share the same data. + static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats; + + static int const kFragementSize = kCoreTileFragementSize * kTilesPerMma * kMmaIterations; + + CUTLASS_DEVICE + static MatrixCoord lane_position(int lane_id) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + // The weight and offset tensor is prepacked to reduce load instructions. + return make_Coord((lane_id % CoreTile::kContiguous) * 4, + lane_id / CoreTile::kContiguous); + } else { + return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, + lane_id / CoreTile::kContiguous); + } + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is to load quantization meta data for operand B from +/// shared memory to fragments (hopefully allocated to registers by compilers). +/// Examples of meta data include scale or offsets. The operand B matrix is +/// quantized on a per block basis, meaning one element of meta data per block. +/// +/// This is meant to be used in lock step with the operand B tile iterator. +/// So all parameters are logical positions in the operand B tiles. +/// The goal here is to deliver each meta data element to its corresponding +/// operand B element for dequantization. As a result, we need to figure +/// out the operand B layout in the tensor core. +/// +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class QuantBMetaMmaTensorOpTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using Layout = cutlass::layout::ColumnMajor; + using ElementOffset = ElementOffset_; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + + using AccessTypeScale = Array; + using AccessTypeOffset = Array; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)){} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + // The weight and offset tensor is prepacked to reduce load instructions. + const int row = lane_position_.row(); + const int column = lane_position_.column() / BlockingShape::kColumn; + + Array *dst_ptr = reinterpret_cast*>(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + + if constexpr(kHasOffset){ + Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + *dst_ptr_offset = *src_ptr_offset; + dst_ptr_offset++; + } + } + + } else { + // Other cases, offsets and scales are not prepacked. + + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + + AccessTypeScale* dst_ptr = reinterpret_cast(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + + if constexpr(kHasOffset){ + AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); + static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + if constexpr(kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const b64* scales_ptr = reinterpret_cast(scales.data()); + const ElementOffset* offsets_ptr = nullptr; + if constexpr(kHasOffset) { offsets_ptr = offsets.data(); } + + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + // dequantize: d = scale * (weight - offset) + // to use FMA, d = scale * weight + (scale * (-offset)) + + b64 offsets; + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets_ptr); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // dest = scale * (16 + weight) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + + offsets_ptr += 4; + } else { + offsets.fp16_quard.a = scales_ptr->fp16_quard.a * static_cast(-16-8); + offsets.fp16_quard.b = scales_ptr->fp16_quard.b * static_cast(-16-8); + offsets.fp16_quard.c = scales_ptr->fp16_quard.c * static_cast(-16-8); + offsets.fp16_quard.d = scales_ptr->fp16_quard.d * static_cast(-16-8); + } + + CUTLASS_PRAGMA_UNROLL + for (int n_r = 0; n_r < kNRepeats; n_r++){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %1, %3, %1, %5;\n" + "}\n" + : "+r"(dest_pair[0]), "+r"(dest_pair[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(offsets.pair.a), "r"(offsets.pair.b)); +#else + assert(0); +#endif + dest_pair += 2; + } + scales_ptr++; + } + + } else { + // unoptiomized path for other cases, very slow + int out_idx = 0; + ElementScale offset; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ElementScale s = scales[idx]; + if constexpr(kHasOffset){ + offset = s * static_cast(-16 - int(offsets[idx])); + } else { + offset = s * static_cast(-16-8); + } + dest[out_idx] = s * dest[out_idx] + offset; + out_idx++; + } + } + } + + } + + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator--() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using ElementOffset = ElementOffset_; + using Layout = cutlass::layout::RowMajor; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, + "Only support column blocking for row major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) + {} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); + + ElementScale* src_ptr = pointer_ + layout_({row, column}); + ElementScale* dst_ptr = frag.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr[n_idx] = src_ptr[n_idx * kNStride]; + } + + if constexpr(kHasOffset){ + ElementOffset* src_ptr_offset = pointer_offset_ + layout_offset_({row, column}); + ElementOffset* dst_ptr_offset = frag_offset.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr_offset[n_idx] = src_ptr_offset[n_idx * kNStride]; + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; + int idx = col * kMmaIterations + n_idx; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + ElementScale s = scales[n_out]; + ElementScale offset; + if constexpr(kHasOffset){ + offset = s * static_cast(-16 - int(offsets[n_out])); + } else { + offset = s * static_cast(-16-8); + } + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + dest[out_idx] = s * dest[out_idx] + offset; + dest[out_idx + 1] = s * dest[out_idx + 1] + offset; + out_idx += 2; + } + } + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator--() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..a88be45952857 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -0,0 +1,436 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/mma_tensor_op.h + * Templates implementing warp-level matrix multiply-accumulate operations + * targeting tensor cores. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace internal { + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + return converter(source); + } +}; + +template +struct ConvertAndPack { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + return source; + } +}; + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + Array tmp; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); + tmp[i] = source[idx]; + } + + return converter(tmp); + } +}; + +template +struct ConvertAndPack { + + using Converter = NumericArrayConverter; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &source) { + Converter converter; + + Array tmp; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); + tmp[i] = source[idx]; + } + + return converter(tmp); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace internal + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class QuantBMmaTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + // warp B MatrixShape<64, 64>, + // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, + // instruction op shape cutlass::MatrixShape<16, 8>, + // kPartitionsK 1 + // FragmentB::kElements 32 + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; // cutlass::Array + + /// Storage for transformed B tile + /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded + /// we multiply the number of elements by 4. + /// TODO: make sure ArchMmaOperator::ElementB same as dequantized ElementB + /// and change the transform function below to perform dequantization + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + using ElementQScale = ElementQScale_; + using SmemLayoutQScale = SmemLayoutQScale_; + using QuantBlocking = QuantBlocking_; + + using ElementQOffset = ElementQOffset_; + using SmemLayoutQOffset = SmemLayoutQOffset_; + + /// Iterates over the quantization parameters in memory + using WarpQScaleShape = MatrixShape<(Shape::kK / QuantBlocking::kRow), (Shape::kN / QuantBlocking::kColumn)>; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(WarpQScaleShape::kCount > 0, "QuantBlocking too big to fit in a warp block!"); + + // TODO This is an expanding iterator, it needs to replicate the quantization parameters + // to all threads in the warp. + using IteratorQScale = QuantBMetaMmaTensorOpTileIterator< + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; + + using FragmentQScale = typename IteratorQScale::FragmentScale; + using FragmentQOffset = typename IteratorQScale::FragmentOffset; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + QuantBMmaTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentB &dst_B, + FragmentB const &B, + FragmentQScale const &scales, + FragmentQOffset const &offsets) const { + + Array const *ptr_B = + reinterpret_cast const *>(&B); + IteratorQScale::dequant(scales, offsets, *ptr_B, dst_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index a0405e32034ae..783a29d8a2055 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -17,7 +17,6 @@ #include #include "core/common/gsl.h" -// TODO!! Already have this in cuda, what about cpu code though? #if defined(_MSC_VER) #define ORT_FORCEINLINE __forceinline #else diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h new file mode 100644 index 0000000000000..42a118a89ff54 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -0,0 +1,204 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80.h + * + * Abstract: + * Bridge between gtest code and gemm kernel implementation. + * Gemm kernel requires CUTLASS header files, which causes strange + * compilation errors with RE2 header files, which are required + * by gtest. + */ + +#pragma once + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace cuda { +namespace test { + +Status sm80_supported(); + +static inline void prepack_weights_ref( + int rows, + int columns, + const MatrixRef& tensor_weight, + const MatrixRef& tensor_weight_prepacked) { + ORT_ENFORCE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns, + "Unexpected tensor_weight shape! Expected: (", rows / 2, ", ", columns, "), Got: (", + tensor_weight.shape()[0], ", ", tensor_weight.shape()[1], ")."); + ORT_ENFORCE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2, + "tensor_weight_prepacked shape is not compatible with prepacked weight shape"); + + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } +} + +template < + typename ScaleElementT, + typename Layout, + typename QuantBlocking> +void prepack_quant_scales_ref( + int rows, + int columns, + const MatrixRef& tensor_scale, + const MatrixRef& tensor_scale_prepacked) { + ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn), + "Unexpected tensor_scale shape! Expected: (", + rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); + ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + std::copy(tensor_scale.data().begin(), tensor_scale.data().end(), tensor_scale_prepacked.data().begin()); + } +} + +template +void prepack_quant_offsets_ref( + size_t rows, + size_t columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + ORT_ENFORCE(tensor_offset_prepacked.shape() == tensor_offset.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + std::copy(tensor_offset.data().begin(), tensor_offset.data().end(), tensor_offset_prepacked.data().begin()); + return; + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); + tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); + tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); + tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); + } + } + } + } +} + +template < + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc similarity index 61% rename from onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc rename to onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index aba2b0b2cb4a4..ff6c38bb56d32 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -1,181 +1,29 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_test.cc + * + * Abstract: + * Test code for block-wise quantized 4b GEMM kernels. + * This part requires gtest header files, which do not play + * well with CUTLASS headers. + */ #include #include "core/framework/float16.h" -#include "core/mickey/blk_q4/prepack_sm80.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" #include "core/mlas/inc/mlas_q4.h" +#include "blkq4_fp16_gemm_sm80.h" + #include "gtest/gtest.h" namespace onnxruntime { namespace test { -void prepack_weights_ref( - int rows, - int columns, - const MatrixRef& tensor_weight, - const MatrixRef& tensor_weight_prepacked) { - EXPECT_TRUE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns); - EXPECT_TRUE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2); - - auto t0_base = make_Position(0, 0); - auto t1_base = make_Position(4, 0); - auto t2_base = make_Position(0, 8); - auto t3_base = make_Position(4, 8); - for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { - for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { - // Packing from a 8x16 tile to a 16x8 tile - auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); - auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto cord = make_Position(row, col); - auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 - uint8_t buf[4]; - buf[0] = tensor_weight.at(dtile_base + t0_base + cord); - buf[1] = tensor_weight.at(dtile_base + t1_base + cord); - buf[2] = tensor_weight.at(dtile_base + t2_base + cord); - buf[3] = tensor_weight.at(dtile_base + t3_base + cord); - - // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights - // are in different b16 register at the same positions. This makes it easier to convert to - // fp16x2 format in a b32 register - - tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); - tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); - } - } - } - } -} - -template < - typename ScaleElementT, - typename Layout, - typename QuantBlocking> -void prepack_quant_scales_ref( - int rows, - int columns, - const MatrixRef& tensor_scale, - const MatrixRef& tensor_scale_prepacked) { - EXPECT_TRUE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_scale_prepacked.shape() == tensor_scale.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - - for (int col = 0; col < tensor_scale.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); - tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); - tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); - tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); - } - } - } - } else { - // In all other cases, we don't prepack scale or offset - FAIL() << "Scale prepack only supported for 16b gemm with (1,n) quantization blocking"; - } -} - -template -void prepack_quant_offsets_ref( - size_t rows, - size_t columns, - MatrixRef tensor_offset, - MatrixRef tensor_offset_prepacked) { - // EXPECT_TRUE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_offset_prepacked.shape() == tensor_offset.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (QuantBlocking::kRow != 1) { - FAIL() << "Offsets prepack only supported for 16b gemm with (1,n) quantization blocking"; - } - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own - // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to - // convert to fp16x2 format in a b32 register - tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); - tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); - tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); - tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); - } - } - } - } -} - template void testPrepack(int rows, int columns, bool has_offset = true) { using ElementT = MLFloat16; @@ -407,7 +255,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { std::vector packed_w_ref(q_weight_shape.product()); MatrixRef tensor_packed_w_ref( packed_w_ref, make_Position(rows, columns / 2)); - prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); + onnxruntime::cuda::test::prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); std::vector packed_w(q_weight_shape.product()); MatrixRef tensor_packed_w( @@ -429,7 +277,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) : tensor_scale; if (Base::ShouldRearrangeMeta) { - prepack_quant_scales_ref( + onnxruntime::cuda::test::prepack_quant_scales_ref( rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); } @@ -454,7 +302,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) : tensor_offset; if (Base::ShouldRearrangeMeta) { - prepack_quant_offsets_ref( + onnxruntime::cuda::test::prepack_quant_offsets_ref( rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); } @@ -477,6 +325,12 @@ void testPrepack(int rows, int columns, bool has_offset = true) { // TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 TEST(BlkQ4_GEMM, PrepackSm80Test) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + testPrepack(32, 32); testPrepack(32, 32, false); testPrepack(32, 32); @@ -503,5 +357,53 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { testPrepack(256, 256, false); } +TEST(BlkQ4_GEMM, Sm80Test) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 32, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 32, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 192); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 192); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576); + + // small m + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, true>(16, 1024, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, false>(16, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, true>(16, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, true>(16, 1024, 576); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu new file mode 100644 index 0000000000000..6dcdad67e9511 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -0,0 +1,489 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_testcu.cu + * + * Abstract: + * Test code for invoking block-wise quantized 4b GEMM kernels. + * This part requires CUTLASS header files, which do not play + * well with gtest headers. + */ + +#include "core/mickey/blk_q4/f16_gemm_sm80.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "core/common/common.h" + +#include "blkq4_fp16_gemm_sm80.h" + +namespace onnxruntime { +namespace cuda{ +namespace test{ + +Status sm80_supported(){ + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::ostringstream ss; + ss << "Unable to obtain GPU device properties: " << cudaGetErrorString(error); + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::ostringstream ss; + ss << "Device compute capability mismatch, desired 8.0, actual " << props.major << "." << props.minor; + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + return Status::OK(); +} + +/** + * @brief Reference implementation of GEMM + * Copied directly from cutlass util/reference/device/gemm.h + * for the strange reason that compiler insists on asking + * for explicit stream argument in kernel launch. +*/ +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType +> +void compute_gemm_ref( + cutlass::gemm::GemmCoord problem_size, + ScalarType alpha, + cutlass::TensorRef tensor_a, + cutlass::TensorRef tensor_b, + ScalarType beta, + cutlass::TensorRef tensor_c, + cutlass::TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = cutlass::MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + cutlass::reference::device::kernel::Gemm< + cutlass::TensorRef, + cutlass::TensorRef, + cutlass::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + cutlass::multiply_add, + cutlass::NumericConverter + ><<>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Converting cutlass tensor to MatrixRef +// + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_MatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + auto* ptr = const_cast::type *>(tensor.host_data()); + return MatrixRef(ptr, tensor.capacity(), shape); +} + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_ConstMatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + return MatrixRef(tensor.host_data(), tensor.capacity(), shape); +} + +// +// Invoking the kernel +// + +template< + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k) { + + using ElementDequant = cutlass::half_t; + using QuantBlocking = + typename std::conditional, + cutlass::MatrixShape<1, block_size>>::type; + + using GemmRunner = BlkQ4F16GemmImpl; + + using ElementAccumulator = typename GemmRunner::ElementAccumulator; + using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; + using ElementInputA = typename GemmRunner::ElementInputA; + using ElementOutput = typename GemmRunner::ElementOutput; + using ElementW = typename GemmRunner::ElementW; + using ElementWPack = typename GemmRunner::ElementWPack; + using ElementQScale = typename GemmRunner::ElementQScale; + using ElementQOffset = typename GemmRunner::ElementQOffset; + + using LayoutInputA = typename GemmRunner::LayoutInputA; + using LayoutOutput = typename GemmRunner::LayoutOutput; + using LayoutInputWPack = typename GemmRunner::LayoutInputWPack; + using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; + + const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + problem_size.mk()); // <- Create matrix A with dimensions M x K + + // Create weight matrix with dimensions K x N. + // Actual weight type is int4, we use ElementW = uint8 to avoid possible compilation + // troubles. Since the layout is column major, we are packing 2 weights in a column + // into one int8 + cutlass::HostTensor tensor_weight( + {problem_size.k()/2, problem_size.n()}); + // Create weight quantization scale and offset with dimensions K x N + cutlass::HostTensor tensor_scale( + {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); + cutlass::HostTensor tensor_offset( + {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); + + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(4), + ElementInputA(-4), + 2); // <- Fill matrix A on host with uniform-distribution random data + if constexpr (has_offsets) { + cutlass::reference::host::TensorFillRandomUniform( + tensor_offset.host_view(), + 1, + ElementQOffset(0), + ElementQOffset(15), + 0); // <- Fill weight offsets on host with uniform-distribution random data + } + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(4), + ElementOutput(-4), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + int v = 7; + for (int c = 0; c < tensor_weight.extent()[1]; c++) { + for (int r = 0; r < tensor_weight.extent()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + + tensor_weight.at({r, c}) = ElementW((v1 << 4) | v0); + } + } + + for (int c = 0; c < tensor_scale.extent()[1]; c++) { + for (int r = 0; r < tensor_scale.extent()[0]; ++r) { + int f = (((c * v + r + v / 3 ) % 63) + 1); + v += 41; + int m = (c * v + r + v / 8 ) % 4; + tensor_scale.at({r, c}) = ElementQScale(static_cast(f) / static_cast(1 << (2 + m))); + } + } + +// // Fill tensor_weight with the patterned data, so that we can use +// // print to make sure the layout matches after loaded to registers +// int loop_val = 0; +// int offset = 3; +// for (int col_tile = 0; col_tile < tensor_weight.extent().column()/8; ++col_tile) { +// for (int row_tile = 0; row_tile < tensor_weight.extent().row()/4; ++row_tile) { +// for (int col = 0; col < 8; ++col) { +// for (int row = 0; row < 4; ++row) { +// auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); +// auto val = (loop_val + offset) % 256; +// tensor_weight.at(weight_cord) = ElementW(val); +// loop_val++; +// if (loop_val == 256) { +// loop_val = 0; +// offset += 11; +// } +// } +// } +// } +// } +// for (int col = 0; col < tensor_scale.extent().column(); ++col){ +// int c = col * QuantBlocking::kColumn; +// for (int row = 0; row < tensor_scale.extent().row(); ++row){ +// int r = row * QuantBlocking::kRow; +// auto weight_cord = cutlass::make_Coord(r/2, c); +// int w = 0; +// if (r % 2 == 0) { +// w = int(tensor_weight.at(weight_cord) & 0x0f); +// } else { +// w = int(tensor_weight.at(weight_cord) >> 4); +// } +// tensor_scale.at({row, col}) = w; +// #ifdef USE_QUANT_OFFSET +// tensor_offset.at({row, col}) = ElementQOffset(w); +// #endif +// } +// } + + // int fill_val = -512; + // int factor = 1; + // for (int col = 0; col < tensor_scale.extent().column(); ++col){ + // for (int row = 0; row < tensor_scale.extent().row(); ++row){ + // tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); + // fill_val++; + // if (fill_val == 512) { + // fill_val = -512; + // factor += 1; + // } + // } + // } + + // std::cout << "Matrix Weight:\n" << tensor_weight.host_view() << "\n"; + + // Prepacking weight matrix and quantization meta data ... + + cutlass::HostTensor tensor_weight_prepacked( + cutlass::make_Coord(problem_size.k(), problem_size.n()/2)); + prepack_weights_ref(problem_size.k(), problem_size.n(), + make_ConstMatrixRef(tensor_weight), + make_MatrixRef(tensor_weight_prepacked)); + + // std::cout << "Matrix Weight Prepacked:\n" << tensor_weight_prepacked.host_view() << "\n"; + + cutlass::HostTensor tensor_scale_prepacked( + {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); + cutlass::HostTensor tensor_offset_prepacked( + {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); + + auto scale_ref = make_ConstMatrixRef(tensor_scale); + prepack_quant_scales_ref( + problem_size.k(), problem_size.n(), scale_ref, + make_MatrixRef(tensor_scale_prepacked)); + if constexpr (has_offsets) { + auto offset_ref = make_ConstMatrixRef(tensor_offset); + prepack_quant_offsets_ref( + problem_size.k(), problem_size.n(), offset_ref, + make_MatrixRef(tensor_offset_prepacked)); + } + + // std::cout << "================== Matrix Scale ==========================\n"; + // for (int row = 0; row < tensor_scale_prepacked.extent().row(); ++row){ + // for (int col = 0; col < tensor_scale_prepacked.extent().column(); ++col){ + // printf("%.0f, ", float(tensor_scale_prepacked.at({row, col}))); + // } + // printf("\n"); + // } + + // Copy data from host to GPU... + tensor_a.sync_device(); + tensor_weight_prepacked.sync_device(); + tensor_scale_prepacked.sync_device(); + if constexpr (has_offsets) { + tensor_offset_prepacked.sync_device(); + } + tensor_c.sync_device(); + tensor_d.sync_device(); + cutlass::TensorRef ref_W( + reinterpret_cast(tensor_weight_prepacked.device_data()), + LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); + + // Construct events + cudaEvent_t finish_gemm_event; + auto cuda_err = cudaEventCreate(&finish_gemm_event); + ORT_ENFORCE(cuda_err == cudaSuccess, "Failed to create CUDA event."); + + // run GEMM + cutlass::Status status; + if constexpr (has_offsets){ + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + tensor_scale_prepacked.device_ref(), tensor_offset_prepacked.device_ref(), + tensor_c.device_ref(), tensor_d.device_ref()); + } else { + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + tensor_scale_prepacked.device_ref(), + tensor_c.device_ref(), tensor_d.device_ref()); + } + ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + + // Record an event when the GEMMs are complete + cuda_err = cudaEventRecord(finish_gemm_event); + ORT_ENFORCE(cuda_err == cudaSuccess, "Failed to record CUDA event: ", cudaGetErrorString(cuda_err)); + + // Wait for work on the device to complete. + cuda_err = cudaEventSynchronize(finish_gemm_event); + ORT_ENFORCE(cuda_err == cudaSuccess, "Failure during sync CUDA event: ", cudaGetErrorString(cuda_err)); + + cudaEventDestroy(finish_gemm_event); + + // Preparing reference kernel arguments + // Dequantizing weights and running reference kernel + + using ElementInputB = ElementInputA; + using LayoutInputB = cutlass::layout::ColumnMajor; + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create dequantized matrix B with dimensions K x N + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Dequantize weights and save into matrix B for reference + for (int col = 0; col < tensor_b.extent().column(); ++col){ + for (int row = 0; row < tensor_b.extent().row(); ++row) { + auto weight_cord = cutlass::make_Coord(row/2, col); + auto scale_cord = cutlass::make_Coord(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + const uint8_t offset = has_offsets ? tensor_offset.at(scale_cord) : 8; + int w = 0; + if (row % 2 == 0) { + w = int(tensor_weight.at(weight_cord) & 0x0f) - offset; + } else { + w = int(tensor_weight.at(weight_cord) >> 4) - offset; + } + auto scale = tensor_scale.at(scale_cord); + tensor_b.at({row, col}) = scale * float(w); + } + } + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + tensor_b.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + compute_gemm_ref( + problem_size, + alpha, + tensor_a.device_ref(), + tensor_b.device_ref(), + beta, + tensor_c.device_ref(), + tensor_ref_d.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + ORT_ENFORCE(passed, "Gemm kernel result wrong!"); +} + +template void run_blkq4_gemm<16, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime From 7ca652c38d433b2e02e85b54ab330d9e39a8cc7d Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 30 Nov 2023 01:21:04 +0000 Subject: [PATCH 02/13] add compilation flag --- cmake/CMakeLists.txt | 12 ++++++++++- cmake/external/cutlass.cmake | 20 ++++++++++--------- cmake/onnxruntime_providers_cuda.cmake | 2 +- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 4 ++++ .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 4 ++++ 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0eb224623f678..8e2c01f79e4ca 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -699,6 +699,7 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) +set(onnxruntime_USE_CUTLASS ON) if (onnxruntime_USE_CUDA) if (onnxruntime_USE_CUDA_NHWC_OPS) add_compile_definitions(ENABLE_CUDA_NHWC_OPS) @@ -715,6 +716,10 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) + message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") + set(onnxruntime_USE_CUTLASS OFF) + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) @@ -735,8 +740,13 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - + if (onnxruntime_USE_CUTLASS) + message( STATUS "Enable CUTLASS extension") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUTLASS=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUTLASS=1) + endif() endif() + if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index f04f4bec76cd5..efc708bd681c0 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,11 +1,13 @@ -include(FetchContent) -FetchContent_Declare( - cutlass - URL ${DEP_URL_cutlass} - URL_HASH SHA1=${DEP_SHA1_cutlass} -) +if (onnxruntime_USE_CUTLASS) + include(FetchContent) + FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} + ) -FetchContent_GetProperties(cutlass) -if(NOT cutlass_POPULATED) - FetchContent_Populate(cutlass) + FetchContent_GetProperties(cutlass) + if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) + endif() endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 33441a94bd324..6861052c22751 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -199,7 +199,7 @@ target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) + if (onnxruntime_USE_CUTLASS) include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) endif() diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index ff6c38bb56d32..eab0aa19d82f2 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -11,6 +11,8 @@ * well with CUTLASS headers. */ +#if USE_CUTLASS + #include #include "core/framework/float16.h" @@ -407,3 +409,5 @@ TEST(BlkQ4_GEMM, Sm80Test) { } // namespace test } // namespace onnxruntime + +#endif // USE_CUTLASS diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 6dcdad67e9511..7c2f99c62370a 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -11,6 +11,8 @@ * well with gtest headers. */ +#if USE_CUTLASS + #include "core/mickey/blk_q4/f16_gemm_sm80.h" #include "cutlass/util/host_tensor.h" @@ -487,3 +489,5 @@ template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); } // namespace test } // namespace cuda } // namespace onnxruntime + +#endif // USE_CUTLASS From 93ac7e3387b45334aeba19b6a7c58d13bd03fb50 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 30 Nov 2023 17:50:51 +0000 Subject: [PATCH 03/13] require cuda 11.4 for cutlass --- cmake/CMakeLists.txt | 9 +----- cmake/external/cutlass.cmake | 14 ++++----- cmake/onnxruntime_providers_cuda.cmake | 6 ++-- .../cutlass_ext/q4gemm/device/quantb_gemm.h | 5 ++-- .../threadblock/quantb_mma_multistage.h | 30 +++++++++---------- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 4 --- .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 4 --- 7 files changed, 26 insertions(+), 46 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8e2c01f79e4ca..b2d68d78ae47a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -699,7 +699,6 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) -set(onnxruntime_USE_CUTLASS ON) if (onnxruntime_USE_CUDA) if (onnxruntime_USE_CUDA_NHWC_OPS) add_compile_definitions(ENABLE_CUDA_NHWC_OPS) @@ -717,8 +716,7 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) - message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") - set(onnxruntime_USE_CUTLASS OFF) + message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) @@ -740,11 +738,6 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - if (onnxruntime_USE_CUTLASS) - message( STATUS "Enable CUTLASS extension") - list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUTLASS=1) - list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUTLASS=1) - endif() endif() if (onnxruntime_USE_VITISAI) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index efc708bd681c0..f4c55ae105560 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,13 +1,11 @@ -if (onnxruntime_USE_CUTLASS) - include(FetchContent) - FetchContent_Declare( +include(FetchContent) +FetchContent_Declare( cutlass URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} - ) +) - FetchContent_GetProperties(cutlass) - if(NOT cutlass_POPULATED) - FetchContent_Populate(cutlass) - endif() +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 6861052c22751..aee02c0e304ac 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -199,10 +199,8 @@ target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_CUTLASS) - include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) - endif() + include(cutlass) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h index 2e9b04d93b12e..5088d918cb53a 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -184,7 +184,7 @@ class QuantBGemm { using QuantBlocking = QuantBlocking_; static constexpr bool kHasQOffset = !(std::is_same::value); - // TODO enable uint4_t or smaller for QOffset + // TODO(chenfucn): consider moving to uint4_t or smaller for QOffset static_assert(!kHasQOffset || std::is_same::value, "QOffset must be uint8_t"); /// Define the kernel @@ -378,8 +378,7 @@ class QuantBGemm { return Status::kErrorInternal; } } - } - else { + } else { if (args.split_k_slices > 1) { return Status::kErrorInvalidProblem; diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h index c8aff17151f29..b4e7f1eb732d1 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -651,7 +651,7 @@ class QuantBMmaMultistage : smem_iterator_B_.add_tile_offset({1, 0}); smem_iterator_QScale_.add_tile_offset({1, 0}); - if constexpr (kHasQOffset){ + if constexpr (kHasQOffset) { iterator_QOffset.add_tile_offset({1, 0}); smem_iterator_QOffset_.add_tile_offset({1, 0}); } @@ -664,7 +664,7 @@ class QuantBMmaMultistage : smem_iterator_A_.add_tile_offset({0, -Base::kStages}); smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0}); - if constexpr (kHasQOffset){ + if constexpr (kHasQOffset) { smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0}); } smem_write_stage_idx_ = 0; @@ -703,7 +703,7 @@ class QuantBMmaMultistage : static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); - if constexpr(kHasQOffset){ + if constexpr(kHasQOffset) { // Async Copy for quantization offset typename IteratorQOffset::AccessType *dst_ptr = reinterpret_cast( @@ -872,7 +872,7 @@ class QuantBMmaMultistage : cutlass::arch::cp_async( dst_ptr, gmem_ptr, iterator_QScale.valid()); - if constexpr (kHasQOffset){ + if constexpr (kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); // Async Copy for quantization offset @@ -907,8 +907,8 @@ class QuantBMmaMultistage : cutlass::arch::cp_async_wait(); __syncthreads(); - if constexpr(debug_layout){ - if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1){ + if constexpr(debug_layout) { + if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { if (threadIdx.x == 0){ printf("stage: %d\n", smem_write_stage_idx_); } @@ -957,7 +957,7 @@ class QuantBMmaMultistage : iterator_B, (warp_mma_k + 1) % Base::kWarpGemmIterations); - if constexpr(debug_layout){ + if constexpr(debug_layout) { if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); } @@ -974,7 +974,7 @@ class QuantBMmaMultistage : pipe_state.warp_loaded_frag_QScale_, pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout){ + if constexpr(debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } @@ -1049,7 +1049,7 @@ class QuantBMmaMultistage : for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { // In the case of small M, memory latency dominates. We try to move uses far // from their definitions to hide latency. - if constexpr(debug_layout){ + if constexpr(debug_layout) { if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); } @@ -1066,7 +1066,7 @@ class QuantBMmaMultistage : pipe_state.warp_loaded_frag_QScale_, pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout){ + if constexpr(debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } @@ -1159,7 +1159,7 @@ class QuantBMmaMultistage : iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); - if constexpr(kHasQOffset){ + if constexpr(kHasQOffset) { iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); } @@ -1180,9 +1180,9 @@ class QuantBMmaMultistage : copy_tiles_and_advance(iterator_A, iterator_B, 0); - if constexpr(Shape::kM > 32){ + if constexpr(Shape::kM > 32) { // the case of bigger m - if constexpr(debug_layout){ + if constexpr(debug_layout) { if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); } @@ -1199,7 +1199,7 @@ class QuantBMmaMultistage : pipe_state.warp_loaded_frag_QScale_, pipe_state.warp_loaded_frag_QOffset_); - if constexpr(debug_layout){ + if constexpr(debug_layout) { LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); } } else { @@ -1215,7 +1215,7 @@ class QuantBMmaMultistage : // Mainloop CUTLASS_GEMM_LOOP for (; gemm_k_iterations > (-Base::kStages + 1);) { - if constexpr(Shape::kM > 32){ + if constexpr(Shape::kM > 32) { mac_loop_iter( pipe_state, accum, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index eab0aa19d82f2..ff6c38bb56d32 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -11,8 +11,6 @@ * well with CUTLASS headers. */ -#if USE_CUTLASS - #include #include "core/framework/float16.h" @@ -409,5 +407,3 @@ TEST(BlkQ4_GEMM, Sm80Test) { } // namespace test } // namespace onnxruntime - -#endif // USE_CUTLASS diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 7c2f99c62370a..6dcdad67e9511 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -11,8 +11,6 @@ * well with gtest headers. */ -#if USE_CUTLASS - #include "core/mickey/blk_q4/f16_gemm_sm80.h" #include "cutlass/util/host_tensor.h" @@ -489,5 +487,3 @@ template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); } // namespace test } // namespace cuda } // namespace onnxruntime - -#endif // USE_CUTLASS From cf3975772a17bab0424f0afc9638b9da2be0d7cc Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 25 Jan 2024 05:09:37 +0000 Subject: [PATCH 04/13] fix comments and rebase on main --- .../cutlass_ext/q4gemm/device/quantb_gemm.h | 16 +- .../q4gemm/kernel/default_quantb_gemm.h | 8 +- .../q4gemm/threadblock/default_quantb_mma.h | 12 +- .../threadblock/default_quantb_mma_core.h | 16 +- .../threadblock/quantb_mma_multistage.h | 4 +- .../test/cuda_host/blkq4_fp16_quant_sm80.h | 199 ++++++++++++++++++ .../cuda/test_cases/blkq4_fp16_gemm_sm80.h | 170 +-------------- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 6 +- .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 36 +--- .../cuda_execution_provider_test.cc | 4 +- 10 files changed, 239 insertions(+), 232 deletions(-) create mode 100644 onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h index 5088d918cb53a..36b52199362d5 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -85,7 +85,7 @@ template < /// Element type for quant offsets typename ElementQOffset_, /// Layout type for quant scales and offsets - typename LayoutQScale_, + typename LayoutQMeta_, /// Blocking dimensions for quantization typename QuantBlocking_, /// Element type for C and D matrix operands @@ -180,7 +180,7 @@ class QuantBGemm { "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout."); using ElementQScale = ElementQScale_; using ElementQOffset = ElementQOffset_; - using LayoutQScale = LayoutQScale_; + using LayoutQMeta = LayoutQMeta_; using QuantBlocking = QuantBlocking_; static constexpr bool kHasQOffset = !(std::is_same::value); @@ -197,7 +197,7 @@ class QuantBGemm { kAlignmentB, ElementQScale, ElementQOffset, - LayoutQScale, + LayoutQMeta, QuantBlocking, ElementC, LayoutC, @@ -230,8 +230,8 @@ class QuantBGemm { TensorRef ref_B; TensorRef ref_C; TensorRef ref_D; - TensorRef ref_Qscale; - TensorRef ref_Qoffset; + TensorRef ref_Qscale; + TensorRef ref_Qoffset; typename EpilogueOutputOp::Params epilogue; @@ -258,7 +258,7 @@ class QuantBGemm { GemmCoord problem_size_, TensorRef ref_A_, TensorRef ref_B_, - TensorRef ref_Qscale_, + TensorRef ref_Qscale_, TensorRef ref_C_, TensorRef ref_D_, typename EpilogueOutputOp::Params epilogue_ = @@ -279,8 +279,8 @@ class QuantBGemm { GemmCoord problem_size_, TensorRef ref_A_, TensorRef ref_B_, - TensorRef ref_Qscale_, - TensorRef ref_Qoffset_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, TensorRef ref_C_, TensorRef ref_D_, typename EpilogueOutputOp::Params epilogue_ = diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h index 3860a241395a6..2f4460bb59e9f 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -96,7 +96,7 @@ template < /// Element type for quant offsets typename ElementQOffset_, /// Layout type for quant scales and offsets - typename LayoutQScale_, + typename LayoutQMeta_, /// Blocking dimensions for quantization typename QuantBlocking_, /// Access granularity of quant scales in units of elements @@ -167,7 +167,7 @@ template < /// Element type for quant offsets typename ElementQOffset, /// Layout type for quant scales - typename LayoutQScale, + typename LayoutQMeta, /// Blocking dimensions for quantization typename QuantBlocking, /// Access granularity of quant scales in units of elements @@ -207,7 +207,7 @@ template < typename PermuteBLayout > struct DefaultQuantBGemm struct DefaultQuantBMma; @@ -218,14 +218,14 @@ struct DefaultQuantBMma; + ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>; using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset; using AccessTypeQOffset = cutlass::Array; using IteratorQOffset = cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< - typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQScale, + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; // Define the threadblock-scoped multistage matrix multiply diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h index 060d2134cfcef..ad322f6505200 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h @@ -93,7 +93,7 @@ template < /// Element data type of quant offset typename ElementQOffset, /// Layout of quant scale - typename LayoutQScale, + typename LayoutQMeta, /// Blocking dimensions for quantization typename QuantBlocking, /// Data type of accumulator @@ -157,7 +157,7 @@ template < /// Element data type of quant offset typename ElementQOffset_, /// Layout of quant scale - typename LayoutQScale_, + typename LayoutQMeta_, /// Blocking dimensions for quantization typename QuantBlocking_, /// Data type of accumulator @@ -174,7 +174,7 @@ template < cutlass::arch::CacheOperation::Kind CacheOpB> struct DefaultQuantBMmaCore { using Shape = Shape_; @@ -187,7 +187,7 @@ struct DefaultQuantBMmaCore, ElementB, SmemLayoutB, 1, IteratorThreadMapB>; - using SmemLayoutQScale = LayoutQScale; - using SmemLayoutQOffset = LayoutQScale; + using SmemLayoutQScale = LayoutQMeta; + using SmemLayoutQOffset = LayoutQMeta; /// Threadblock-level quantization meta data shape using ThreadblockQShape = MatrixShape; @@ -279,12 +279,12 @@ struct DefaultQuantBMmaCore 0, "QuantBlocking too big to fit in a thread block!"); static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, "Only support single column or row quantize blocking!"); - static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, + static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); /// Threadblock-level quantization meta data shape in pitch-linear layout using TBQPitchLinearShape = typename std::conditional< - std::is_same::value, + std::is_same::value, layout::PitchLinearShape, layout::PitchLinearShape>::type; diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h index b4e7f1eb732d1..dfd1032b42c68 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -277,7 +277,7 @@ class QuantBMmaBase { } CUTLASS_HOST_DEVICE - static typename Operator::SmemLayoutQScale LayoutQScale() { + static typename Operator::SmemLayoutQScale LayoutQMeta() { return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); } @@ -301,7 +301,7 @@ class QuantBMmaBase { /// Returns a TensorRef to the quantization scales CUTLASS_HOST_DEVICE TensorRefQScale operand_QScale_ref() { - return TensorRefQScale{operand_QScale.data(), LayoutQScale()}; + return TensorRefQScale{operand_QScale.data(), LayoutQMeta()}; } CUTLASS_HOST_DEVICE diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h new file mode 100644 index 0000000000000..6b1c883f96041 --- /dev/null +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -0,0 +1,199 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_quant_sm80.h + * + * Abstract: + * Oracle computation for blockwise 4b quantization for fp16 + * gemm kernel specifically for Ampere GPUs. This is used for + * testing the cuda kernel implementation in + * (test/providers/cuda/test_cases) + * and for testing the cuda op prepack code in (test/optimizer) + */ + +#pragma once + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace test { + +static inline void sm80_prepack_weights_ref( + int rows, + int columns, + const MatrixRef& tensor_weight, + const MatrixRef& tensor_weight_prepacked) { + ORT_ENFORCE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns, + "Unexpected tensor_weight shape! Expected: (", rows / 2, ", ", columns, "), Got: (", + tensor_weight.shape()[0], ", ", tensor_weight.shape()[1], ")."); + ORT_ENFORCE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2, + "tensor_weight_prepacked shape is not compatible with prepacked weight shape"); + + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } +} + +template < + typename ScaleElementT, + typename Layout, + typename QuantBlocking> +inline +void sm80_prepack_quant_scales_ref( + int rows, + int columns, + const MatrixRef& tensor_scale, + const MatrixRef& tensor_scale_prepacked) { + ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn), + "Unexpected tensor_scale shape! Expected: (", + rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); + ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + std::copy(tensor_scale.data().begin(), tensor_scale.data().end(), tensor_scale_prepacked.data().begin()); + } +} + +template +inline +void sm80_prepack_quant_offsets_ref( + int rows, + int columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + ORT_ENFORCE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn), + "Unexpected tensor_offset shape! Expected: (", + rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); + ORT_ENFORCE(tensor_offset_prepacked.shape() == tensor_offset.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + std::copy(tensor_offset.data().begin(), tensor_offset.data().end(), tensor_offset_prepacked.data().begin()); + return; + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); + tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); + tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); + tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); + } + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h index 42a118a89ff54..4db2a6340ed75 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -16,6 +16,7 @@ #include "core/util/matrix_layout.h" #include "core/common/common.h" +#include "test/cuda_host/blkq4_fp16_quant_sm80.h" namespace onnxruntime { namespace cuda { @@ -23,175 +24,6 @@ namespace test { Status sm80_supported(); -static inline void prepack_weights_ref( - int rows, - int columns, - const MatrixRef& tensor_weight, - const MatrixRef& tensor_weight_prepacked) { - ORT_ENFORCE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns, - "Unexpected tensor_weight shape! Expected: (", rows / 2, ", ", columns, "), Got: (", - tensor_weight.shape()[0], ", ", tensor_weight.shape()[1], ")."); - ORT_ENFORCE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2, - "tensor_weight_prepacked shape is not compatible with prepacked weight shape"); - - auto t0_base = make_Position(0, 0); - auto t1_base = make_Position(4, 0); - auto t2_base = make_Position(0, 8); - auto t3_base = make_Position(4, 8); - for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { - for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { - // Packing from a 8x16 tile to a 16x8 tile - auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); - auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto cord = make_Position(row, col); - auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 - uint8_t buf[4]; - buf[0] = tensor_weight.at(dtile_base + t0_base + cord); - buf[1] = tensor_weight.at(dtile_base + t1_base + cord); - buf[2] = tensor_weight.at(dtile_base + t2_base + cord); - buf[3] = tensor_weight.at(dtile_base + t3_base + cord); - - // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights - // are in different b16 register at the same positions. This makes it easier to convert to - // fp16x2 format in a b32 register - - tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); - tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); - } - } - } - } -} - -template < - typename ScaleElementT, - typename Layout, - typename QuantBlocking> -void prepack_quant_scales_ref( - int rows, - int columns, - const MatrixRef& tensor_scale, - const MatrixRef& tensor_scale_prepacked) { - ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn), - "Unexpected tensor_scale shape! Expected: (", - rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); - ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - - for (int col = 0; col < tensor_scale.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); - tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); - tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); - tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); - } - } - } - } else { - // In all other cases, we don't prepack scale or offset - std::copy(tensor_scale.data().begin(), tensor_scale.data().end(), tensor_scale_prepacked.data().begin()); - } -} - -template -void prepack_quant_offsets_ref( - size_t rows, - size_t columns, - MatrixRef tensor_offset, - MatrixRef tensor_offset_prepacked) { - ORT_ENFORCE(tensor_offset_prepacked.shape() == tensor_offset.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (QuantBlocking::kRow != 1) { - std::copy(tensor_offset.data().begin(), tensor_offset.data().end(), tensor_offset_prepacked.data().begin()); - return; - } - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own - // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to - // convert to fp16x2 format in a b32 register - tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); - tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); - tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); - tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); - } - } - } - } -} - template < int block_size, bool column_wise_blocking, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index ff6c38bb56d32..60c9b16f4cf88 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -255,7 +255,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { std::vector packed_w_ref(q_weight_shape.product()); MatrixRef tensor_packed_w_ref( packed_w_ref, make_Position(rows, columns / 2)); - onnxruntime::cuda::test::prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); + onnxruntime::test::sm80_prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); std::vector packed_w(q_weight_shape.product()); MatrixRef tensor_packed_w( @@ -277,7 +277,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) : tensor_scale; if (Base::ShouldRearrangeMeta) { - onnxruntime::cuda::test::prepack_quant_scales_ref( + onnxruntime::test::sm80_prepack_quant_scales_ref( rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); } @@ -302,7 +302,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) : tensor_offset; if (Base::ShouldRearrangeMeta) { - onnxruntime::cuda::test::prepack_quant_offsets_ref( + onnxruntime::test::sm80_prepack_quant_offsets_ref( rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); } diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 6dcdad67e9511..733e88da9fc89 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -323,11 +323,10 @@ void run_blkq4_gemm(int m, int n, int k) { cutlass::HostTensor tensor_weight_prepacked( cutlass::make_Coord(problem_size.k(), problem_size.n()/2)); - prepack_weights_ref(problem_size.k(), problem_size.n(), - make_ConstMatrixRef(tensor_weight), - make_MatrixRef(tensor_weight_prepacked)); - - // std::cout << "Matrix Weight Prepacked:\n" << tensor_weight_prepacked.host_view() << "\n"; + onnxruntime::test::sm80_prepack_weights_ref( + problem_size.k(), problem_size.n(), + make_ConstMatrixRef(tensor_weight), + make_MatrixRef(tensor_weight_prepacked)); cutlass::HostTensor tensor_scale_prepacked( {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); @@ -335,24 +334,16 @@ void run_blkq4_gemm(int m, int n, int k) { {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); auto scale_ref = make_ConstMatrixRef(tensor_scale); - prepack_quant_scales_ref( + onnxruntime::test::sm80_prepack_quant_scales_ref( problem_size.k(), problem_size.n(), scale_ref, make_MatrixRef(tensor_scale_prepacked)); if constexpr (has_offsets) { auto offset_ref = make_ConstMatrixRef(tensor_offset); - prepack_quant_offsets_ref( + onnxruntime::test::sm80_prepack_quant_offsets_ref( problem_size.k(), problem_size.n(), offset_ref, make_MatrixRef(tensor_offset_prepacked)); } - // std::cout << "================== Matrix Scale ==========================\n"; - // for (int row = 0; row < tensor_scale_prepacked.extent().row(); ++row){ - // for (int col = 0; col < tensor_scale_prepacked.extent().column(); ++col){ - // printf("%.0f, ", float(tensor_scale_prepacked.at({row, col}))); - // } - // printf("\n"); - // } - // Copy data from host to GPU... tensor_a.sync_device(); tensor_weight_prepacked.sync_device(); @@ -366,11 +357,6 @@ void run_blkq4_gemm(int m, int n, int k) { reinterpret_cast(tensor_weight_prepacked.device_data()), LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); - // Construct events - cudaEvent_t finish_gemm_event; - auto cuda_err = cudaEventCreate(&finish_gemm_event); - ORT_ENFORCE(cuda_err == cudaSuccess, "Failed to create CUDA event."); - // run GEMM cutlass::Status status; if constexpr (has_offsets){ @@ -386,16 +372,6 @@ void run_blkq4_gemm(int m, int n, int k) { } ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); - // Record an event when the GEMMs are complete - cuda_err = cudaEventRecord(finish_gemm_event); - ORT_ENFORCE(cuda_err == cudaSuccess, "Failed to record CUDA event: ", cudaGetErrorString(cuda_err)); - - // Wait for work on the device to complete. - cuda_err = cudaEventSynchronize(finish_gemm_event); - ORT_ENFORCE(cuda_err == cudaSuccess, "Failure during sync CUDA event: ", cudaGetErrorString(cuda_err)); - - cudaEventDestroy(finish_gemm_event); - // Preparing reference kernel arguments // Dequantizing weights and running reference kernel diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index a70e439cdf755..53a5d9b5921fb 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -27,7 +27,7 @@ TEST(TestDeferredRelease, WithArena) { AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; @@ -66,7 +66,7 @@ TEST(TestDeferredRelease, WithoutArena) { // For details, see CUDAPinnedAllocator in cuda_allocator.cc. // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; From 73679d3fcc198c1739ba32bffec2ade139738ad6 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Mon, 29 Jan 2024 00:04:00 +0000 Subject: [PATCH 05/13] refactor blkq4 gemm quant input generation --- .../test/cuda_host/blkq4_fp16_quant_sm80.h | 74 +++++- .../cuda/test_cases/blkq4_fp16_gemm_sm80.h | 154 ++++++++++++ .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 232 ++++++------------ .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 231 +++++------------ 4 files changed, 350 insertions(+), 341 deletions(-) diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index 6b1c883f96041..d49484a072be1 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -69,8 +69,7 @@ template < typename ScaleElementT, typename Layout, typename QuantBlocking> -inline -void sm80_prepack_quant_scales_ref( +inline void sm80_prepack_quant_scales_ref( int rows, int columns, const MatrixRef& tensor_scale, @@ -130,6 +129,77 @@ void sm80_prepack_quant_scales_ref( } } +template +inline void sm80_expand_prepack_quant_offsets_ref( + int rows, + int columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, + "Unexpected tensor_offset_prepacked shape (", + tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ORT_ENFORCE(tensor_offset.shape() == zp_shape, + "Unexpected tensor_offset shape (", + tensor_offset.shape()[0], ",", tensor_offset.shape()[1], + ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + return; + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } +} + template inline void sm80_prepack_quant_offsets_ref( diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h index 4db2a6340ed75..4cfb074e7df7d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -14,8 +14,11 @@ #pragma once +#include + #include "core/util/matrix_layout.h" #include "core/common/common.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" #include "test/cuda_host/blkq4_fp16_quant_sm80.h" namespace onnxruntime { @@ -24,6 +27,157 @@ namespace test { Status sm80_supported(); +/** + * @brief Generate a set of quantized weights, scales and offsets + * and dequantized weights for testing quantization and + * dequantization. All outputs are column major layout. + * + * @tparam ElementT The type of the dequantized weights. + * @tparam block_size The block size of the quantization. + * @tparam col_blocking Whether to use column blocking (all elements of + * a block comes from a single column) or row blocking + * @tparam has_offsets Whether to generate offsets. + * + * @param[in] rows The number of rows of the weight matrix. + * @param[in] columns The number of columns of the weight matrix. + * @param[out] dequants The dequantized weights, column major layout. + * @param[out] q_weights The quantized weights, column major layout. + * @param[out] q_scales The scales, column major layout. + * @param[out] q_zp The zero points, column major layout. + */ +template +inline +void blkq4_weights_gen( + int rows, int columns, + std::vector& dequants, + std::vector& q_weights, + std::vector& q_scales, + std::vector& q_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + q_weights.resize(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + q_scales.resize(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + uint32_t v = dis(gen); + uint32_t m = (v % 63) + 1; + uint32_t e = (v >> 6) % 4; + q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + MatrixRef tensor_offset; + if constexpr(has_offsets) { + q_zp.resize(zp_shape.product()); + tensor_offset = MatrixRef( + q_zp, zp_shape); + for (int c = 0; c < zp_shape[1]; c++) { + for (int r = 0; r < zp_shape[0]; ++r) { + uint8_t v0 = dis(gen) % 16; + uint8_t v1 = 8; + if (r * 2 + 1 < meta_shape[0]) { + v1 = dis(gen) % 16; + } + tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); + } + } + } + + dequants.resize(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + uint8_t offset = 8; + if constexpr(has_offsets) { + if (scale_cord[0] % 2 == 0) { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; + } else { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; + } + } + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } + +} + template < int block_size, bool column_wise_blocking, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 60c9b16f4cf88..148055bd046e2 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -14,7 +14,6 @@ #include #include "core/framework/float16.h" -#include "core/mickey/blk_q4/f16_prepack_sm80.h" #include "core/mlas/inc/mlas_q4.h" #include "blkq4_fp16_gemm_sm80.h" @@ -24,15 +23,15 @@ namespace onnxruntime { namespace test { -template -void testPrepack(int rows, int columns, bool has_offset = true) { +template +void testPrepack(int rows, int columns) { using ElementT = MLFloat16; constexpr int block_size = 32; using Base = onnxruntime::cuda::BlockwiseQuantization< ElementT, block_size, 4, - ColumnMajorQuantBlocking>; + col_blocking>; using QuantBlocking = typename Base::QuantBlocking; using ElementW = typename Base::ElementW; @@ -40,147 +39,40 @@ void testPrepack(int rows, int columns, bool has_offset = true) { using ElementQOffset = typename Base::ElementQOffset; using LayoutQmeta = typename Base::LayoutQmeta; - unsigned int seed = 28571; // Replace with desired seed value - std::seed_seq seq{seed}; - std::mt19937 gen(seq); - std::uniform_int_distribution<> dis(0, 8192); - const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); - // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. - // + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + rows, columns, dequants, q_weights, q_scales, q_zp); - std::vector q_weights(q_weight_shape.product()); - MatrixRef tensor_q_weight( + // for quantization tool, the input is row major, all outputs are column major + MatrixRef tensor_q_weight( q_weights, make_Position(rows / 2, columns)); - int v = 7; - for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { - for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); - } - } - - std::vector q_scales(meta_shape.product()); - for (size_t i = 0; i < q_scales.size(); i++) { - q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f); - } - MatrixRef tensor_scale( + MatrixRef tensor_scale( q_scales, meta_shape); - - std::vector q_zp(meta_shape.product()); - for (size_t i = 0; i < q_zp.size(); i++) { - q_zp[i] = dis(gen) % 16; - } - MatrixRef tensor_offset( - q_zp, meta_shape); - -#if 0 // debug - // Fill tensor_q_weight with the patterned data, easier to debug with print - int loop_val = 0; - int offset = 3; - for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) { - for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) { - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); - auto val = (loop_val + offset) % 256; - tensor_q_weight.at(weight_cord) = ElementW(val); - loop_val++; - if (loop_val == 256) { - loop_val = 0; - offset += 11; - } - } - } - } - } - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - int c = col * QuantBlocking::kColumn; - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - int r = row * QuantBlocking::kRow; - auto weight_cord = cutlass::make_Coord(r/2, c); - int w = 0; - if (r % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - tensor_scale.at({row, col}) = w; - tensor_offset.at({row, col}) = ElementQOffset(w); - } - } - - int fill_val = -512; - int factor = 1; - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); - fill_val++; - if (fill_val == 512) { - fill_val = -512; - factor += 1; - } - } + MatrixRef tensor_offset; + if constexpr(has_offset) { + tensor_offset = MatrixRef(q_zp, zp_shape); } -#endif // debug - - std::vector dequants(rows * columns); - MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); - - // Dequantize weights and save into matrix B for reference + // for quantization tool, the input is row major, test weight gen output is column major + std::vector dequants_transposed(dequants.size()); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + MatrixRef tensor_dequant_transposed(dequants_transposed, make_Position(rows, columns)); for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { - auto weight_cord = make_Position(row / 2, col); - auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8; - int w = 0; - if (row % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - float scale = float(tensor_scale.at(scale_cord)); - float dequant = scale * float(w - offset); - tensor_dequant.at(row, col) = ElementT(dequant); - // Prints for help debugging in case of test failure - // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + tensor_dequant_transposed.at(row, col) = tensor_dequant.at(row, col); } } int q_rows, q_cols; MlasBlockwiseQuantizedShape( - block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols); + block_size, col_blocking, rows, columns, q_rows, q_cols); // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes EXPECT_EQ(q_rows, q_weight_shape[0]); EXPECT_EQ(q_cols, q_weight_shape[1]); @@ -194,19 +86,18 @@ void testPrepack(int rows, int columns, bool has_offset = true) { std::vector o_scales(meta_shape.product()); MatrixRef tensor_o_scales(o_scales, meta_shape); - std::vector o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true); - MatrixRef tensor_o_zp( - o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + std::vector o_zp(zp_shape.product()); + MatrixRef tensor_o_zp(o_zp, zp_shape); MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, - tensor_dequant.data().data(), block_size, - ColumnMajorQuantBlocking, rows, columns, columns, nullptr); + dequants_transposed.data(), block_size, + col_blocking, rows, columns, columns, nullptr); for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) << "quantized value mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -215,16 +106,17 @@ void testPrepack(int rows, int columns, bool has_offset = true) { for (int row = 0; row < meta_shape[0]; row += 2) { if (has_offset) { uint8_t pair01 = tensor_o_zp.at(row / 2, col); - EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf) + uint8_t expected_pair01 = tensor_offset.at(row / 2, col); + EXPECT_EQ(expected_pair01 & 0xf, pair01 & 0xf) << "quantized offset mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4) + EXPECT_EQ(expected_pair01 >> 4, pair01 >> 4) << "quantized offset mismatch at [" << row + 1 << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -232,22 +124,22 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) << "quantized scale mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; if (row + 1 < meta_shape[0]) { EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) << "quantized scale mismatch at [" << row + 1 << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } } // - // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight, - // quantization scale tensor_scale and quantization offset tensor_offset. The above - // testing just make sure our test setup is consistent with quantization tool output. + // Now we just setup quantized weights tensor_q_weight, quantization scale tensor_scale + // and quantization offset tensor_offset. The above tests just make sure our setup is + // consistent with quantization tool output. // // Next we test the prepack code // @@ -267,18 +159,23 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) << "prepacked weights mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } std::vector packed_scales_ref(meta_shape.product()); MatrixRef tensor_packed_s_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) - : tensor_scale; - if (Base::ShouldRearrangeMeta) { + make_MatrixRef(packed_scales_ref, meta_shape); + if constexpr(Base::ShouldRearrangeMeta) { onnxruntime::test::sm80_prepack_quant_scales_ref( rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); + } else { + for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) { + tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col); + } + } } std::vector packed_scales(meta_shape.product()); @@ -291,7 +188,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) << "prepacked scales mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -299,11 +196,20 @@ void testPrepack(int rows, int columns, bool has_offset = true) { if (has_offset) { std::vector packed_zp_ref(meta_shape.product()); MatrixRef tensor_packed_zp_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) - : tensor_offset; - if (Base::ShouldRearrangeMeta) { - onnxruntime::test::sm80_prepack_quant_offsets_ref( + make_MatrixRef(packed_zp_ref, meta_shape); + if constexpr(Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_expand_prepack_quant_offsets_ref( rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); + } else { + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_packed_zp_ref.at(row, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4; + } + } + } } std::vector packed_zp(meta_shape.product()); @@ -316,7 +222,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) { EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) << "prepacked offsets mismatch at [" << row << "," << col << "]" << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << (col_blocking ? "Column-wise-block" : "Row-wise-block") << std::endl; } } @@ -332,9 +238,9 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { } testPrepack(32, 32); - testPrepack(32, 32, false); + testPrepack(32, 32); testPrepack(32, 32); - testPrepack(32, 32, false); + testPrepack(32, 32); testPrepack(32, 64); testPrepack(32, 128); testPrepack(32, 256); @@ -342,9 +248,9 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { testPrepack(128, 32); testPrepack(256, 32); testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); testPrepack(32, 64); testPrepack(32, 128); testPrepack(32, 256); @@ -352,9 +258,9 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { testPrepack(128, 32); testPrepack(256, 32); testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); } TEST(BlkQ4_GEMM, Sm80Test) { diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 733e88da9fc89..69c929d446ce4 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -11,6 +11,10 @@ * well with gtest headers. */ +#include +#include +#include + #include "core/mickey/blk_q4/f16_gemm_sm80.h" #include "cutlass/util/host_tensor.h" @@ -149,6 +153,10 @@ template< bool small_m, bool has_offsets> void run_blkq4_gemm(int m, int n, int k) { + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution<> dis(0, 8192); using ElementDequant = cutlass::half_t; using QuantBlocking = @@ -173,23 +181,38 @@ void run_blkq4_gemm(int m, int n, int k) { using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + const auto q_weight_shape = cutlass::make_Coord(problem_size.k()/2, problem_size.n()); + const auto meta_shape = cutlass::make_Coord(problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn); + + // + // Generate quantized and dequantizeed input matrix B [K, N] + // + static_assert(std::is_same::value); + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + problem_size.k(), problem_size.n(), dequants, q_weights, q_scales, q_zp); + + using PrepackT = onnxruntime::cuda::BlockwiseQuantization< + ElementDequant, + block_size, + 4, + column_wise_blocking>; + + std::vector packed_w(q_weight_shape.product()); + PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w); + std::vector packed_scales(meta_shape.product()); + PrepackT::prepack_quant_scales(problem_size.k(), problem_size.n(), q_scales, packed_scales); + std::vector packed_zp; + if constexpr (has_offsets) { + packed_zp.resize(meta_shape.product()); + PrepackT::prepack_quant_offsets(problem_size.k(), problem_size.n(), q_zp, packed_zp); + } - // Initialize tensors using CUTLASS helper functions cutlass::HostTensor tensor_a( problem_size.mk()); // <- Create matrix A with dimensions M x K - - // Create weight matrix with dimensions K x N. - // Actual weight type is int4, we use ElementW = uint8 to avoid possible compilation - // troubles. Since the layout is column major, we are packing 2 weights in a column - // into one int8 - cutlass::HostTensor tensor_weight( - {problem_size.k()/2, problem_size.n()}); - // Create weight quantization scale and offset with dimensions K x N - cutlass::HostTensor tensor_scale( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - cutlass::HostTensor tensor_offset( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - cutlass::HostTensor tensor_c( problem_size.mn()); // <- Create matrix C with dimensions M x N cutlass::HostTensor tensor_d( @@ -203,14 +226,6 @@ void run_blkq4_gemm(int m, int n, int k) { ElementInputA(4), ElementInputA(-4), 2); // <- Fill matrix A on host with uniform-distribution random data - if constexpr (has_offsets) { - cutlass::reference::host::TensorFillRandomUniform( - tensor_offset.host_view(), - 1, - ElementQOffset(0), - ElementQOffset(15), - 0); // <- Fill weight offsets on host with uniform-distribution random data - } cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), 1, @@ -221,188 +236,52 @@ void run_blkq4_gemm(int m, int n, int k) { tensor_d.host_view()); // <- fill matrix D on host with zeros // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. + // Copy data from host to GPU... // + thrust::device_vector d_packed_w(packed_w); + cutlass::TensorRef ref_W( + reinterpret_cast(d_packed_w.data().get()), + LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); - int v = 7; - for (int c = 0; c < tensor_weight.extent()[1]; c++) { - for (int r = 0; r < tensor_weight.extent()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - - tensor_weight.at({r, c}) = ElementW((v1 << 4) | v0); - } - } - - for (int c = 0; c < tensor_scale.extent()[1]; c++) { - for (int r = 0; r < tensor_scale.extent()[0]; ++r) { - int f = (((c * v + r + v / 3 ) % 63) + 1); - v += 41; - int m = (c * v + r + v / 8 ) % 4; - tensor_scale.at({r, c}) = ElementQScale(static_cast(f) / static_cast(1 << (2 + m))); - } - } + thrust::device_vector d_packed_scales(packed_scales); + cutlass::TensorRef ref_scales( + d_packed_scales.data().get(), LayoutInputQScale::packed(meta_shape)); -// // Fill tensor_weight with the patterned data, so that we can use -// // print to make sure the layout matches after loaded to registers -// int loop_val = 0; -// int offset = 3; -// for (int col_tile = 0; col_tile < tensor_weight.extent().column()/8; ++col_tile) { -// for (int row_tile = 0; row_tile < tensor_weight.extent().row()/4; ++row_tile) { -// for (int col = 0; col < 8; ++col) { -// for (int row = 0; row < 4; ++row) { -// auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); -// auto val = (loop_val + offset) % 256; -// tensor_weight.at(weight_cord) = ElementW(val); -// loop_val++; -// if (loop_val == 256) { -// loop_val = 0; -// offset += 11; -// } -// } -// } -// } -// } -// for (int col = 0; col < tensor_scale.extent().column(); ++col){ -// int c = col * QuantBlocking::kColumn; -// for (int row = 0; row < tensor_scale.extent().row(); ++row){ -// int r = row * QuantBlocking::kRow; -// auto weight_cord = cutlass::make_Coord(r/2, c); -// int w = 0; -// if (r % 2 == 0) { -// w = int(tensor_weight.at(weight_cord) & 0x0f); -// } else { -// w = int(tensor_weight.at(weight_cord) >> 4); -// } -// tensor_scale.at({row, col}) = w; -// #ifdef USE_QUANT_OFFSET -// tensor_offset.at({row, col}) = ElementQOffset(w); -// #endif -// } -// } - - // int fill_val = -512; - // int factor = 1; - // for (int col = 0; col < tensor_scale.extent().column(); ++col){ - // for (int row = 0; row < tensor_scale.extent().row(); ++row){ - // tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); - // fill_val++; - // if (fill_val == 512) { - // fill_val = -512; - // factor += 1; - // } - // } - // } - - // std::cout << "Matrix Weight:\n" << tensor_weight.host_view() << "\n"; - - // Prepacking weight matrix and quantization meta data ... - - cutlass::HostTensor tensor_weight_prepacked( - cutlass::make_Coord(problem_size.k(), problem_size.n()/2)); - onnxruntime::test::sm80_prepack_weights_ref( - problem_size.k(), problem_size.n(), - make_ConstMatrixRef(tensor_weight), - make_MatrixRef(tensor_weight_prepacked)); - - cutlass::HostTensor tensor_scale_prepacked( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - cutlass::HostTensor tensor_offset_prepacked( - {problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn}); - - auto scale_ref = make_ConstMatrixRef(tensor_scale); - onnxruntime::test::sm80_prepack_quant_scales_ref( - problem_size.k(), problem_size.n(), scale_ref, - make_MatrixRef(tensor_scale_prepacked)); - if constexpr (has_offsets) { - auto offset_ref = make_ConstMatrixRef(tensor_offset); - onnxruntime::test::sm80_prepack_quant_offsets_ref( - problem_size.k(), problem_size.n(), offset_ref, - make_MatrixRef(tensor_offset_prepacked)); - } + thrust::device_vector d_packed_zp(packed_zp); + cutlass::TensorRef ref_zp( + d_packed_zp.data().get(), LayoutInputQScale::packed(meta_shape)); - // Copy data from host to GPU... tensor_a.sync_device(); - tensor_weight_prepacked.sync_device(); - tensor_scale_prepacked.sync_device(); - if constexpr (has_offsets) { - tensor_offset_prepacked.sync_device(); - } tensor_c.sync_device(); tensor_d.sync_device(); - cutlass::TensorRef ref_W( - reinterpret_cast(tensor_weight_prepacked.device_data()), - LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); // run GEMM cutlass::Status status; if constexpr (has_offsets){ status = GemmRunner::run( nullptr, problem_size, tensor_a.device_ref(), ref_W, - tensor_scale_prepacked.device_ref(), tensor_offset_prepacked.device_ref(), + ref_scales, ref_zp, tensor_c.device_ref(), tensor_d.device_ref()); } else { status = GemmRunner::run( nullptr, problem_size, tensor_a.device_ref(), ref_W, - tensor_scale_prepacked.device_ref(), + ref_scales, tensor_c.device_ref(), tensor_d.device_ref()); } ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); - // Preparing reference kernel arguments - // Dequantizing weights and running reference kernel - + // Running reference kernel using ElementInputB = ElementInputA; using LayoutInputB = cutlass::layout::ColumnMajor; - cutlass::HostTensor tensor_b( - problem_size.kn()); // <- Create dequantized matrix B with dimensions K x N + thrust::device_vector d_dequants(dequants); + cutlass::TensorRef ref_B( + d_dequants.data().get(), LayoutInputB::packed(problem_size.kn())); cutlass::HostTensor tensor_ref_d( problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from // reference kernel - // Dequantize weights and save into matrix B for reference - for (int col = 0; col < tensor_b.extent().column(); ++col){ - for (int row = 0; row < tensor_b.extent().row(); ++row) { - auto weight_cord = cutlass::make_Coord(row/2, col); - auto scale_cord = cutlass::make_Coord(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - const uint8_t offset = has_offsets ? tensor_offset.at(scale_cord) : 8; - int w = 0; - if (row % 2 == 0) { - w = int(tensor_weight.at(weight_cord) & 0x0f) - offset; - } else { - w = int(tensor_weight.at(weight_cord) >> 4) - offset; - } - auto scale = tensor_scale.at(scale_cord); - tensor_b.at({row, col}) = scale * float(w); - } - } cutlass::reference::host::TensorFill( tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros - - tensor_b.sync_device(); tensor_ref_d.sync_device(); // Initialize alpha and beta for dot product computation @@ -416,7 +295,7 @@ void run_blkq4_gemm(int m, int n, int k) { problem_size, alpha, tensor_a.device_ref(), - tensor_b.device_ref(), + ref_B, beta, tensor_c.device_ref(), tensor_ref_d.device_ref()); From 423aa1fe60b4285e9c8106d93e2fd815c5d67cce Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:42:38 +0000 Subject: [PATCH 06/13] lint --- cmake/external/emsdk | 2 +- onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h | 5 ++--- .../test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h | 8 +++----- .../cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc | 6 +++--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 4e2496141eda1..a896e3d066448 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 4e2496141eda15040c44e9bbf237a1326368e34c +Subproject commit a896e3d066448b3530dbcaa48869fafefd738f57 diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index d49484a072be1..ab59cc2c59b75 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -140,7 +140,7 @@ inline void sm80_expand_prepack_quant_offsets_ref( ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, "Unexpected tensor_offset_prepacked shape (", tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], - ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); ORT_ENFORCE(tensor_offset.shape() == zp_shape, "Unexpected tensor_offset shape (", tensor_offset.shape()[0], ",", tensor_offset.shape()[1], @@ -201,8 +201,7 @@ inline void sm80_expand_prepack_quant_offsets_ref( } template -inline -void sm80_prepack_quant_offsets_ref( +inline void sm80_prepack_quant_offsets_ref( int rows, int columns, MatrixRef tensor_offset, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h index 4cfb074e7df7d..bbe370675fc48 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -46,8 +46,7 @@ Status sm80_supported(); * @param[out] q_zp The zero points, column major layout. */ template -inline -void blkq4_weights_gen( +inline void blkq4_weights_gen( int rows, int columns, std::vector& dequants, std::vector& q_weights, @@ -130,7 +129,7 @@ void blkq4_weights_gen( q_scales, meta_shape); MatrixRef tensor_offset; - if constexpr(has_offsets) { + if constexpr (has_offsets) { q_zp.resize(zp_shape.product()); tensor_offset = MatrixRef( q_zp, zp_shape); @@ -155,7 +154,7 @@ void blkq4_weights_gen( auto weight_cord = make_Position(row / 2, col); auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); uint8_t offset = 8; - if constexpr(has_offsets) { + if constexpr (has_offsets) { if (scale_cord[0] % 2 == 0) { offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; } else { @@ -175,7 +174,6 @@ void blkq4_weights_gen( // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); } } - } template < diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 148055bd046e2..897cf3fc774d3 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -56,7 +56,7 @@ void testPrepack(int rows, int columns) { MatrixRef tensor_scale( q_scales, meta_shape); MatrixRef tensor_offset; - if constexpr(has_offset) { + if constexpr (has_offset) { tensor_offset = MatrixRef(q_zp, zp_shape); } @@ -167,7 +167,7 @@ void testPrepack(int rows, int columns) { std::vector packed_scales_ref(meta_shape.product()); MatrixRef tensor_packed_s_ref = make_MatrixRef(packed_scales_ref, meta_shape); - if constexpr(Base::ShouldRearrangeMeta) { + if constexpr (Base::ShouldRearrangeMeta) { onnxruntime::test::sm80_prepack_quant_scales_ref( rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); } else { @@ -197,7 +197,7 @@ void testPrepack(int rows, int columns) { std::vector packed_zp_ref(meta_shape.product()); MatrixRef tensor_packed_zp_ref = make_MatrixRef(packed_zp_ref, meta_shape); - if constexpr(Base::ShouldRearrangeMeta) { + if constexpr (Base::ShouldRearrangeMeta) { onnxruntime::test::sm80_expand_prepack_quant_offsets_ref( rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); } else { From 40de1a14742108cbab55a575b0f136c00188dd1a Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 30 Jan 2024 18:36:45 +0000 Subject: [PATCH 07/13] conflict with main --- cmake/external/cutlass.cmake | 6 +++--- cmake/external/emsdk | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index f4c55ae105560..f04f4bec76cd5 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,8 +1,8 @@ include(FetchContent) FetchContent_Declare( - cutlass - URL ${DEP_URL_cutlass} - URL_HASH SHA1=${DEP_SHA1_cutlass} + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} ) FetchContent_GetProperties(cutlass) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index a896e3d066448..4e2496141eda1 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit a896e3d066448b3530dbcaa48869fafefd738f57 +Subproject commit 4e2496141eda15040c44e9bbf237a1326368e34c From 2d67beaa6784d71b18fc5f1c4a58afbe21f48f9d Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 30 Jan 2024 19:54:16 +0000 Subject: [PATCH 08/13] remove redundent test function --- .../test/cuda_host/blkq4_fp16_quant_sm80.h | 131 +++++------------- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 2 +- 2 files changed, 34 insertions(+), 99 deletions(-) diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index ab59cc2c59b75..6ea8b55505214 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -83,76 +83,10 @@ inline void sm80_prepack_quant_scales_ref( // 16b gemm (2 elements per 32b register, operand tile shape 8x8) // 2 B operand tiles per mma instruction stacked on k dimension // (1,n) quantization blocking - if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - - for (int col = 0; col < tensor_scale.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); - tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); - tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); - tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); - } - } - } - } else { - // In all other cases, we don't prepack scale or offset - std::copy(tensor_scale.data().begin(), tensor_scale.data().end(), tensor_scale_prepacked.data().begin()); + if constexpr (sizeof(ScaleElementT) != 2 || QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_scales_ref should only be called for row-wise block quantization on 16b float values."); } -} -template -inline void sm80_expand_prepack_quant_offsets_ref( - int rows, - int columns, - MatrixRef tensor_offset, - MatrixRef tensor_offset_prepacked) { - const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); - const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); - ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, - "Unexpected tensor_offset_prepacked shape (", - tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], - ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); - ORT_ENFORCE(tensor_offset.shape() == zp_shape, - "Unexpected tensor_offset shape (", - tensor_offset.shape()[0], ",", tensor_offset.shape()[1], - ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (QuantBlocking::kRow != 1) { - return; - } // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, @@ -168,7 +102,7 @@ inline void sm80_expand_prepack_quant_offsets_ref( // // We need to deliver quantization scale and offset elements to the corresponding threads, // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // needs two separate loads for a mma instruction, due to the tile fragment layout shown // above. To reduce the number of loads, we rearrange each column as below, so we can use // a single load to load fragments for two tiles: // T0 T0 @@ -179,22 +113,16 @@ inline void sm80_expand_prepack_quant_offsets_ref( // T1 T2 // T2 T3 // T3 T3 - if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own - // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to - // convert to fp16x2 format in a b32 register - uint8_t pair01 = tensor_offset.at(src_idx / 2, col); - uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); - tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; - tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; - tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; - tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; - } + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); } } } @@ -206,18 +134,23 @@ inline void sm80_prepack_quant_offsets_ref( int columns, MatrixRef tensor_offset, MatrixRef tensor_offset_prepacked) { - ORT_ENFORCE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn), - "Unexpected tensor_offset shape! Expected: (", - rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); - ORT_ENFORCE(tensor_offset_prepacked.shape() == tensor_offset.shape()); + const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, + "Unexpected tensor_offset_prepacked shape (", + tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ORT_ENFORCE(tensor_offset.shape() == zp_shape, + "Unexpected tensor_offset shape (", + tensor_offset.shape()[0], ",", tensor_offset.shape()[1], + ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); // Only prepacking scale and offset tensors for a often used special case: // 16b gemm (2 elements per 32b register, operand tile shape 8x8) // 2 B operand tiles per mma instruction stacked on k dimension // (1,n) quantization blocking if constexpr (QuantBlocking::kRow != 1) { - std::copy(tensor_offset.data().begin(), tensor_offset.data().end(), tensor_offset_prepacked.data().begin()); - return; + ORT_THROW("sm80_prepack_quant_offsets_ref should only be called for row-wise block quantization."); } // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use @@ -234,7 +167,7 @@ inline void sm80_prepack_quant_offsets_ref( // // We need to deliver quantization scale and offset elements to the corresponding threads, // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two seperate loads for a mma instruction, due to the tile fragment layout shown + // needs two separate loads for a mma instruction, due to the tile fragment layout shown // above. To reduce the number of loads, we rearrange each column as below, so we can use // a single load to load fragments for two tiles: // T0 T0 @@ -246,18 +179,20 @@ inline void sm80_prepack_quant_offsets_ref( // T2 T3 // T3 T3 if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { + for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { for (int thread_id = 0; thread_id < 4; thread_id++) { const int dst_idx = row_blk + thread_id * 4; const int src_idx = row_blk + thread_id * 2; // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to // convert to fp16x2 format in a b32 register - tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); - tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); - tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); - tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; } } } diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 897cf3fc774d3..f987c4a7c507d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -198,7 +198,7 @@ void testPrepack(int rows, int columns) { MatrixRef tensor_packed_zp_ref = make_MatrixRef(packed_zp_ref, meta_shape); if constexpr (Base::ShouldRearrangeMeta) { - onnxruntime::test::sm80_expand_prepack_quant_offsets_ref( + onnxruntime::test::sm80_prepack_quant_offsets_ref( rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); } else { for (int col = 0; col < meta_shape[1]; ++col) { From 18bf4636fc36cad019315982502aca6c605c1b4d Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 15 Feb 2024 18:19:32 +0000 Subject: [PATCH 09/13] fix mis-spell and comments --- .../cutlass_ext/q4gemm/device/quantb_gemm.h | 19 +++++---------- .../cutlass_ext/q4gemm/kernel/quantb_gemm.h | 24 +++++++------------ .../quantb_meta_mma_tensor_op_tile_iterator.h | 24 +++++++++---------- .../q4gemm/warp/quantb_mma_tensor_op.h | 2 +- 4 files changed, 27 insertions(+), 42 deletions(-) diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h index 36b52199362d5..38795291b0328 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -68,7 +68,7 @@ namespace device { to support quantization. This implementation pretty much follows the design of cutlass. But this class seems to be - just a wrapper of the Gemm kernel class. Is this really necessary? + just a wrapper of the Gemm kernel class. Consider combining them in future iterations. */ template < @@ -220,7 +220,6 @@ class QuantBGemm { /// Argument structure struct Arguments { - // // Data members // @@ -248,9 +247,7 @@ class QuantBGemm { /// Default ctor CUTLASS_HOST_DEVICE - Arguments(): problem_size(0, 0, 0) { - - } + Arguments(): problem_size(0, 0, 0) {} /// Constructs an Arguments structure CUTLASS_HOST_DEVICE @@ -262,8 +259,7 @@ class QuantBGemm { TensorRef ref_C_, TensorRef ref_D_, typename EpilogueOutputOp::Params epilogue_ = - typename EpilogueOutputOp::Params() - ): + typename EpilogueOutputOp::Params()): problem_size(problem_size_), ref_A(ref_A_), ref_B(ref_B_), @@ -284,8 +280,7 @@ class QuantBGemm { TensorRef ref_C_, TensorRef ref_D_, typename EpilogueOutputOp::Params epilogue_ = - typename EpilogueOutputOp::Params() - ): + typename EpilogueOutputOp::Params()): problem_size(problem_size_), ref_A(ref_A_), ref_B(ref_B_), @@ -298,13 +293,11 @@ class QuantBGemm { } }; -private: - + private: /// Kernel parameters object typename GemmKernel::Params params_; -public: - + public: /// Constructs the GEMM. QuantBGemm() { } diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h index 1f781b37b98b8..6e5ad8f406147 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -310,20 +310,12 @@ struct QuantBGemm { tb_offset_B, params.gather_B_indices); - int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; - int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; - if (qscale_k == 0) { - printf("qscale_k is 0! can_implement() should have returned false!\n"); - } - if (qscale_k * Mma::QuantBlocking::kRow != problem_size_k) { - printf("qscale_k * Mma::QuantBlocking::kK != problem_size_k! can_implement() should have returned false!\n"); - } - if (qscale_n == 0) { - printf("qscale_n is 0! can_implement() should have returned false!\n"); - } - if (qscale_n * Mma::QuantBlocking::kColumn != params.problem_size.n()) { - printf("qscale_n * Mma::QuantBlocking::kN != params.problem_size.n()! can_implement() should have returned false!\n"); - } + const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; + const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; + + // should have been verified by can_implement() + assert((qscale_k > 0) && (qscale_k * Mma::QuantBlocking::kRow == problem_size_k)); + assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); cutlass::MatrixCoord tb_offset_QScale{ threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), @@ -347,8 +339,8 @@ struct QuantBGemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); - int lane_idx = threadIdx.x % 32; + const int warp_idx = canonical_warp_idx(); + const int lane_idx = threadIdx.x % 32; // // Main loop diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 107db414c23bc..5d05016b8693a 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -40,14 +40,14 @@ struct b32_pair{ uint32_t b; }; -struct fp16_quard{ +struct fp16_quad{ cutlass::half_t a; cutlass::half_t b; cutlass::half_t c; cutlass::half_t d; }; -struct b16_quard{ +struct b16_quad{ int16_t a; int16_t b; int16_t c; @@ -57,8 +57,8 @@ struct b16_quard{ union b64 { uint64_t single; b32_pair pair; - b16_quard quard; - fp16_quard fp16_quard; + b16_quad quard; + fp16_quad fp16_quad; }; static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); @@ -88,7 +88,7 @@ void weights2Half(cutlass::Array const &weights, " shl.b32 %1, %4, 2;\n" " shr.u32 %2, %4, 2;\n" " shr.u32 %3, %4, 6;\n" - " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" @@ -147,7 +147,7 @@ class QuantBMetaMmaTile{ // T3 | T7 | T11 | T15 | T19 | T23 | T27 | T31 using CoreTile = layout::PitchLinearShape<4, 8>; - /// Each thread holds a 32b fragement per tile: for half precision, it's 2 elements, 4 elements for int8 + /// Each thread holds a 32b fragment per tile: for half precision, it's 2 elements, 4 elements for int8 static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits::value; /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension) @@ -175,13 +175,13 @@ class QuantBMetaMmaTile{ /// stride to reach the meta data for the next CoreTile on the K dimension static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; - /// Stride on N dimention should be the tile width, shrunk by blocking size on this dimension. + /// Stride on N dimension should be the tile width, shrunk by blocking size on this dimension. static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn; /// On N dimension, how many tiles share the same meta data static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided; - /// Each fragement should cover kMmaIterationsB number of mma intructions on the N dimension. + /// Each fragment should cover kMmaIterationsB number of mma intructions on the N dimension. /// When blocking size on this dimension exceeds the tile width, multiple iterations /// would share the same data. static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats; @@ -487,10 +487,10 @@ class QuantBMetaMmaTensorOpTileIteratorfp16_quard.a * static_cast(-16-8); - offsets.fp16_quard.b = scales_ptr->fp16_quard.b * static_cast(-16-8); - offsets.fp16_quard.c = scales_ptr->fp16_quard.c * static_cast(-16-8); - offsets.fp16_quard.d = scales_ptr->fp16_quard.d * static_cast(-16-8); + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); } CUTLASS_PRAGMA_UNROLL diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h index a88be45952857..1b99c8b909fe8 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -88,7 +88,7 @@ struct ConvertAndPack { CUTLASS_HOST_DEVICE Array operator()(Array const &source) { - return source; + return source; } }; From 7d5d5ca465481d619a55078d285e4d9f1cd28fe9 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 23 Feb 2024 18:28:00 +0000 Subject: [PATCH 10/13] variable and type names --- .../threadblock/quantb_mma_multistage.h | 2 +- .../quantb_meta_mma_tensor_op_tile_iterator.h | 20 +---- .../q4gemm/warp/quantb_mma_tensor_op.h | 83 +------------------ 3 files changed, 7 insertions(+), 98 deletions(-) diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h index dfd1032b42c68..8b6bac8c5099a 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -327,7 +327,7 @@ class QuantBMmaBase { typename Operator::IteratorB warp_tile_iterator_B_; /// Iterator to load a warp-scoped tile of quant scales from shared memory - typename Operator::IteratorQScale warp_tile_iterator_QScale_; + typename Operator::IteratorQMeta warp_tile_iterator_QScale_; public: diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 5d05016b8693a..c142ddb132629 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -197,7 +197,7 @@ class QuantBMetaMmaTile{ // 16b gemm (kNumBsPerCoreTileFragement == 2) // 2 B operand tiles per mma (kBTilesPerMma == 2) // (1,n) quantization blocking - // The weight and offset tensor is prepacked to reduce load instructions. + // The scale and offset tensors are prepacked to reduce the number of load instructions. return make_Coord((lane_id % CoreTile::kContiguous) * 4, lane_id / CoreTile::kContiguous); } else { @@ -356,7 +356,7 @@ class QuantBMetaMmaTensorOpTileIterator -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - return converter(source); - } -}; - -template -struct ConvertAndPack { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - return source; - } -}; - -template -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - Array tmp; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); - tmp[i] = source[idx]; - } - - return converter(tmp); - } -}; - -template -struct ConvertAndPack { - - using Converter = NumericArrayConverter; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &source) { - Converter converter; - - Array tmp; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - int idx = (((i << 1) & 2) | ((i >> 1) & 1) | (i & 0xfffffffc)); - tmp[i] = source[idx]; - } - - return converter(tmp); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace internal - -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> @@ -292,13 +217,13 @@ class QuantBMmaTensorOp { // TODO This is an expanding iterator, it needs to replicate the quantization parameters // to all threads in the warp. - using IteratorQScale = QuantBMetaMmaTensorOpTileIterator< + using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQOffset, ArchMmaOperator, kThreadCount, kPartitionsK>; - using FragmentQScale = typename IteratorQScale::FragmentScale; - using FragmentQOffset = typename IteratorQScale::FragmentOffset; + using FragmentQScale = typename IteratorQMeta::FragmentScale; + using FragmentQOffset = typename IteratorQMeta::FragmentOffset; /// Number of mma operations performed using MmaIterations = MatrixShape< @@ -419,7 +344,7 @@ class QuantBMmaTensorOp { Array const *ptr_B = reinterpret_cast const *>(&B); - IteratorQScale::dequant(scales, offsets, *ptr_B, dst_B); + IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); } }; From b9f9cb768e0b0c8b593727cef865fb51c6442448 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 23 Feb 2024 19:52:31 +0000 Subject: [PATCH 11/13] ptx for row blocking no zero-point --- .../quantb_meta_mma_tensor_op_tile_iterator.h | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index c142ddb132629..0b4b786bddee8 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -276,6 +276,9 @@ class QuantBMetaMmaTensorOpTileIterator::value); + static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, + "Only support row blocking for column major layout"); + using MetaTile = QuantBMetaMmaTile; /// Number of MMA instructions for this tile @@ -350,12 +353,11 @@ class QuantBMetaMmaTensorOpTileIterator(dest.data()); const b64* scales_ptr = reinterpret_cast(scales.data()); @@ -475,7 +476,7 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); +#endif } CUTLASS_PRAGMA_UNROLL From 31a602f467815607e667d14cf9e93a377d012b32 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Wed, 28 Feb 2024 21:00:51 +0000 Subject: [PATCH 12/13] optimize column block dequant --- .../quantb_meta_mma_tensor_op_tile_iterator.h | 117 ++++++++++++++++-- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 24 +++- 2 files changed, 128 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 0b4b786bddee8..4ba39dda3db8d 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -739,20 +739,119 @@ class QuantBMetaMmaTensorOpTileIterator(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); if constexpr(kHasOffset){ - offset = s * static_cast(-16 - int(offsets[n_out])); + const uint32_t* p = reinterpret_cast(offsets.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + scales_ptr++; + p++; + addon_ptr += 2; + } } else { - offset = s * static_cast(-16-8); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + assert(0); +#endif + scales_ptr++; + addon_ptr += 2; + } } + } else if constexpr (kMmaIterationsB % 2 == 0) { + const uint32_t* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + + if constexpr (kHasOffset){ + // possible buffer over read 2 bytes here. + const uint32_t* p = reinterpret_cast(offsets.data()); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) + "r"(p[0])); +#else + assert(0); +#endif + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); +#else + assert(0); +#endif + } + } else { + // kMmaIterationsB == 1 + if constexpr(kHasOffset){ + uint8_t zp = offsets[0]; + addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); + } else { + addon[0] = scales[0] * static_cast(-16-8); + } + } + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ CUTLASS_PRAGMA_UNROLL for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ - dest[out_idx] = s * dest[out_idx] + offset; - dest[out_idx + 1] = s * dest[out_idx + 1] + offset; + dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; + dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; out_idx += 2; } } diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index f987c4a7c507d..6e5739c9ee647 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -263,7 +263,7 @@ TEST(BlkQ4_GEMM, PrepackSm80Test) { testPrepack(256, 256); } -TEST(BlkQ4_GEMM, Sm80Test) { +TEST(BlkQ4_GEMM, Sm80RowBlockingTest) { Status status = onnxruntime::cuda::test::sm80_supported(); if (!status.IsOK()) { // skip the test if sm80 is not supported @@ -290,14 +290,30 @@ TEST(BlkQ4_GEMM, Sm80Test) { onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, false>(256, 1024, 576); onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576); +} - onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(256, 672, 576); - onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(256, 672, 576); +TEST(BlkQ4_GEMM, Sm80ColBlockingTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(64, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(64, 672, 576); onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576); onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576); - // small m +} + +TEST(BlkQ4_GEMM, Sm80SmallMTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + // // small m onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576); onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576); From 1477c011d11dce3361ed02f590ac5fa9fc51d930 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Wed, 28 Feb 2024 23:40:30 +0000 Subject: [PATCH 13/13] lint --- .../test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 6e5739c9ee647..e687ae73e66f2 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -303,7 +303,6 @@ TEST(BlkQ4_GEMM, Sm80ColBlockingTest) { onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576); onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576); - } TEST(BlkQ4_GEMM, Sm80SmallMTest) {