Skip to content

Commit

Permalink
bug fix: can't use A tensor shape[0] for M
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Mar 8, 2024
1 parent 6bd5c0c commit 01a6521
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 121 deletions.
4 changes: 0 additions & 4 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,6 @@
endif()
config_cuda_provider_shared_module(onnxruntime_providers_cuda)

# TODO only needed in DEBUG builds, need cmake expert advice on how to do that
set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cu PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ")
set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cc PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ")

install(TARGETS onnxruntime_providers_cuda
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
34 changes: 17 additions & 17 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

template<>
template <>
Status MatMulNBits<MLFloat16>::PrepackedGemm(
cudaStream_t stream,
const Tensor* a,
const Tensor* b,
const Tensor* scales,
const Tensor* zero_points,
Tensor* Y) const {
int64_t M = a->Shape()[0];
cudaStream_t stream,
int M,
const Tensor* a,
const Tensor* b,
const Tensor* scales,
const Tensor* zero_points,
Tensor* Y) const {
uint8_t const* zero_points_ptr = nullptr;
size_t zero_points_size = 0;
if (zero_points != nullptr) {
Expand All @@ -32,12 +32,12 @@ Status MatMulNBits<MLFloat16>::PrepackedGemm(
}

return blkq4_fp16_gemm_sm80_dispatch<MLFloat16>(
int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,
a->Data<MLFloat16>(), a->Shape().Size(),
b->Data<uint8_t>(), b->Shape().Size(),
scales->Data<MLFloat16>(), scales->Shape().Size(),
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,

Check warning on line 35 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:35: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
a->Data<MLFloat16>(), a->Shape().Size(),
b->Data<uint8_t>(), b->Shape().Size(),
scales->Data<MLFloat16>(), scales->Shape().Size(),
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
}

template <typename T>
Expand All @@ -59,14 +59,14 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

if (prepack_ > 0){
if (prepack_ > 0) {
ORT_RETURN_IF(reorder_idx != nullptr,
"Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!");
ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType<T>(),
"Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!");
return PrepackedGemm(
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
a, b, scales, zero_points, Y);
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
helper.M(), a, b, scales, zero_points, Y);
}

const auto* a_data = a->Data<T>();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream,

const cutlass::gemm::GemmCoord problem_size = {m, n, k};

ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct");
ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct: ", a.size_bytes(), " vs m: ", m, "k: ", k , " size: ", m * k * sizeof(ElementDequant));

Check warning on line 393 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:393: Lines should be <= 120 characters long [whitespace/line_length] [2]
cutlass::TensorRef<ElementDequant const, LayoutInputA> ref_a(
reinterpret_cast<ElementDequant const *>(a.data()),
LayoutInputA::packed({m, k}));
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ class MatMulNBits final : public CudaKernel {
info.GetAttrOrDefault<int64_t>("prepacked", &prepack_, int64_t(0));
}

Status PrepackedGemm(cudaStream_t stream, const Tensor* a, const Tensor* b,
const Tensor* scales, const Tensor* zero_points, Tensor* Y) const {
Status PrepackedGemm([[maybe_unused]] cudaStream_t stream,
[[maybe_unused]] int M,
[[maybe_unused]] const Tensor* a,
[[maybe_unused]] const Tensor* b,
[[maybe_unused]] const Tensor* scales,
[[maybe_unused]] const Tensor* zero_points,
[[maybe_unused]] Tensor* Y) const {
ORT_THROW("Prepacked gemm is not supported for MatMulNBits op.");
}

Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/graph/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node,
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

/** Returns the attribute of a Node with a given name. */
static inline
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) {
static inline const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) {
const auto& attrs = node.GetAttributes();
const auto iter = attrs.find(attr_name);
return iter == attrs.end() ? nullptr : &iter->second;
Expand All @@ -49,7 +48,6 @@ inline Status TryGetNodeAttribute<int64_t>(const Node& node, const std::string&
return Status::OK();
}


/** Add a new initializer to 'graph'.
Checks that new_initializer does not already exist in 'graph' before adding it.
@returns The NodeArg for the new initializer.
Expand Down
56 changes: 27 additions & 29 deletions onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ struct BlockwiseQuantization {
static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1;

static void prepack_quant_scales(
size_t rows,
size_t columns,
int rows,
int columns,
const gsl::span<ElementT const>& scales, // <- quant scales, column major layout
const gsl::span<ElementT>& scales_prepacked // <- quant scales prepacked, same size buffer
) {
Expand Down Expand Up @@ -261,8 +261,8 @@ struct BlockwiseQuantization {
}

static void prepack_quant_offsets(
size_t rows,
size_t columns,
int rows,
int columns,
const gsl::span<uint8_t const>& offsets, // <- quant offsets, int4, column major layout
const gsl::span<uint8_t>& offsets_prepacked // <- quant offsets prepacked, double size buffer
) {
Expand Down Expand Up @@ -345,8 +345,8 @@ struct BlockwiseQuantization {
};

static inline bool IsSm80WithWholeBlocks(
int weight_rows, [[maybe_unused]] int weight_cols,
int major, [[maybe_unused]] int minor) {
int weight_rows, [[maybe_unused]] int weight_cols,
int major, [[maybe_unused]] int minor) {
if (major < 8) {
return false;
}
Expand All @@ -364,9 +364,8 @@ static inline bool IsSm80WithWholeBlocks(
return (weight_rows % 64 == 0);
}

template<typename ElementT, int block_size, bool col_blocking>
inline
bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) {
template <typename ElementT, int block_size, bool col_blocking>
inline bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) {
using Base = BlockwiseQuantization<ElementT, block_size, 4, col_blocking>;
if (!Base::weight_dimension_supported(weight_rows, weight_cols)) {
return false;
Expand All @@ -375,26 +374,25 @@ bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int
}

static inline bool BlkQuantGemmSm80Supported(int block_size, bool col_blocking, int weight_rows, int weight_cols, int major, int minor) {
switch (block_size)
{
case 16:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, false>(weight_rows, weight_cols, major, minor);
}
case 32:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, false>(weight_rows, weight_cols, major, minor);
}
case 64:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, false>(weight_rows, weight_cols, major, minor);
}
switch (block_size) {
case 16:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 16, false>(weight_rows, weight_cols, major, minor);
}
case 32:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 32, false>(weight_rows, weight_cols, major, minor);
}
case 64:
if (col_blocking) {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, true>(weight_rows, weight_cols, major, minor);
} else {
return onnxruntime::cuda::BlkQuantGemmSm80Supported<MLFloat16, 64, false>(weight_rows, weight_cols, major, minor);
}
}
return false;
}
Expand Down
78 changes: 39 additions & 39 deletions onnxruntime/core/optimizer/gpu_ops_prepack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
// 3. The logic of prepacking depends on underlying GPU
// hardware. Currently this part is hard-coded for SM80.


#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/gpu_ops_prepack.h"
Expand All @@ -43,17 +42,17 @@ extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA();
/**
* @brief Read initialized tensor from protobuf, and store it in ort_value.
* Keep in mind that ort_value is the owner of the tensor memory after calling this function.
*/
*/
inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) {
const ONNX_NAMESPACE::TensorProto* tensor_proto;
ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto),
"Missing initializer for ", arg->Name());

const auto* path_c_str = graph.ModelPath().ToPathString().c_str();
const auto path_str = graph.ModelPath().ToPathString();

return utils::TensorProtoToOrtValue(
Env::Default(), path_c_str, *tensor_proto,
std::make_shared<CPUAllocator>(), ort_value);
Env::Default(), path_str.c_str(), *tensor_proto,
std::make_shared<CPUAllocator>(), ort_value);
}

template <typename T>
Expand All @@ -65,7 +64,7 @@ inline gsl::span<T> make_span(std::string& str) {
// Prepacking logic specific to MatMulNBits<float16> on sm80
//

static inline bool IsNodeMatMulNbitsFp16(const Node& node){
static inline bool IsNodeMatMulNbitsFp16(const Node& node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain)) {
return false;
}
Expand All @@ -78,13 +77,13 @@ static inline bool IsNodeMatMulNbitsFp16(const Node& node){

template <int block_size, bool column_quant_blk>
void Sm80BlkQ4PrepackT(
int rows, int columns,
gsl::span<const uint8_t> weights,
gsl::span<const MLFloat16> scales,
gsl::span<const uint8_t> zp,
std::string& packed_w,
std::string& packed_scales,
std::string& packed_zp) {
int rows, int columns,
gsl::span<const uint8_t> weights,
gsl::span<const MLFloat16> scales,
gsl::span<const uint8_t> zp,
std::string& packed_w,
std::string& packed_scales,
std::string& packed_zp) {
using Base = onnxruntime::cuda::BlockwiseQuantization<
MLFloat16,
block_size,
Expand All @@ -95,31 +94,31 @@ void Sm80BlkQ4PrepackT(

packed_w.resize(q_weight_shape.product() * sizeof(uint8_t));
Base::prepack_weights(
rows, columns, weights,
make_span<uint8_t>(packed_w));
rows, columns, weights,
make_span<uint8_t>(packed_w));

packed_scales.resize(meta_shape.product() * sizeof(MLFloat16));
Base::prepack_quant_scales(
rows, columns, scales,
make_span<MLFloat16>(packed_scales));
rows, columns, scales,
make_span<MLFloat16>(packed_scales));

if (!zp.empty()) {
packed_zp.resize(meta_shape.product() * sizeof(uint8_t));
Base::prepack_quant_offsets(
rows, columns, zp,
make_span<uint8_t>(packed_zp));
rows, columns, zp,
make_span<uint8_t>(packed_zp));
}
}

void Sm80BlkQ4Prepack(
int block_size, bool column_quant_blk,
int rows, int columns,
gsl::span<const uint8_t> weights,
gsl::span<const MLFloat16> scales,
gsl::span<const uint8_t> zp,
std::string& packed_w,
std::string& packed_scales,
std::string& packed_zp) {
int block_size, bool column_quant_blk,
int rows, int columns,
gsl::span<const uint8_t> weights,
gsl::span<const MLFloat16> scales,
gsl::span<const uint8_t> zp,
std::string& packed_w,
std::string& packed_scales,
std::string& packed_zp) {
switch (block_size) {
case 16:
if (column_quant_blk) {
Expand Down Expand Up @@ -161,21 +160,23 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) {
Status status = graph_utils::TryGetNodeAttribute(node, "prepacked", att_i);
bool prepacked = status.IsOK() ? att_i != 0 : false;
if (prepacked) {
return Status::OK(); // already prepacked, nothing to do
return Status::OK(); // already prepacked, nothing to do
}

ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute<int64_t>(node, "bits", att_i));
int nbits = static_cast<int>(att_i);
int nbits = SafeInt<int>(att_i);
if (nbits != 4) {
return Status::OK(); // only support 4 bits for now
return Status::OK(); // only support 4 bits for now
}

// A single dimension can not exceed 2G yet.
ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute<int64_t>(node, "K", att_i));
int k = static_cast<int>(att_i);
int k = SafeInt<int>(att_i);
ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute<int64_t>(node, "N", att_i));
int n = static_cast<int>(att_i);
int n = SafeInt<int>(att_i);

ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute<int64_t>(node, "block_size", att_i));
int block_size = static_cast<int>(att_i);
int block_size = SafeInt<int>(att_i);

status = graph_utils::TryGetNodeAttribute(node, "column_wise_blocking", att_i);
bool column_wise_quant_blk = status.IsOK() ? att_i != 0 : true;
Expand All @@ -184,10 +185,10 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) {
ORT_ENFORCE(provider_info != nullptr, "Failed to query CUDA provider info while prepacking cuda operators.");
int major, minor;
ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr,
"Failed to query CUDA device version while prepacking cuda operators.");
"Failed to query CUDA device version while prepacking cuda operators.");

if (!onnxruntime::cuda::BlkQuantGemmSm80Supported(block_size, column_wise_quant_blk, k, n, major, minor)) {
return Status::OK(); // not supported
return Status::OK(); // not supported
}

//
Expand All @@ -196,7 +197,7 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) {
auto& node_name = node.Name();
auto& mutable_input_defs = node.MutableInputDefs();
if (mutable_input_defs.size() < 3 || mutable_input_defs.size() > 4) {
return Status::OK(); // not supported
return Status::OK(); // not supported
}

NodeArg* old_weights_arg = mutable_input_defs[1];
Expand Down Expand Up @@ -227,7 +228,7 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) {
ORT_RETURN_IF_ERROR(GetOrtValue(old_zp_arg, graph, zp_val));
Tensor* zp_tensor_ptr = zp_val.GetMutable<Tensor>();
if (!zp_tensor_ptr->IsDataType<uint8_t>()) {
return Status::OK(); // not supported
return Status::OK(); // not supported
}
zp = zp_tensor_ptr->DataAsSpan<uint8_t>();
}
Expand Down Expand Up @@ -289,7 +290,6 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) {
return Status::OK();
}


Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Expand All @@ -304,7 +304,7 @@ Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));

if (node.GetExecutionProviderType() != onnxruntime::kCudaExecutionProvider) {
continue; // only interested in CUDA nodes
continue; // only interested in CUDA nodes
}

// Run prepack if the node is MatMulNBits<float16>.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/util/matrix_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class MatrixRef {
MatrixRef(
NonConstMatrixRef const& ref, ///< MatrixRef to non-const data
/// SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
_Magic magic = (typename std::enable_if<!IsNonConstRef, _Magic>::type)0
[[maybe_unused]] _Magic magic = (typename std::enable_if<!IsNonConstRef, _Magic>::type)0
) : data_(ref.data()), shape_(ref.shape()), layout_(Layout::packed(ref.shape())) {}

ORT_FORCEINLINE
Expand Down
Loading

0 comments on commit 01a6521

Please sign in to comment.