Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connecting fp16xq4 gemm kernels (optimized for A100) to MatMulNBits<fp16> operator #21083

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_src

onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT})

# using optimizer as cuda prepacking, so extra headers are needed
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)

if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT})
onnxruntime_add_include_to_target(onnxruntime_optimizer nlohmann_json::nlohmann_json)
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)
"${TEST_SRC_DIR}/optimizer/*.h"
)

if (MSVC AND ((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")))
set_source_files_properties("${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" PROPERTIES COMPILE_FLAGS "/bigobj")
list(REMOVE_ITEM onnxruntime_test_optimizer_src
"${TEST_SRC_DIR}/optimizer/gpu_op_prepack_test.cc"
)
endif()

set(onnxruntime_test_framework_src_patterns
"${TEST_SRC_DIR}/framework/*.cc"
"${TEST_SRC_DIR}/framework/*.h"
Expand Down
8 changes: 8 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2915,6 +2915,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>number of bits used for weight quantization (default 4)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
<dt><tt>column_wise_blocking</tt> : int</dt>
<dd>whether to quantize weight columnwise (value 1, default), or rowwise (value 0)</dd>
<dt><tt>prepacked</tt> : int</dt>
<dd>
Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default).
This property should NEVER be set by user. It is set by ONNX Runtime internally during
model loading time.
</dd>
</dl>

#### Inputs (3 - 6)
Expand Down
5 changes: 3 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

namespace onnxruntime {
class IExecutionProvider;
class ExecutionProviders;

namespace optimizer_utils {

Expand All @@ -48,7 +49,7 @@ std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(
InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const ExecutionProviders& execution_providers /* cpu ep required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -77,7 +78,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
TransformerLevel level,
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const ExecutionProviders& execution_providers,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down
62 changes: 51 additions & 11 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,44 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;

#if !defined(USE_MIGRAPHX) && !defined(USE_ROCM)
template <>
Status MatMulNBits<MLFloat16>::PrepackedGemm(
cudaStream_t stream,
int M,
const Tensor* a,
const Tensor* b,
const Tensor* scales,
const Tensor* zero_points,
Tensor* Y) const {
uint8_t const* zero_points_ptr = nullptr;
size_t zero_points_size = 0;
if (zero_points != nullptr) {
zero_points_ptr = zero_points->Data<uint8_t>();
zero_points_size = zero_points->Shape().Size();
}

return blkq4_fp16_gemm_sm80_dispatch<MLFloat16>(
int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,
a->Data<MLFloat16>(), a->Shape().Size(),
b->Data<uint8_t>(), b->Shape().Size(),
scales->Data<MLFloat16>(), scales->Shape().Size(),
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
}
#endif // !defined(USE_MIGRAPHX) && !defined(USE_ROCM)

template <typename T>
Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
using CudaT = typename onnxruntime::cuda::ToCudaType<T>::MappedType;

const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const Tensor* reorder_idx = ctx->Input<Tensor>(4);

const auto* a_data = a->Data<T>();
const uint8_t* blob_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<T>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

typedef typename ToCudaType<T>::MappedType CudaT;

constexpr bool transa = false;
constexpr bool transb = true;
MatMulComputeHelper helper;
Expand All @@ -43,6 +63,26 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

if (prepack_ > 0) {
#if !defined(USE_MIGRAPHX) && !defined(USE_ROCM)
ORT_RETURN_IF(reorder_idx != nullptr,
"Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!");
ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType<T>(),
"Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!");
return PrepackedGemm(
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
static_cast<int>(helper.M()), a, b, scales, zero_points, Y);
#else
ORT_RETURN_IF(true, "Prepacked gemm is not supported for MatMulNBits op.");
#endif // !defined(USE_MIGRAPHX) && !defined(USE_ROCM)
}

const auto* a_data = a->Data<T>();
const uint8_t* blob_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<T>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

bool is_4bit_done = (reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType<T>()) &&
TryMatMul4Bits(
Expand Down Expand Up @@ -115,8 +155,8 @@ cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost);
delete[] b_data_cpu;
#endif

const CudaT alpha = ToCudaType<T>::FromFloat(1.f);
const CudaT zero = ToCudaType<T>::FromFloat(0.f);
const CudaT alpha = onnxruntime::cuda::ToCudaType<T>::FromFloat(1.f);
const CudaT zero = onnxruntime::cuda::ToCudaType<T>::FromFloat(0.f);

if (helper.OutputOffsets().size() == 1) {
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
Expand Down
Loading
Loading