Skip to content

Commit

Permalink
bug fix: can't use A tensor shape[0] for M
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Mar 13, 2024
1 parent 6bd5c0c commit 65573be
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 144 deletions.
4 changes: 0 additions & 4 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,6 @@
endif()
config_cuda_provider_shared_module(onnxruntime_providers_cuda)

# TODO only needed in DEBUG builds, need cmake expert advice on how to do that
set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cu PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ")
set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cc PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ")

install(TARGETS onnxruntime_providers_cuda
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
34 changes: 17 additions & 17 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

template<>
template <>
Status MatMulNBits<MLFloat16>::PrepackedGemm(
cudaStream_t stream,
const Tensor* a,
const Tensor* b,
const Tensor* scales,
const Tensor* zero_points,
Tensor* Y) const {
int64_t M = a->Shape()[0];
cudaStream_t stream,
int M,
const Tensor* a,
const Tensor* b,
const Tensor* scales,
const Tensor* zero_points,
Tensor* Y) const {
uint8_t const* zero_points_ptr = nullptr;
size_t zero_points_size = 0;
if (zero_points != nullptr) {
Expand All @@ -32,12 +32,12 @@ Status MatMulNBits<MLFloat16>::PrepackedGemm(
}

return blkq4_fp16_gemm_sm80_dispatch<MLFloat16>(
int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,
a->Data<MLFloat16>(), a->Shape().Size(),
b->Data<uint8_t>(), b->Shape().Size(),
scales->Data<MLFloat16>(), scales->Shape().Size(),
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,

Check warning on line 35 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:35: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
a->Data<MLFloat16>(), a->Shape().Size(),
b->Data<uint8_t>(), b->Shape().Size(),
scales->Data<MLFloat16>(), scales->Shape().Size(),
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
}

template <typename T>
Expand All @@ -59,14 +59,14 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

if (prepack_ > 0){
if (prepack_ > 0) {
ORT_RETURN_IF(reorder_idx != nullptr,
"Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!");
ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType<T>(),
"Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!");
return PrepackedGemm(
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
a, b, scales, zero_points, Y);
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
helper.M(), a, b, scales, zero_points, Y);
}

const auto* a_data = a->Data<T>();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream,

const cutlass::gemm::GemmCoord problem_size = {m, n, k};

ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct");
ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct: ", a.size_bytes(), " vs m: ", m, "k: ", k , " size: ", m * k * sizeof(ElementDequant));

Check warning on line 393 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:393: Lines should be <= 120 characters long [whitespace/line_length] [2]
cutlass::TensorRef<ElementDequant const, LayoutInputA> ref_a(
reinterpret_cast<ElementDequant const *>(a.data()),
LayoutInputA::packed({m, k}));
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ class MatMulNBits final : public CudaKernel {
info.GetAttrOrDefault<int64_t>("prepacked", &prepack_, int64_t(0));
}

Status PrepackedGemm(cudaStream_t stream, const Tensor* a, const Tensor* b,
const Tensor* scales, const Tensor* zero_points, Tensor* Y) const {
Status PrepackedGemm([[maybe_unused]] cudaStream_t stream,
[[maybe_unused]] int M,
[[maybe_unused]] const Tensor* a,
[[maybe_unused]] const Tensor* b,
[[maybe_unused]] const Tensor* scales,
[[maybe_unused]] const Tensor* zero_points,
[[maybe_unused]] Tensor* Y) const {
ORT_THROW("Prepacked gemm is not supported for MatMulNBits op.");
}

Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/graph/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node,
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

/** Returns the attribute of a Node with a given name. */
static inline
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) {
static inline const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) {
const auto& attrs = node.GetAttributes();
const auto iter = attrs.find(attr_name);
return iter == attrs.end() ? nullptr : &iter->second;
Expand All @@ -49,7 +48,6 @@ inline Status TryGetNodeAttribute<int64_t>(const Node& node, const std::string&
return Status::OK();
}


/** Add a new initializer to 'graph'.
Checks that new_initializer does not already exist in 'graph' before adding it.
@returns The NodeArg for the new initializer.
Expand Down
56 changes: 27 additions & 29 deletions onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ struct BlockwiseQuantization {
static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1;

static void prepack_quant_scales(
size_t rows,
size_t columns,
int rows,
int columns,
const gsl::span<ElementT const>& scales, // <- quant scales, column major layout
const gsl::span<ElementT>& scales_prepacked // <- quant scales prepacked, same size buffer
) {
Expand Down Expand Up @@ -261,8 +261,8 @@ struct BlockwiseQuantization {
}

static void prepack_quant_offsets(
size_t rows,
size_t columns,
int rows,
int columns,
const gsl::span<uint8_t const>& offsets, // <- quant offsets, int4, column major layout
const gsl::span<uint8_t>& offsets_prepacked // <- quant offsets prepacked, double size buffer
) {
Expand Down Expand Up @@ -345,8 +345,8 @@ struct BlockwiseQuantization {
};

static inline bool IsSm80WithWholeBlocks(
int weight_rows, [[maybe_unused]] int weight_cols,
int major, [[maybe_unused]] int minor) {
int weight_rows, [[maybe_unused]] int weight_cols,
int major, [[maybe_unused]] int minor) {
if (major < 8) {
return false;
}
Expand All @@ -364,9 +364,8 @@ static inline bool IsSm80WithWholeBlocks(
return (weight_rows % 64 == 0);
}

template<typename ElementT, int block_size, bool col_blocking>
inline
bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) {
template <typename ElementT, int block_size, bool col_blocking>
inline bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) {
using Base = BlockwiseQuantization<ElementT, block_size, 4, col_blocking>;
if (!Base::weight_dimension_supported(weight_rows, weight_cols)) {
return false;
Expand All @@ -375,26 +374,25 @@ bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int
}

static inline bool BlkQuantGemmSm80Supported(int block_size, bool col_blocking, int weight_rows, int weight_cols, int major, int minor) {
switch (block_size)
{
case 16:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, false>(weight_rows, weight_cols, major, minor);
}
case 32:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, false>(weight_rows, weight_cols, major, minor);
}
case 64:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, false>(weight_rows, weight_cols, major, minor);
}
switch (block_size) {
case 16:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, false>(weight_rows, weight_cols, major, minor);
}
case 32:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, false>(weight_rows, weight_cols, major, minor);
}
case 64:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, false>(weight_rows, weight_cols, major, minor);
}
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,19 +453,19 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,

uint32_t* dest_pair = reinterpret_cast<uint32_t*>(dest.data());
const b64* scales_ptr = reinterpret_cast<const b64*>(scales.data());
const ElementOffset* offsets_ptr = nullptr;
[[maybe_unused]] const ElementOffset* offsets_ptr = nullptr;
if constexpr(kHasOffset) { offsets_ptr = offsets.data(); }

CUTLASS_PRAGMA_UNROLL
for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){
// dequantize: d = scale * (weight - offset)
// to use FMA, d = scale * weight + (scale * (-offset))

b64 offsets;
[[maybe_unused]] b64 offsets;
if constexpr(kHasOffset){
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets_ptr);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands
Expand Down Expand Up @@ -796,12 +796,12 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
}
} else if constexpr (kMmaIterationsB % 2 == 0) {
const uint32_t* scales_ptr = reinterpret_cast<const uint32_t*>(scales.data());
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);

if constexpr (kHasOffset){
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
// possible buffer over read 2 bytes here.
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets.data());
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);

asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1, rb2;\n"
Expand All @@ -823,6 +823,8 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
#endif
} else {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);

asm volatile(
"{\n\t"
" .reg .b32 rb0;\n"
Expand Down
Loading

0 comments on commit 65573be

Please sign in to comment.