diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake
index 3bae1b8a48e0f..a2156f2c0e9f5 100644
--- a/cmake/onnxruntime_optimizer.cmake
+++ b/cmake/onnxruntime_optimizer.cmake
@@ -111,6 +111,10 @@ onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_src
onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT})
+
+# using optimizer as cuda prepacking, so extra headers are needed
+target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
+
if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT})
onnxruntime_add_include_to_target(onnxruntime_optimizer nlohmann_json::nlohmann_json)
diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake
index 3b48a40bf1166..10e88308cea53 100644
--- a/cmake/onnxruntime_providers_cuda.cmake
+++ b/cmake/onnxruntime_providers_cuda.cmake
@@ -217,6 +217,7 @@
include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
+ target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 711a9f77f9094..1ade7f4109457 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -272,6 +272,13 @@ if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)
"${TEST_SRC_DIR}/optimizer/*.h"
)
+ if (MSVC AND ((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")))
+ set_source_files_properties("${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" PROPERTIES COMPILE_FLAGS "/bigobj")
+ list(REMOVE_ITEM onnxruntime_test_optimizer_src
+ "${TEST_SRC_DIR}/optimizer/gpu_op_prepack_test.cc"
+ )
+ endif()
+
set(onnxruntime_test_framework_src_patterns
"${TEST_SRC_DIR}/framework/*.cc"
"${TEST_SRC_DIR}/framework/*.h"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 45306c852a906..9f2fb44b9756d 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2915,6 +2915,14 @@ This version of the operator has been available since version 1 of the 'com.micr
number of bits used for weight quantization (default 4)
block_size : int (required)
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+column_wise_blocking : int
+whether to quantize weight columnwise (value 1, default), or rowwise (value 0)
+prepacked : int
+
+Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default).
+This property should NEVER be set by user. It is set by ONNX Runtime internally during
+model loading time.
+
#### Inputs (3 - 6)
diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
index e609745b5e03f..56853e3229030 100644
--- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h
+++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
@@ -22,6 +22,7 @@
namespace onnxruntime {
class IExecutionProvider;
+class ExecutionProviders;
namespace optimizer_utils {
@@ -48,7 +49,7 @@ std::unique_ptr GenerateRuleBasedGraphTransformer(
InlinedVector> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
- const IExecutionProvider& execution_provider /*required by constant folding*/,
+ const ExecutionProviders& execution_providers /* cpu ep required by constant folding*/,
const InlinedHashSet& rules_and_transformers_to_disable = {});
#endif // !defined(ORT_MINIMAL_BUILD)
@@ -77,7 +78,7 @@ InlinedVector> GenerateTransformersForMinimalB
TransformerLevel level,
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
- const IExecutionProvider& cpu_execution_provider,
+ const ExecutionProviders& execution_providers,
const InlinedHashSet& rules_and_transformers_to_disable = {});
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
index 1cec6f6a12f1c..2c3db2a3d2b47 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
@@ -14,24 +14,44 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
-using namespace onnxruntime::cuda;
+
+#ifndef USE_ROCM
+template <>
+Status MatMulNBits::PrepackedGemm(
+ cudaStream_t stream,
+ int M,
+ const Tensor* a,
+ const Tensor* b,
+ const Tensor* scales,
+ const Tensor* zero_points,
+ Tensor* Y) const {
+ uint8_t const* zero_points_ptr = nullptr;
+ size_t zero_points_size = 0;
+ if (zero_points != nullptr) {
+ zero_points_ptr = zero_points->Data();
+ zero_points_size = zero_points->Shape().Size();
+ }
+
+ return blkq4_fp16_gemm_sm80_dispatch(
+ int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,
+ a->Data(), a->Shape().Size(),
+ b->Data(), b->Shape().Size(),
+ scales->Data(), scales->Shape().Size(),
+ zero_points_ptr, zero_points_size,
+ Y->MutableData(), Y->Shape().Size());
+}
+#endif // !USE_ROCM
template
Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const {
+ using CudaT = typename ToCudaType::MappedType;
+
const Tensor* a = ctx->Input(0);
const Tensor* b = ctx->Input(1);
const Tensor* scales = ctx->Input(2);
const Tensor* zero_points = ctx->Input(3);
const Tensor* reorder_idx = ctx->Input(4);
- const auto* a_data = a->Data();
- const uint8_t* blob_data = b->Data();
- const auto* scales_data = scales->Data();
- const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
- const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data();
-
- typedef typename ToCudaType::MappedType CudaT;
-
constexpr bool transa = false;
constexpr bool transb = true;
MatMulComputeHelper helper;
@@ -43,6 +63,26 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();
+ if (prepack_ > 0) {
+#ifndef USE_ROCM
+ ORT_RETURN_IF(reorder_idx != nullptr,
+ "Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!");
+ ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType(),
+ "Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!");
+ return PrepackedGemm(
+ static_cast(ctx->GetComputeStream()->GetHandle()),
+ static_cast(helper.M()), a, b, scales, zero_points, Y);
+#else
+ ORT_RETURN_IF(true, "Prepacked gemm is not supported for MatMulNBits op.");
+#endif // !USE_ROCM
+ }
+
+ const auto* a_data = a->Data();
+ const uint8_t* blob_data = b->Data();
+ const auto* scales_data = scales->Data();
+ const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
+ const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data();
+
bool is_4bit_done = (reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType()) &&
TryMatMul4Bits(
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
index af9e87eaf225d..2309ef4fb7a27 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
@@ -10,6 +10,9 @@
#include "core/providers/cuda/cuda_common.h"
#include "matmul_nbits.cuh"
+#include "blk_q4/f16_gemm_sm80.h"
+#include "gemm/device/quant_b4_gemm.h"
+
using namespace onnxruntime::cuda;
using namespace cub;
@@ -349,6 +352,342 @@ template bool TryMatMul4Bits(
int shared_mem_per_block,
cudaStream_t stream);
+/**
+ * @brief Helper function to run the GEMM kernel for 4bits quantized gemm on SM80.
+ * Only support fp16 for now.
+*/
+template<
+ typename ElementT,
+ int block_size,
+ bool column_wise_blocking,
+ bool small_m,
+ bool has_offsets>
+Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream,
+ gsl::span a,
+ gsl::span weights,
+ gsl::span scales,
+ gsl::span offsets,
+ gsl::span output) {
+ static_assert(std::is_same::value
+ || std::is_same::value
+ || std::is_same::value,
+ "Only support fp16 for now");
+ using ElementDequant = cutlass::half_t;
+ using QuantBlocking =
+ typename std::conditional,
+ cutlass::MatrixShape<1, block_size>>::type;
+
+ using GemmRunner = onnxruntime::cuda::BlkQ4F16GemmImpl;
+
+ using ElementAccumulator = typename GemmRunner::ElementAccumulator;
+ using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue;
+ using ElementOutput = typename GemmRunner::ElementOutput;
+ using ElementW = typename GemmRunner::ElementW;
+ using ElementWPack = typename GemmRunner::ElementWPack;
+ using ElementQScale = typename GemmRunner::ElementQScale;
+ using ElementQOffset = typename GemmRunner::ElementQOffset;
+
+ using LayoutInputA = typename GemmRunner::LayoutInputA;
+ using LayoutOutput = typename GemmRunner::LayoutOutput;
+ using LayoutInputWPack = typename GemmRunner::LayoutInputWPack;
+ using LayoutInputQScale = typename GemmRunner::LayoutInputQScale;
+
+ const cutlass::gemm::GemmCoord problem_size = {m, n, k};
+
+ ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct: ", a.size_bytes(), " vs m: ", m, "k: ", k , " size: ", m * k * sizeof(ElementDequant));
+ cutlass::TensorRef ref_a(
+ reinterpret_cast(a.data()),
+ LayoutInputA::packed({m, k}));
+
+ ORT_RETURN_IF_NOT(weights.size_bytes() == k/2 * n/2 * sizeof(ElementWPack), "weights size is not correct");
+ cutlass::TensorRef ref_W(
+ reinterpret_cast(weights.data()),
+ LayoutInputWPack::packed({k/2, n/2}));
+
+ ORT_RETURN_IF_NOT(scales.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQScale),
+ "scales size is not correct");
+ cutlass::TensorRef ref_scales(
+ reinterpret_cast(scales.data()),
+ LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn}));
+
+ ORT_RETURN_IF_NOT(output.size_bytes() == m * n * sizeof(ElementOutput), "output size is not correct");
+ cutlass::TensorRef ref_output(
+ reinterpret_cast(output.data()),
+ LayoutOutput::packed({m, n}));
+
+ // run GEMM
+ cutlass::Status status;
+ if constexpr (has_offsets) {
+ ORT_RETURN_IF_NOT(offsets.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQOffset),
+ "offsets size is not correct");
+ cutlass::TensorRef ref_offsets(
+ reinterpret_cast(offsets.data()),
+ LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn}));
+ status = GemmRunner::run(
+ stream, problem_size, ref_a, ref_W, ref_scales, ref_offsets,
+ ref_output, ref_output);
+ } else {
+ status = GemmRunner::run(
+ stream, problem_size, ref_a, ref_W, ref_scales,
+ ref_output, ref_output);
+ }
+ ORT_RETURN_IF_NOT(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status));
+ return Status::OK();
+}
+
+/**
+ * @brief The GEMM kernel for 4bits quantized gemm on SM80 -- small size gemm version.
+ */
+template <
+ typename ElementT,
+ int block_size,
+ bool column_wise_blocking,
+ bool has_offsets>
+Status blkq4_small_gemm_sm80(
+ int m, int n, int k, cudaStream_t stream,
+ const ElementT* ptr_a,
+ size_t lda,
+ const uint8_t* ptr_packed_b,
+ const ElementT* ptr_scales,
+ const uint8_t* ptr_offsets,
+ ElementT* ptr_c,
+ size_t ldc) {
+ using QuantBlocking =
+ typename std::conditional,
+ cutlass::MatrixShape<1, block_size>>::type;
+ using LayoutQmeta =
+ typename std::conditional::type;
+ using WarpShape = cutlass::gemm::GemmShape<16, 16, 64>;
+
+ const cutlass::gemm::GemmCoord problem_size = {m, n, k};
+ const auto meta_shape = cutlass::make_Coord(problem_size.k() / QuantBlocking::kRow,
+ problem_size.n() / QuantBlocking::kColumn);
+ if ((problem_size.k() % QuantBlocking::kRow != 0) ||
+ (problem_size.n() % QuantBlocking::kColumn) != 0){
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Partial quantization block in B not supported!");
+ }
+
+ // run GEMM
+ size_t meta_stride = static_cast(LayoutQmeta::packed(meta_shape).stride(0));
+ const void* ptr_zp = has_offsets ? ptr_offsets : nullptr;
+ size_t zp_byte_stride = has_offsets ? meta_stride * sizeof(uint8_t) : size_t(0);
+
+ cutlass::Status status;
+
+ if (k <= 384) {
+ status = mickey::gemm::device::QuantB4Gemm::run(
+ stream, problem_size,
+ ptr_c, ldc * sizeof(ElementT),
+ ptr_a, lda * sizeof(ElementT),
+ ptr_packed_b, k * sizeof(uint8_t),
+ ptr_scales, meta_stride * sizeof(ElementT),
+ ptr_zp, zp_byte_stride);
+ } else if (k <= 768) {
+ status = mickey::gemm::device::QuantB4Gemm::run(
+ stream, problem_size,
+ ptr_c, ldc * sizeof(ElementT),
+ ptr_a, lda * sizeof(ElementT),
+ ptr_packed_b, k * sizeof(uint8_t),
+ ptr_scales, meta_stride * sizeof(ElementT),
+ ptr_zp, zp_byte_stride);
+ } else if (k < 1536) {
+ status = mickey::gemm::device::QuantB4Gemm::run(
+ stream, problem_size,
+ ptr_c, ldc * sizeof(ElementT),
+ ptr_a, lda * sizeof(ElementT),
+ ptr_packed_b, k * sizeof(uint8_t),
+ ptr_scales, meta_stride * sizeof(ElementT),
+ ptr_zp, zp_byte_stride);
+ } else {
+ status = mickey::gemm::device::QuantB4Gemm::run(
+ stream, problem_size,
+ ptr_c, ldc * sizeof(ElementT),
+ ptr_a, lda * sizeof(ElementT),
+ ptr_packed_b, k * sizeof(uint8_t),
+ ptr_scales, meta_stride * sizeof(ElementT),
+ ptr_zp, zp_byte_stride);
+ }
+
+ ORT_RETURN_IF_NOT(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status));
+ return Status::OK();
+}
+
+
+template
+Status
+blkq4_fp16_gemm_sm80_dispatch(
+ int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream,
+ ElementT const* a_ptr, size_t a_size,
+ uint8_t const* weights_ptr, size_t weights_size,
+ ElementT const* scales_ptr, size_t scales_size,
+ uint8_t const* offsets_ptr, size_t offsets_size,
+ ElementT* output_ptr, size_t output_size) {
+ auto a = gsl::make_span(a_ptr, a_size);
+ auto weights = gsl::make_span(weights_ptr, weights_size);
+ auto scales = gsl::make_span(scales_ptr, scales_size);
+ auto offsets = gsl::make_span(offsets_ptr, offsets_size);
+ auto output = gsl::make_span(output_ptr, output_size);
+
+ switch (block_size)
+ {
+ case 16:
+ if (column_wise_blocking) {
+ if (m <= 64 && n < 16384) {
+ if (offsets.empty())
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ else
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ }
+ if (m > 32) {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ } else {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ }
+ } else {
+ if (m <= 64 && n < 16384) {
+ if (offsets.empty())
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ else
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ }
+ if (m > 32) {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ } else {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ }
+ }
+ break;
+
+ case 32:
+ if (column_wise_blocking) {
+ if (m <= 64 && n < 16384) {
+ if (offsets.empty())
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ else
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ }
+ if (m > 32) {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ } else {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ }
+ } else {
+ if (m <= 64 && n < 16384) {
+ if (offsets.empty())
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ else
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ }
+ if (m > 32) {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ } else {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ }
+ }
+ break;
+
+ case 64:
+ if (column_wise_blocking) {
+ if (m <= 64 && n < 16384) {
+ if (offsets.empty())
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ else
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ }
+ if (m > 32) {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ } else {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ }
+ } else {
+ if (m <= 64 && n < 16384) {
+ if (offsets.empty())
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ else
+ return blkq4_small_gemm_sm80(m, n, k, stream, a_ptr, k, weights_ptr, scales_ptr, offsets_ptr, output_ptr, n);
+ }
+ if (m > 32) {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ } else {
+ if (offsets.empty())
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ else
+ return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output);
+ }
+ }
+ break;
+ }
+
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported block size: ", block_size);
+}
+
+template
+Status blkq4_fp16_gemm_sm80_dispatch(
+ int block_size,
+ bool column_wise_blocking,
+ int m, int n, int k, cudaStream_t stream,
+ half const* a_ptr, size_t a_size,
+ uint8_t const* weights_ptr, size_t weights_size,
+ half const* scales_ptr, size_t scales_size,
+ uint8_t const* offsets_ptr, size_t offsets_size,
+ half* output_ptr, size_t output_size);
+
+template
+Status blkq4_fp16_gemm_sm80_dispatch(
+ int block_size,
+ bool column_wise_blocking,
+ int m, int n, int k, cudaStream_t stream,
+ cutlass::half_t const* a_ptr, size_t a_size,
+ uint8_t const* weights_ptr, size_t weights_size,
+ cutlass::half_t const* scales_ptr, size_t scales_size,
+ uint8_t const* offsets_ptr, size_t offsets_size,
+ cutlass::half_t* output_ptr, size_t output_size);
+
+template
+Status blkq4_fp16_gemm_sm80_dispatch(
+ int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream,
+ onnxruntime::MLFloat16 const* a_ptr, size_t a_size,
+ uint8_t const* weights_ptr, size_t weights_size,
+ onnxruntime::MLFloat16 const* scales_ptr, size_t scales_size,
+ uint8_t const* offsets_ptr, size_t offsets_size,
+ onnxruntime::MLFloat16* output_ptr, size_t output_size);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh
index 9ccbe4c4d97a8..b3a325bf806b4 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh
@@ -22,6 +22,17 @@ bool TryMatMul4Bits(
int shared_mem_per_block,
cudaStream_t stream);
+template
+Status blkq4_fp16_gemm_sm80_dispatch(
+ int block_size,
+ bool column_wise_blocking,
+ int m, int n, int k, cudaStream_t stream,
+ ElementT const* a_ptr, size_t a_size,
+ uint8_t const* weights_ptr, size_t weights_size,
+ ElementT const* scales_ptr, size_t scales_size,
+ uint8_t const* offsets_ptr, size_t offsets_size,
+ ElementT* output_ptr, size_t output_size);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
index f5c2c6c4e4fdf..a8008e9cdcfa7 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
+++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
@@ -14,7 +14,6 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
-using namespace onnxruntime::cuda;
template
class MatMulNBits final : public CudaKernel {
@@ -24,8 +23,27 @@ class MatMulNBits final : public CudaKernel {
ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_));
+ ORT_ENFORCE(nbits_ == 4,
+ "Only 4b quantization is supported for MatMulNBits op,"
+ " additional bits support is planned.");
+ int64_t column_wise_quant_blk = 1;
+ info.GetAttrOrDefault("column_wise_blocking", &column_wise_quant_blk, int64_t(1));
+ column_wise_quant_blk_ = column_wise_quant_blk != 0;
+ info.GetAttrOrDefault("prepacked", &prepack_, int64_t(0));
}
+#ifndef USE_ROCM
+ Status PrepackedGemm([[maybe_unused]] cudaStream_t stream,
+ [[maybe_unused]] int M,
+ [[maybe_unused]] const Tensor* a,
+ [[maybe_unused]] const Tensor* b,
+ [[maybe_unused]] const Tensor* scales,
+ [[maybe_unused]] const Tensor* zero_points,
+ [[maybe_unused]] Tensor* Y) const {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Prepacked gemm is not supported for MatMulNBits op.");
+ }
+#endif // !USE_ROCM
+
Status ComputeInternal(OpKernelContext* context) const override;
private:
@@ -34,6 +52,7 @@ class MatMulNBits final : public CudaKernel {
int64_t block_size_;
int64_t nbits_;
bool column_wise_quant_blk_{true};
+ int64_t prepack_{0};
};
} // namespace cuda
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index dea8775c89a30..91cd1fe37d31b 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3431,6 +3431,15 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi
If zero_points has same type as A, it's not packed and has the same shape as Scales.
)DOC";
+ // This is really bad as we expose a property to the user that should never be set by user.
+ // We have to use this as we perform cuda prepacking during graph optimization phase and
+ // this is the only way to pass the information to the runtime.
+ static const char* PrepackProperty_doc = R"DOC(
+Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default).
+This property should NEVER be set by user. It is set by ONNX Runtime internally during
+model loading time.
+)DOC";
+
ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits)
.SetDomain(kMSDomain)
.SinceVersion(1)
@@ -3446,6 +3455,8 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi
"computation. 4 means input A can be quantized with the same block_size to int8 internally from "
"type T1.",
AttributeProto::INT, static_cast(0))
+ .Attr("column_wise_blocking", "whether to quantize weight columnwise (value 1, default), or rowwise (value 0)", AttributeProto::INT, static_cast(1))
+ .Attr("prepacked", PrepackProperty_doc, AttributeProto::INT, static_cast(0))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1 or 2 dimensional data blob", "T2")
.Input(2, "scales", "quantization scale", "T1")
diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc
index 13620f4d8b3bb..6b2138f39e1b3 100644
--- a/onnxruntime/core/graph/graph_utils.cc
+++ b/onnxruntime/core/graph/graph_utils.cc
@@ -240,12 +240,6 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, std::string_view op_typ
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
-const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) {
- const auto& attrs = node.GetAttributes();
- const auto iter = attrs.find(attr_name);
- return iter == attrs.end() ? nullptr : &iter->second;
-}
-
NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
// sanity check as AddInitializedTensor silently ignores attempts to add a duplicate initializer
const ONNX_NAMESPACE::TensorProto* existing = nullptr;
diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h
index 319e055200cca..e27007a6eecfc 100644
--- a/onnxruntime/core/graph/graph_utils.h
+++ b/onnxruntime/core/graph/graph_utils.h
@@ -26,7 +26,27 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node,
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/** Returns the attribute of a Node with a given name. */
-const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name);
+static inline const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) {
+ const auto& attrs = node.GetAttributes();
+ const auto iter = attrs.find(attr_name);
+ return iter == attrs.end() ? nullptr : &iter->second;
+}
+
+template
+inline Status TryGetNodeAttribute(const Node& node, const std::string& attr_name, T& value);
+
+template <>
+inline Status TryGetNodeAttribute(const Node& node, const std::string& attr_name, int64_t& value) {
+ const auto* attr = GetNodeAttribute(node, attr_name);
+ if (!attr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node attribute '", attr_name, "' is not set.");
+ }
+ if (!attr->has_i()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node attribute '", attr_name, "' is not an integer.");
+ }
+ value = attr->i();
+ return Status::OK();
+}
/** Add a new initializer to 'graph'.
Checks that new_initializer does not already exist in 'graph' before adding it.
diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
index 52bff7e40dbe3..0b3cf3c0e75fa 100644
--- a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
+++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h
@@ -11,9 +11,19 @@
#pragma once
+// Ignore CUTLASS warning C4100: unreferenced formal parameter
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4100)
+#endif
+
#include "cutlass/cutlass.h"
#include "cutlass_ext/q4gemm/device/quantb_gemm.h"
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
namespace onnxruntime {
namespace cuda {
@@ -127,14 +137,6 @@ struct BlkQ4F16GemmImpl {
if constexpr (!kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
- if constexpr (ShapeMMAThreadBlock::kM == 16) {
- if (problem_size_.m() > 16) {
- // For M > 16, the caller should have picked the
- // kernel with bigger M
- return cutlass::Status::kErrorNotSupported;
- }
- }
-
// Construct Gemm arguments
Arguments args{
problem_size_,
@@ -172,14 +174,6 @@ struct BlkQ4F16GemmImpl {
if constexpr (kHasQuantOffset) {
return cutlass::Status::kErrorNotSupported;
} else {
- if constexpr (ShapeMMAThreadBlock::kM == 16) {
- if (problem_size_.m() > 16) {
- // For M > 16, the caller should have picked the
- // kernel with bigger M
- return cutlass::Status::kErrorNotSupported;
- }
- }
-
// Construct Gemm arguments
Arguments args{
problem_size_,
diff --git a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
index c81b4967d2719..1492d9846a08e 100644
--- a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
+++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
@@ -19,6 +19,7 @@
#pragma once
#include "core/common/common.h"
+#include "core/framework/float16.h"
#include "core/util/matrix_layout.h"
namespace onnxruntime {
@@ -79,6 +80,23 @@ struct BlockwiseQuantization {
return make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn);
}
+ static inline bool weight_dimension_supported(int rows, int columns) {
+ // prepacking works on a 16x16 block, so the dimensions must be multiples of 16
+ if (((rows % 16) != 0) || ((columns % 16) != 0)) {
+ return false;
+ }
+
+ // verify the dimensions are multiples of the block size
+ if (((rows % QuantBlocking::kRow) != 0) || ((columns % QuantBlocking::kColumn) != 0)) {
+ return false;
+ }
+
+ // All the above restrictions can be relaxed by adding more logic to the prepack
+ // and gemm implementation to support edge cases. But for now, these restrictions
+ // does not affect LLM weight dimensions, so we keep it simple.
+ return true;
+ }
+
/**
* @brief Prepack weight matrix to facilitate matrix loading, depending on MMA
* instruction layout.
@@ -113,10 +131,10 @@ struct BlockwiseQuantization {
gsl::span weights, // <- int4 weights, column major
gsl::span weights_prepacked // <- int4 prepacked weights tensor, same size buffer
) {
- ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 &&
- (rows % QuantBlocking::kRow) == 0 &&
- (columns % QuantBlocking::kColumn) == 0,
- "Does not support odd number of rows or columns!");
+#ifndef NDEBUG
+ ORT_ENFORCE(weight_dimension_supported(rows, columns),
+ "This function must be guarded by weight_dimension_supported()!");
+#endif
ORT_ENFORCE(weights.size() == size_t(rows * columns / 2),
"Weight tensor shape mismatch!");
ORT_ENFORCE(weights_prepacked.size() == weights.size(),
@@ -174,6 +192,10 @@ struct BlockwiseQuantization {
gsl::span scales, // <- quant scales, column major layout
gsl::span scales_prepacked // <- quant scales prepacked, same size buffer
) {
+#ifndef NDEBUG
+ ORT_ENFORCE(weight_dimension_supported(static_cast(rows), static_cast(columns)),
+ "This function must be guarded by weight_dimension_supported()!");
+#endif
auto meta_shape = get_quant_meta_shape(static_cast(rows), static_cast(columns));
ORT_ENFORCE(scales.size() == size_t(meta_shape.product()),
"Quantization scale tensor shape mismatch!");
@@ -244,8 +266,11 @@ struct BlockwiseQuantization {
gsl::span offsets, // <- quant offsets, int4, column major layout
gsl::span offsets_prepacked // <- quant offsets prepacked, double size buffer
) {
+#ifndef NDEBUG
+ ORT_ENFORCE(weight_dimension_supported(static_cast(rows), static_cast(columns)),
+ "This function must be guarded by weight_dimension_supported()!");
+#endif
auto meta_shape = get_quant_meta_shape(static_cast(rows), static_cast(columns));
-
ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0,
"Does not support odd number of rows or columns!");
ORT_ENFORCE(offsets_prepacked.size() == size_t(meta_shape.product()),
@@ -321,5 +346,58 @@ struct BlockwiseQuantization {
}
};
+static inline bool IsSm80WithWholeBlocks(
+ int weight_rows, [[maybe_unused]] int weight_cols,
+ int major, [[maybe_unused]] int minor) {
+ if (major < 8) {
+ return false;
+ }
+
+ // Kernel implementation detail:
+ // K must be aligned with thread block tile (64) due to the way
+ // predicate iterator works, it loads the partial tile
+ // in the first iteration and then the full tile in the
+ // remaining iterations. This will cause the blockwise
+ // quantization parameters to go out of step with the
+ // weights.
+ // To fix this, we need to write our own predicate iterator
+ // that loads the full tile in the first iterations and
+ // then the partial tile in the last iteration.
+ return (weight_rows % 64 == 0);
+}
+
+template
+inline bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) {
+ using Base = BlockwiseQuantization;
+ if (!Base::weight_dimension_supported(weight_rows, weight_cols)) {
+ return false;
+ }
+ return IsSm80WithWholeBlocks(weight_rows, weight_cols, major, minor);
+}
+
+static inline bool BlkQuantGemmSm80Supported(int block_size, bool col_blocking, int weight_rows, int weight_cols, int major, int minor) {
+ switch (block_size) {
+ case 16:
+ if (col_blocking) {
+ return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor);
+ } else {
+ return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor);
+ }
+ case 32:
+ if (col_blocking) {
+ return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor);
+ } else {
+ return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor);
+ }
+ case 64:
+ if (col_blocking) {
+ return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor);
+ } else {
+ return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor);
+ }
+ }
+ return false;
+}
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h
index 6e281241a3427..26e878f6120d9 100644
--- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h
+++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h
@@ -132,7 +132,7 @@ struct DummyType{
}
CUTLASS_HOST_DEVICE
- std::monostate& operator[](int /*idx */) {
+ std::monostate& operator[]([[maybe_unused]] int idx) {
return dummy_;
}
};
diff --git a/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h b/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h
index a0695dbbfd347..97f5773617cf8 100644
--- a/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h
+++ b/onnxruntime/core/mickey/gemm/kernel/quant_b4_gemm.h
@@ -386,9 +386,9 @@ struct QuantB4Gemm {
if constexpr (kSplitK > 1){
// TODO! Use thread block shape
- if (params.gemm_k_size_ < WarpShape::kK * kStages * 2) {
+ if (params.gemm_k_size_ < WarpShape::kK * kStages) {
// spliting too small, may not get enough iterations to rampup pipeline
- std::cerr << "QuantB4Gemm validation fail: split k too big, each k segment: " << params.gemm_k_size_ << " is smaller than " << (WarpShape::kK * kStages * 2) << std::endl;
+ std::cerr << "QuantB4Gemm validation fail: split k too big, each k segment: " << params.gemm_k_size_ << " is smaller than " << (WarpShape::kK * kStages) << std::endl;
return cutlass::Status::kErrorNotSupported;
}
}
@@ -534,21 +534,6 @@ struct QuantB4Gemm {
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait();
- //__syncthreads(); is this necessary since the loader is warp based?
- // if constexpr(kDebugPrintA) {
- // if (lane_idx == 0) {
- // printf("Prologue, warp: %d, WarpPtr: %p\n",
- // warp_idx, a_shared_ptr);
- // printf("\n********Dumping the shared memory of Warp %d*******\n\n", warp_idx);
-
- // for (int i = 0; i < MainLoopSharedBuffer::kASize; i += 8) {
- // for (int j = 0; j < 8; ++j) {
- // printf("%f, ", float(a_shared_ptr[i + j]));
- // }
- // printf("\n");
- // }
- // }
- // }
//
// Prefix of the Mainloop, pre-loading the double buffer in registers
@@ -678,11 +663,10 @@ struct QuantB4Gemm {
a_tile_loader.load_to_smem_split(lane_idx, a_smem_write_ptr, next_iter * kAGloadsPerIter + i);
}
if constexpr (kDebugPrintA) {
- const int lane_id = threadIdx.x % 32;
- if (lane_id == 0) {
- printf("==== A tiles =======\n");
+ if (lane_idx == 0) {
+ printf("===== Warp %d A tiles =======\n", warp_idx);
}
- const char* const format = (lane_id == 31) ? "%f, %f\n\n" : ((lane_id % 4) == 3) ? "%f, %f\n" : "%f, %f, ";
+ const char* const format = (lane_idx == 31) ? "%f, %f\n\n" : ((lane_idx % 4) == 3) ? "%f, %f\n" : "%f, %f, ";
const ElementT* a_ptr = fragment_a[iter2 % 2].data();
for (int m2_tile = 0; m2_tile < (WarpShape::kM / InstructionShape::kM); ++m2_tile, a_ptr += 8) {
printf(format, float(a_ptr[0]), float(a_ptr[1]));
diff --git a/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h b/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h
index 4a784a1a49109..2422f1b25a366 100644
--- a/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h
+++ b/onnxruntime/core/mickey/gemm/warp/quantb_meta_loader.h
@@ -47,10 +47,9 @@ void weightsMinuEight2Half(uint32_t const &weights,
//
// 1.125 instruction per weight, 9 instructions in total.
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 500))
uint32_t* b32s = reinterpret_cast(dest.data());
const uint32_t high_8s = weights >> 8;
-
-#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 500))
asm volatile(
" lop3.b32 %0, %4, 0x000f000f, %6, 0xea;\n"
" lop3.b32 %1, %4, 0x00f000f0, %7, 0xea;\n"
@@ -67,6 +66,8 @@ void weightsMinuEight2Half(uint32_t const &weights,
"r"(0x64086408));
#else
assert(false);
+ (void)(weights);
+ (void)(dest);
#endif
}
@@ -331,6 +332,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme
// only one scale/offset, so the block size cannot be smaller than 16.
static_assert(QuantBlocking::kRow % 16 == 0);
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
const int meta_k = k_iter / (QuantBlocking::kRow / 16);
half const* scales = reinterpret_cast(frag_scales.data() + meta_k * kMetaFragSize);
[[maybe_unused]] half const* offsets = nullptr;
@@ -391,6 +393,14 @@ struct QuantBScaleLoader, WarpShape_, Eleme
}
}
}
+#else
+ assert(false);
+ (void)(k_iter);
+ (void)(frag_pack_b);
+ (void)(frag_scales);
+ (void)(frag_offsets);
+ (void)(frag_b);
+#endif // __CUDA_ARCH__
}
};
@@ -611,6 +621,7 @@ struct QuantBScaleLoader, WarpShape_, Eleme
constexpr int kPackedBKStride = PackedBSize / kPackedBNTiles;
static_assert(kPackedBKStride * kPackedBNTiles == PackedBSize);
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Row-wise quantization, every row has its own scale/offset
CUTLASS_PRAGMA_UNROLL
for (int nn = 0; nn < (WarpShape::kN / 16); ++nn) {
@@ -680,6 +691,14 @@ struct QuantBScaleLoader, WarpShape_, Eleme
}
}
}
+#else
+ assert(false);
+ (void)(k_iter);
+ (void)(frag_pack_b);
+ (void)(frag_scales);
+ (void)(frag_offsets);
+ (void)(frag_b);
+#endif // __CUDA_ARCH__
}
};
diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.cc b/onnxruntime/core/optimizer/gpu_ops_prepack.cc
new file mode 100644
index 0000000000000..aed5817c21868
--- /dev/null
+++ b/onnxruntime/core/optimizer/gpu_ops_prepack.cc
@@ -0,0 +1,326 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+//
+// Module Abstract:
+// This module defines the logic for prepacking weights
+// (aka Initializers in onnxruntime) of GPU operators.
+// Unlike CPU operators, overriding the PrePack() method
+// of class OpKernel results in GPU memory fragmentation.
+// So we try to rewrite the weight tensors during graph
+// optimization phase to avoid this problem
+//
+// Unfortunately, there are still some seriouse problems
+// with this approach:
+// 1. Rewriting of the initializer tensors is restricted
+// by operator shape inferencing rules. For example,
+// there are 3 initializers for MatMulNBits,
+// we can't combine them into a single initializer.
+// And we have to make sure the operator's shape inference
+// logic does NOT verify the initializer's shape.
+// 2. These rewriting logic is tightly coupled to each GPU
+// operators. It really should be defined together with
+// these operators, instead of defining them in a complete
+// different module.
+// 3. The logic of prepacking depends on underlying GPU
+// hardware. Currently this part is hard-coded for SM80.
+
+#if defined(USE_CUDA) && !defined(USE_ROCM)
+
+#include "core/graph/graph_utils.h"
+#include "core/optimizer/initializer.h"
+#include "core/optimizer/gpu_ops_prepack.h"
+#include "core/optimizer/utils.h"
+#include "core/graph/graph_utils.h"
+
+#include "blk_q4/f16_prepack_sm80.h"
+
+#include "core/providers/cuda/cuda_provider_factory.h"
+#include "core/providers/cuda/cuda_execution_provider_info.h"
+
+namespace onnxruntime {
+
+extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
+
+/**
+ * @brief Read initialized tensor from protobuf, and store it in ort_value.
+ * Keep in mind that ort_value is the owner of the tensor memory after calling this function.
+ */
+inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) {
+ const ONNX_NAMESPACE::TensorProto* tensor_proto;
+ ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto),
+ "Missing initializer for ", arg->Name());
+
+ return utils::TensorProtoToOrtValue(
+ Env::Default(), graph.ModelPath(), *tensor_proto,
+ std::make_shared(), ort_value);
+}
+
+template
+inline gsl::span make_span(std::string& str) {
+ return gsl::make_span(reinterpret_cast(str.data()), str.size() / sizeof(T));
+}
+
+//
+// Prepacking logic specific to MatMulNBits on sm80
+//
+
+static inline bool IsNodeMatMulNbitsFp16(const Node& node) {
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain)) {
+ return false;
+ }
+ const auto* acts = node.InputDefs()[0];
+ if (acts == nullptr || acts->Type() == nullptr || acts->Type()->find("float16") == std::string::npos) {
+ return false;
+ }
+ return true;
+}
+
+template
+void Sm80BlkQ4PrepackT(
+ int rows, int columns,
+ gsl::span weights,
+ gsl::span scales,
+ gsl::span zp,
+ std::string& packed_w,
+ std::string& packed_scales,
+ std::string& packed_zp) {
+ using Base = onnxruntime::cuda::BlockwiseQuantization<
+ MLFloat16,
+ block_size,
+ 4,
+ column_quant_blk>;
+ const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns);
+ const auto meta_shape = Base::get_quant_meta_shape(rows, columns);
+
+ packed_w.resize(SafeInt(q_weight_shape.product() * sizeof(uint8_t)));
+ Base::prepack_weights(
+ rows, columns, weights,
+ make_span(packed_w));
+
+ packed_scales.resize(SafeInt(meta_shape.product() * sizeof(MLFloat16)));
+ Base::prepack_quant_scales(
+ rows, columns, scales,
+ make_span(packed_scales));
+
+ if (!zp.empty()) {
+ packed_zp.resize(SafeInt(meta_shape.product() * sizeof(uint8_t)));
+ Base::prepack_quant_offsets(
+ rows, columns, zp,
+ make_span(packed_zp));
+ }
+}
+
+void Sm80BlkQ4Prepack(
+ int block_size, bool column_quant_blk,
+ int rows, int columns,
+ gsl::span weights,
+ gsl::span scales,
+ gsl::span zp,
+ std::string& packed_w,
+ std::string& packed_scales,
+ std::string& packed_zp) {
+ switch (block_size) {
+ case 16:
+ if (column_quant_blk) {
+ Sm80BlkQ4PrepackT<16, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp);
+ } else {
+ Sm80BlkQ4PrepackT<16, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp);
+ }
+ break;
+ case 32:
+ if (column_quant_blk) {
+ Sm80BlkQ4PrepackT<32, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp);
+ } else {
+ Sm80BlkQ4PrepackT<32, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp);
+ }
+ break;
+ case 64:
+ if (column_quant_blk) {
+ Sm80BlkQ4PrepackT<64, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp);
+ } else {
+ Sm80BlkQ4PrepackT<64, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp);
+ }
+ break;
+ default:
+ ORT_THROW("Unsupported block size: ", block_size);
+ }
+}
+
+/**
+ *@brief Prepack weights of the operator MatMulNBits.
+ * The caller should make sure the node is of type MatMulNBits.
+ */
+Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) {
+ modified = false;
+ int64_t att_i;
+
+ //
+ // Verify prepacking is needed and supported
+ //
+ Status status = graph_utils::TryGetNodeAttribute(node, "prepacked", att_i);
+ bool prepacked = status.IsOK() ? att_i != 0 : false;
+ if (prepacked) {
+ return Status::OK(); // already prepacked, nothing to do
+ }
+
+ ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "bits", att_i));
+ int nbits = SafeInt(att_i);
+ if (nbits != 4) {
+ return Status::OK(); // only support 4 bits for now
+ }
+
+ // A single dimension can not exceed 2G yet.
+ ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "K", att_i));
+ int k = SafeInt(att_i);
+ ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "N", att_i));
+ int n = SafeInt(att_i);
+
+ ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "block_size", att_i));
+ int block_size = SafeInt(att_i);
+
+ status = graph_utils::TryGetNodeAttribute(node, "column_wise_blocking", att_i);
+ bool column_wise_quant_blk = status.IsOK() ? att_i != 0 : true;
+
+ auto* provider_info = TryGetProviderInfo_CUDA();
+ ORT_ENFORCE(provider_info != nullptr, "Failed to query CUDA provider info while prepacking cuda operators.");
+ int major, minor;
+ ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr,
+ "Failed to query CUDA device version while prepacking cuda operators.");
+
+ if (!onnxruntime::cuda::BlkQuantGemmSm80Supported(block_size, column_wise_quant_blk, k, n, major, minor)) {
+ return Status::OK(); // not supported
+ }
+
+ //
+ // Verification passed, start prepacking
+ //
+ auto& node_name = node.Name();
+ auto& mutable_input_defs = node.MutableInputDefs();
+ if (mutable_input_defs.size() < 3 || mutable_input_defs.size() > 4) {
+ return Status::OK(); // not supported
+ }
+
+ NodeArg* old_weights_arg = mutable_input_defs[1];
+ NodeArg* old_scales_arg = mutable_input_defs[2];
+ NodeArg* old_zp_arg = nullptr;
+
+ // holders of the packed weight tensor memory
+ std::string packed_weights;
+ std::string packed_scales;
+ std::string packed_zp;
+
+ {
+ // owners of the weight tensor memory, keep around until consumed by the prepacking function
+ OrtValue weights_val;
+ OrtValue scales_val;
+ OrtValue zp_val;
+
+ ORT_RETURN_IF_ERROR(GetOrtValue(old_weights_arg, graph, weights_val));
+ const gsl::span weights = weights_val.GetMutable()->DataAsSpan();
+
+ ORT_RETURN_IF_ERROR(GetOrtValue(old_scales_arg, graph, scales_val));
+ const gsl::span scales = scales_val.GetMutable()->DataAsSpan();
+
+ gsl::span zp;
+ if (mutable_input_defs.size() > 3) {
+ old_zp_arg = mutable_input_defs[3];
+ if (old_zp_arg != nullptr && old_zp_arg->Exists()) {
+ ORT_RETURN_IF_ERROR(GetOrtValue(old_zp_arg, graph, zp_val));
+ Tensor* zp_tensor_ptr = zp_val.GetMutable();
+ if (!zp_tensor_ptr->IsDataType()) {
+ return Status::OK(); // not supported
+ }
+ zp = zp_tensor_ptr->DataAsSpan();
+ }
+ }
+
+ Sm80BlkQ4Prepack(block_size, column_wise_quant_blk, k, n, weights, scales, zp, packed_weights, packed_scales, packed_zp);
+
+#if 0
+ // debug print if prepacked tests fail
+ std::cout << " ====== packed weight ====== " << std::endl << std::hex;
+ const gsl::span packed_weights_span(reinterpret_cast(packed_weights.data()), packed_weights.size());
+ for (int r = 0; r < k; r++) {
+ for (int c = 0; c < n/2; c++) {
+ std::cout << std::setw(2) << std::setfill('0') << static_cast(packed_weights_span[c * k + r]) << " ";
+ }
+ std::cout << std::endl;
+ }
+ std::cout << std::dec;
+#endif
+ }
+
+ //
+ // write packed weight tensor to node parameters
+ //
+ ONNX_NAMESPACE::TensorProto packed_weights_proto;
+ packed_weights_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_weight"));
+ packed_weights_proto.add_dims(packed_weights.size() / sizeof(uint8_t));
+ packed_weights_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType());
+ packed_weights_proto.set_raw_data(std::move(packed_weights));
+ NodeArg& packed_weights_arg = graph_utils::AddInitializer(graph, packed_weights_proto);
+ graph.RemoveConsumerNode(old_weights_arg->Name(), &node);
+ mutable_input_defs[1] = &packed_weights_arg;
+ graph.AddConsumerNode(packed_weights_arg.Name(), &node);
+
+ ONNX_NAMESPACE::TensorProto packed_scales_proto;
+ packed_scales_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_scales"));
+ packed_scales_proto.add_dims(packed_scales.size() / sizeof(MLFloat16));
+ packed_scales_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType());
+ packed_scales_proto.set_raw_data(std::move(packed_scales));
+ NodeArg& packed_scales_arg = graph_utils::AddInitializer(graph, packed_scales_proto);
+ graph.RemoveConsumerNode(old_scales_arg->Name(), &node);
+ mutable_input_defs[2] = &packed_scales_arg;
+ graph.AddConsumerNode(packed_scales_arg.Name(), &node);
+
+ if (!packed_zp.empty()) {
+ ONNX_NAMESPACE::TensorProto packed_zp_proto;
+ packed_zp_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_zp"));
+ packed_zp_proto.add_dims(packed_zp.size() / sizeof(uint8_t));
+ packed_zp_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType());
+ packed_zp_proto.set_raw_data(std::move(packed_zp));
+ NodeArg& packed_zp_arg = graph_utils::AddInitializer(graph, packed_zp_proto);
+ graph.RemoveConsumerNode(old_zp_arg->Name(), &node);
+ mutable_input_defs[3] = &packed_zp_arg;
+ graph.AddConsumerNode(packed_zp_arg.Name(), &node);
+ }
+
+ node.AddAttribute("prepacked", static_cast(1));
+ modified = true;
+ return Status::OK();
+}
+
+Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
+ GraphViewer graph_viewer(graph);
+ const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+
+ // int fused_count = 0;
+ for (auto node_index : node_topology_list) {
+ auto* p_node = graph.GetNode(node_index);
+ if (p_node == nullptr)
+ continue; // node was removed as part of an earlier fusion
+
+ Node& node = *p_node;
+ ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
+
+ if (node.GetExecutionProviderType() != onnxruntime::kCudaExecutionProvider) {
+ continue; // only interested in CUDA nodes
+ }
+
+ // Run prepack if the node is MatMulNBits.
+ // When we have more operators to support, we should use a map to dispatch the prepack function
+ // instead of adding a whole bunch of if branches here.
+ if (IsNodeMatMulNbitsFp16(node)) {
+ bool packed = false;
+ ORT_RETURN_IF_ERROR(PackMatMulNBitsFp16(node, graph, packed));
+ modified |= packed;
+ continue;
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
+
+#endif // USE_CUDA && !USE_ROCM
diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.h b/onnxruntime/core/optimizer/gpu_ops_prepack.h
new file mode 100644
index 0000000000000..0beecde4021d9
--- /dev/null
+++ b/onnxruntime/core/optimizer/gpu_ops_prepack.h
@@ -0,0 +1,28 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if defined(USE_CUDA) && !defined(USE_ROCM)
+
+#pragma once
+
+#include "core/optimizer/graph_transformer.h"
+
+namespace onnxruntime {
+
+/**
+@Class AttentionFusion
+Rewrite graph fusing attention subgraph to a single Attention node.
+*/
+class GpuOpsPrepack : public GraphTransformer {
+ public:
+ GpuOpsPrepack() noexcept
+ : GraphTransformer("GpuOpsPrepack", InlinedHashSet{onnxruntime::kCudaExecutionProvider}) {}
+
+ Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
+
+ private:
+};
+
+} // namespace onnxruntime
+
+#endif // USE_CUDA && !USE_ROCM
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 4298551aec412..a8404a6587859 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -6,6 +6,7 @@
#include
#include
+#include "core/framework/execution_providers.h"
#include "core/optimizer/conv_activation_fusion.h"
#include "core/optimizer/matmul_nbits_fusion.h"
#include "core/optimizer/nhwc_transformer.h"
@@ -43,6 +44,7 @@
#include "core/optimizer/gemm_activation_fusion.h"
#include "core/optimizer/gemm_sum_fusion.h"
#include "core/optimizer/gemm_transpose_fusion.h"
+#include "core/optimizer/gpu_ops_prepack.h"
#include "core/optimizer/identical_children_consolidation.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/label_encoder_fusion.h"
@@ -186,8 +188,11 @@ std::unique_ptr GenerateRuleBasedGraphTransformer(
InlinedVector> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
- const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
+ const ExecutionProviders& execution_providers, /*required by constant folding*/
const InlinedHashSet& rules_and_transformers_to_disable) {
+ // CPU EP required to run constant folding
+ const IExecutionProvider& cpu_execution_provider = *execution_providers.Get(onnxruntime::kCpuExecutionProvider);
+
InlinedVector> transformers;
const bool disable_quant_qdq =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
@@ -387,7 +392,16 @@ InlinedVector> GenerateTransformers(
// PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu,
// while we can fuse more activation.
transformers.emplace_back(std::make_unique(cpu_ep));
-#endif
+
+#if defined(USE_CUDA) && !defined(USE_ROCM)
+ // Cuda weight prepacking.
+ auto* cuda_ep = execution_providers.Get(onnxruntime::kCudaExecutionProvider);
+ if (cuda_ep != nullptr) {
+ transformers.emplace_back(std::make_unique());
+ }
+#endif // USE_CUDA && !USE_ROCM
+
+#endif // !defined(DISABLE_CONTRIB_OPS)
} break;
@@ -408,8 +422,9 @@ InlinedVector> GenerateTransformersForMinimalB
TransformerLevel level,
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
- const IExecutionProvider& cpu_execution_provider,
+ const ExecutionProviders& execution_providers,
const InlinedHashSet& rules_and_transformers_to_disable) {
+ const auto& cpu_execution_provider = *execution_providers.Get(onnxruntime::kCpuExecutionProvider);
InlinedVector> transformers;
const bool saving = std::holds_alternative(apply_context);
diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
index 103c79c93b2ca..ca794108d43c6 100644
--- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
+++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc
@@ -71,7 +71,7 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA {
return nullptr;
}
- OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) override {
+ OrtStatus* GetCurrentGpuDeviceId(_Out_ int* device_id) override {
auto cuda_err = cudaGetDevice(device_id);
if (cuda_err != cudaSuccess) {
return CreateStatus(ORT_FAIL, "Failed to get device id.");
@@ -79,6 +79,23 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA {
return nullptr;
}
+ OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) override {
+ int device_id;
+ auto cuda_err = cudaGetDevice(&device_id);
+ if (cuda_err != cudaSuccess) {
+ return CreateStatus(ORT_FAIL, "Failed to get device id.");
+ }
+ cudaDeviceProp prop;
+ cuda_err = cudaGetDeviceProperties(&prop, device_id);
+ if (cuda_err != cudaSuccess) {
+ return CreateStatus(ORT_FAIL, "Failed to get device properties.");
+ }
+ *major = prop.major;
+ *minor = prop.minor;
+
+ return nullptr;
+ }
+
std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) override {
return std::make_unique(device_id, name);
}
diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/onnxruntime/core/providers/cuda/cuda_provider_factory.h
index 4d5ef658f6be0..320eb4ed82cfc 100644
--- a/onnxruntime/core/providers/cuda/cuda_provider_factory.h
+++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.h
@@ -22,7 +22,8 @@ class NvtxRangeCreator;
struct ProviderInfo_CUDA {
virtual OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) = 0;
- virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0;
+ virtual OrtStatus* GetCurrentGpuDeviceId(_Out_ int* device_id) = 0;
+ virtual OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) = 0;
virtual std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) = 0;
virtual std::unique_ptr CreateCUDAPinnedAllocator(const char* name) = 0;
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index 3ef6490a56ded..1e81f5b2bd592 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -1613,15 +1613,15 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Status ApplyOrtFormatModelRuntimeOptimizations(
onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options,
- const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) {
+ const InlinedHashSet& optimizers_to_disable, const ExecutionProviders& execution_providers) {
bool modified = false;
for (int level = static_cast(TransformerLevel::Level2);
level <= static_cast(session_options.graph_optimization_level);
++level) {
const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild(
- static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep,
- optimizers_to_disable);
+ static_cast(level), session_options, SatRuntimeOptimizationLoadContext{},
+ execution_providers, optimizers_to_disable);
for (const auto& transformer : transformers) {
ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger));
@@ -2007,9 +2007,8 @@ common::Status InferenceSession::Initialize() {
*session_state_, session_options_.config_options, *session_logger_));
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
- const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);
ORT_RETURN_IF_ERROR_SESSIONID_(
- ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep));
+ ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, execution_providers_));
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
}
@@ -3159,7 +3158,6 @@ common::Status InferenceSession::AddPredefinedTransformers(
TransformerLevel graph_optimization_level,
MinimalBuildOptimizationHandling minimal_build_optimization_handling,
RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const {
- const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);
for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) {
TransformerLevel level = static_cast(i);
if (graph_optimization_level >= level) {
@@ -3170,7 +3168,7 @@ common::Status InferenceSession::AddPredefinedTransformers(
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;
if (use_full_build_optimizations) {
- return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep,
+ return optimizer_utils::GenerateTransformers(level, session_options_, execution_providers_,
optimizers_to_disable_);
} else {
const auto sat_context =
@@ -3179,8 +3177,8 @@ common::Status InferenceSession::AddPredefinedTransformers(
? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{
record_runtime_optimization_produced_op_schema_fn}}
: SatApplyContextVariant{SatDirectApplicationContext{}};
- return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep,
- optimizers_to_disable_);
+ return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context,
+ execution_providers_, optimizers_to_disable_);
}
}();
diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h
index 43843da3fb96e..b4cf8cf518564 100644
--- a/onnxruntime/core/util/matrix_layout.h
+++ b/onnxruntime/core/util/matrix_layout.h
@@ -428,18 +428,18 @@ class MatrixRef {
/// Returns a reference to the element at a given Coord
ORT_FORCEINLINE
Reference at(MatCoord const& coord) const {
- return data_[offset(coord)];
+ return data_[static_cast(offset(coord))];
}
ORT_FORCEINLINE
Reference at(int row, int col) const {
- return data_[offset(make_Position(row, col))];
+ return data_[static_cast(offset(make_Position(row, col)))];
}
/// Returns a reference to the element at a given Coord
ORT_FORCEINLINE
Reference operator[](MatCoord const& coord) const {
- return data_[offset(coord)];
+ return data_[static_cast(offset(coord))];
}
};
diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h
index 942b1c4d2c2ad..4e9df0369c011 100644
--- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h
+++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h
@@ -15,12 +15,163 @@
#pragma once
-#include "core/util/matrix_layout.h"
+#include
+#include "core/mickey/blk_q4/f16_prepack_sm80.h"
#include "core/common/common.h"
namespace onnxruntime {
namespace test {
+/**
+ * @brief Generate a set of quantized weights, scales and offsets
+ * and dequantized weights for testing quantization and
+ * dequantization. All outputs are column major layout.
+ *
+ * @tparam ElementT The type of the dequantized weights.
+ * @tparam block_size The block size of the quantization.
+ * @tparam col_blocking Whether to use column blocking (all elements of
+ * a block comes from a single column) or row blocking
+ * @tparam has_offsets Whether to generate offsets.
+ *
+ * @param[in] rows The number of rows of the weight matrix.
+ * @param[in] columns The number of columns of the weight matrix.
+ * @param[out] dequants The dequantized weights, column major layout.
+ * @param[out] q_weights The quantized weights, column major layout.
+ * @param[out] q_scales The scales, column major layout.
+ * @param[out] q_zp The zero points, column major layout.
+ */
+template
+inline void blkq4_weights_gen(
+ int rows, int columns,
+ std::vector& dequants,
+ std::vector& q_weights,
+ std::vector& q_scales,
+ std::vector& q_zp) {
+ using Base = onnxruntime::cuda::BlockwiseQuantization<
+ ElementT,
+ block_size,
+ 4,
+ col_blocking>;
+
+ using QuantBlocking = typename Base::QuantBlocking;
+ using ElementW = typename Base::ElementW;
+ using LayoutWPack = typename Base::LayoutWPack;
+ using ElementQOffset = typename Base::ElementQOffset;
+
+ static_assert(std::is_same::value);
+ static_assert(std::is_same::value);
+ static_assert(std::is_same::value);
+
+ unsigned int seed = 28571; // Replace with desired seed value
+ std::seed_seq seq{seed};
+ std::mt19937 gen(seq);
+ std::uniform_int_distribution dis(0, 8192);
+
+ const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns);
+ const auto meta_shape = Base::get_quant_meta_shape(rows, columns);
+
+ //
+ // For testing quantization and dequantization, it is not straight
+ // forward to avoid flaky tests due to rounding errors. The way we
+ // try to achieve this is to:
+ // 1. Generate a set of quantized weights, scales and offsets
+ // 2. Dequantize the weights
+ // 3. Quantize the dequantized weights
+ // 4. Compare the dequantied-and-then-quantized weights with
+ // the original quantized weights
+ //
+ // Random filling of the initial values are key to get this right.
+ // For weights, we must ensure each block gets a full range of
+ // values, i.e. must contain 0 and 15. And for scales, they must
+ // all be positive.
+ //
+
+ q_weights.resize(q_weight_shape.product());
+ MatrixRef tensor_q_weight(
+ q_weights, make_Position(rows / 2, columns));
+ int v = 7;
+ for (int c = 0; c < tensor_q_weight.shape()[1]; c++) {
+ for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) {
+ uint8_t v0 = static_cast(v);
+ v = (v + 5) % 16;
+ if (v == 11 || v == 7 || v == 3) {
+ // making the cycle 13 instead of 16, avoiding same values in a row
+ v = (v + 5) % 16;
+ }
+ uint8_t v1 = 0;
+ if (r + 1 < rows) {
+ v1 = static_cast(v);
+ v = (v + 5) % 16;
+ if (v == 11 || v == 7 || v == 3) {
+ // making the cycle 13 instead of 16, avoiding same values in a row
+ v = (v + 5) % 16;
+ }
+ }
+
+ tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0);
+ }
+ }
+
+ q_scales.resize(meta_shape.product());
+ for (size_t i = 0; i < q_scales.size(); i++) {
+ uint32_t vscale = dis(gen);
+ uint32_t m = (vscale % 63) + 1;
+ uint32_t e = (vscale >> 6) % 4;
+ q_scales[i] = ElementT(m / static_cast(1 << (2 + e)));
+ }
+ MatrixRef tensor_scale(
+ q_scales, meta_shape);
+
+ MatrixRef tensor_offset;
+ if constexpr (has_offsets) {
+ const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]);
+ q_zp.resize(zp_shape.product());
+ tensor_offset = MatrixRef(
+ q_zp, zp_shape);
+ for (int c = 0; c < zp_shape[1]; c++) {
+ for (int r = 0; r < zp_shape[0]; ++r) {
+ uint8_t v0 = dis(gen) % 16;
+ uint8_t v1 = 8;
+ if (r * 2 + 1 < meta_shape[0]) {
+ v1 = dis(gen) % 16;
+ }
+ tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4));
+ }
+ }
+ }
+
+ dequants.resize(rows * columns);
+ MatrixRef tensor_dequant(dequants, make_Position(rows, columns));
+
+ // Dequantize weights and save into matrix B
+ for (int col = 0; col < tensor_dequant.shape()[1]; ++col) {
+ for (int row = 0; row < tensor_dequant.shape()[0]; ++row) {
+ auto weight_cord = make_Position(row / 2, col);
+ auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn);
+ uint8_t offset = 8;
+ if constexpr (has_offsets) {
+ if (scale_cord[0] % 2 == 0) {
+ offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f;
+ } else {
+ offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4;
+ }
+ }
+ int w = 0;
+ if (row % 2 == 0) {
+ w = int(tensor_q_weight.at(weight_cord) & 0x0f);
+ } else {
+ w = int(tensor_q_weight.at(weight_cord) >> 4);
+ }
+ float scale = float(tensor_scale.at(scale_cord));
+ float dequant = scale * float(w - offset);
+ tensor_dequant.at(row, col) = ElementT(dequant);
+ // Prints for help debugging in case of test failure
+ // fprintf(stderr, "%f~%d-%d|%f, ", dequant, w, offset, scale);
+ }
+ // fprintf(stderr, "\n");
+ }
+}
+
static inline void sm80_prepack_weights_ref(
int rows,
int columns,
diff --git a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc
new file mode 100644
index 0000000000000..949fbef3da689
--- /dev/null
+++ b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc
@@ -0,0 +1,393 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "graph_transform_test_builder.h"
+#include "core/mlas/inc/mlas.h"
+#include "core/graph/graph.h"
+#include "core/optimizer/initializer.h"
+
+#include "core/mlas/inc/mlas_q4.h"
+#include "core/providers/cuda/cuda_provider_factory_creator.h"
+#include "core/mickey/blk_q4/f16_prepack_sm80.h"
+#include "core/providers/cuda/cuda_provider_factory.h"
+#include "core/providers/cuda/cuda_execution_provider_info.h"
+
+#include "test/cuda_host/blkq4_fp16_quant_sm80.h"
+#include "test/compare_ortvalue.h"
+#include "test/test_environment.h"
+#include "test/util/include/asserts.h"
+#include "test/util/include/inference_session_wrapper.h"
+
+namespace onnxruntime {
+
+extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
+
+namespace test {
+#ifndef DISABLE_CONTRIB_OPS
+
+std::shared_ptr LoadCudaEp() {
+ try {
+ OrtCUDAProviderOptions cuda_options;
+ auto factory = onnxruntime::CudaProviderFactoryCreator::Create(&cuda_options);
+ if (!factory) {
+ return nullptr;
+ }
+ return factory->CreateProvider();
+ } catch (const ::onnxruntime::OnnxRuntimeException& e) {
+ std::cerr << "LoadCudaEp: " << e.what() << std::endl;
+ return nullptr;
+ }
+}
+
+/**
+ * @brief Testing helper for GPU prepacking logic in the graph transformer.
+ * This is an modification of the TransformerTester function from
+ * onnxruntime/test/optimizer/graph_transform_test_builder.cc
+ * with:
+ * - the addition of cuda execution provider in the session.
+ * - a different location for the model checker, right after session initialization
+ * as the initializers will be deleted during session run.
+ */
+void GpuPrepackTester(
+ const std::shared_ptr& cuda_ep,
+ const std::function& build_test_case,
+ const std::function& check_transformed_graph,
+ TransformerLevel baseline_level,
+ TransformerLevel target_level,
+ int opset_version = 12,
+ double per_sample_tolerance = 0.001,
+ double relative_per_sample_tolerance = 0.001,
+ const std::function& add_session_options = {},
+ const InlinedHashSet& disabled_optimizers = {}) {
+ // Build the model for this test.
+ std::unordered_map domain_to_version;
+ domain_to_version[kOnnxDomain] = opset_version;
+ domain_to_version[kMSDomain] = 1;
+ Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
+ domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
+ Graph& graph = model.MainGraph();
+ ModelTestBuilder helper(graph);
+ ASSERT_TRUE(build_test_case);
+ build_test_case(helper);
+ helper.SetGraphOutputs();
+ ASSERT_STATUS_OK(model.MainGraph().Resolve());
+
+ // Serialize the model to a string.
+ std::string model_data;
+ model.ToProto().SerializeToString(&model_data);
+
+ auto run_model = [&](TransformerLevel level, std::vector& fetches) {
+ SessionOptions session_options;
+ session_options.graph_optimization_level = level;
+ if (level == target_level) {
+ // we don't really need the model file, but it seems to be the only way to keep the
+ // transformed initializers so that they can be verified.
+ session_options.optimized_model_filepath =
+ ToPathString("gpu_prepack_test_model_opt_level_" + std::to_string(static_cast(level)) + ".onnx");
+ }
+ if (add_session_options) {
+ add_session_options(session_options);
+ }
+ InferenceSessionWrapper session{session_options, GetEnvironment()};
+ ASSERT_STATUS_OK(session.RegisterExecutionProvider(cuda_ep));
+
+ ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size())));
+ if (!disabled_optimizers.empty()) {
+ ASSERT_STATUS_OK(session.FilterEnabledOptimizers(InlinedHashSet{disabled_optimizers}));
+ }
+
+ ASSERT_STATUS_OK(session.Initialize());
+
+ RunOptions run_options;
+ ASSERT_STATUS_OK(session.Run(run_options,
+ helper.feeds_,
+ helper.output_names_,
+ &fetches));
+
+ if (level == target_level) {
+ if (check_transformed_graph) {
+ check_transformed_graph(session);
+ }
+ }
+ };
+
+ std::vector baseline_fetches;
+ ASSERT_NO_FATAL_FAILURE(run_model(baseline_level, baseline_fetches));
+
+ std::vector