Skip to content

Commit

Permalink
Add fp16xq4 matmul sm80 cuda kernel to ORT operator
Browse files Browse the repository at this point in the history
And to use graph transformer as prepack
  • Loading branch information
chenfucn committed Jul 3, 2024
1 parent f39ee14 commit d9fae0c
Show file tree
Hide file tree
Showing 34 changed files with 1,611 additions and 118 deletions.
4 changes: 4 additions & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_src

onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT})

# using optimizer as cuda prepacking, so extra headers are needed
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)

if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT})
onnxruntime_add_include_to_target(onnxruntime_optimizer nlohmann_json::nlohmann_json)
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)
"${TEST_SRC_DIR}/optimizer/*.h"
)

if (MSVC AND ((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")))
set_source_files_properties("${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" PROPERTIES COMPILE_FLAGS "/bigobj")
list(REMOVE_ITEM onnxruntime_test_optimizer_src
"${TEST_SRC_DIR}/optimizer/gpu_op_prepack_test.cc"
)
endif()

set(onnxruntime_test_framework_src_patterns
"${TEST_SRC_DIR}/framework/*.cc"
"${TEST_SRC_DIR}/framework/*.h"
Expand Down
8 changes: 8 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2915,6 +2915,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>number of bits used for weight quantization (default 4)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
<dt><tt>column_wise_blocking</tt> : int</dt>
<dd>whether to quantize weight columnwise (value 1, default), or rowwise (value 0)</dd>
<dt><tt>prepacked</tt> : int</dt>
<dd>
Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default).
This property should NEVER be set by user. It is set by ONNX Runtime internally during
model loading time.
</dd>
</dl>

#### Inputs (3 - 6)
Expand Down
5 changes: 3 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

namespace onnxruntime {
class IExecutionProvider;
class ExecutionProviders;

namespace optimizer_utils {

Expand All @@ -48,7 +49,7 @@ std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(
InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const ExecutionProviders& execution_providers /* cpu ep required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -77,7 +78,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
TransformerLevel level,
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const ExecutionProviders& execution_providers,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down
58 changes: 49 additions & 9 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,44 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;

#ifndef USE_ROCM
template <>
Status MatMulNBits<MLFloat16>::PrepackedGemm(
cudaStream_t stream,
int M,
const Tensor* a,
const Tensor* b,
const Tensor* scales,
const Tensor* zero_points,
Tensor* Y) const {
uint8_t const* zero_points_ptr = nullptr;
size_t zero_points_size = 0;
if (zero_points != nullptr) {
zero_points_ptr = zero_points->Data<uint8_t>();
zero_points_size = zero_points->Shape().Size();
}

return blkq4_fp16_gemm_sm80_dispatch<MLFloat16>(
int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,

Check warning on line 36 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:36: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
a->Data<MLFloat16>(), a->Shape().Size(),
b->Data<uint8_t>(), b->Shape().Size(),
scales->Data<MLFloat16>(), scales->Shape().Size(),
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
}
#endif // !USE_ROCM

template <typename T>
Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
using CudaT = typename ToCudaType<T>::MappedType;

const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const Tensor* reorder_idx = ctx->Input<Tensor>(4);

const auto* a_data = a->Data<T>();
const uint8_t* blob_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<T>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

typedef typename ToCudaType<T>::MappedType CudaT;

constexpr bool transa = false;
constexpr bool transb = true;
MatMulComputeHelper helper;
Expand All @@ -43,6 +63,26 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

if (prepack_ > 0) {
#ifndef USE_ROCM
ORT_RETURN_IF(reorder_idx != nullptr,
"Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!");
ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType<T>(),
"Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!");
return PrepackedGemm(
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
static_cast<int>(helper.M()), a, b, scales, zero_points, Y);
#else
ORT_RETURN_IF(true, "Prepacked gemm is not supported for MatMulNBits op.");
#endif // !USE_ROCM
}

const auto* a_data = a->Data<T>();
const uint8_t* blob_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<T>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

bool is_4bit_done = (reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType<T>()) &&
TryMatMul4Bits(
Expand Down
Loading

0 comments on commit d9fae0c

Please sign in to comment.