-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use graph optimizer for gpu tensor prepack #19814
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,7 @@ | |
#include "core/providers/cuda/cuda_common.h" | ||
#include "matmul_nbits.cuh" | ||
|
||
using namespace onnxruntime::cuda; | ||
using namespace cub; | ||
#include "blk_q4/f16_gemm_sm80.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
@@ -348,6 +347,226 @@ | |
int shared_mem_per_block, | ||
cudaStream_t stream); | ||
|
||
/** | ||
* @brief Helper function to run the GEMM kernel for 4bits quantized gemm on SM80. | ||
* Only support fp16 for now. | ||
*/ | ||
template< | ||
typename ElementT, | ||
int block_size, | ||
bool column_wise_blocking, | ||
bool small_m, | ||
bool has_offsets> | ||
Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream, | ||
gsl::span<ElementT const> a, | ||
gsl::span<uint8_t const> weights, | ||
gsl::span<ElementT const> scales, | ||
gsl::span<uint8_t const> offsets, | ||
gsl::span<ElementT> output) { | ||
static_assert(std::is_same<ElementT, half>::value | ||
|| std::is_same<ElementT, MLFloat16>::value | ||
|| std::is_same<ElementT, cutlass::half_t>::value, | ||
"Only support fp16 for now"); | ||
using ElementDequant = cutlass::half_t; | ||
using QuantBlocking = | ||
typename std::conditional<column_wise_blocking, | ||
cutlass::MatrixShape<block_size, 1>, | ||
cutlass::MatrixShape<1, block_size>>::type; | ||
|
||
using GemmRunner = onnxruntime::cuda::BlkQ4F16GemmImpl<ElementDequant, QuantBlocking, small_m, has_offsets>; | ||
|
||
using ElementAccumulator = typename GemmRunner::ElementAccumulator; | ||
using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; | ||
using ElementOutput = typename GemmRunner::ElementOutput; | ||
using ElementW = typename GemmRunner::ElementW; | ||
using ElementWPack = typename GemmRunner::ElementWPack; | ||
using ElementQScale = typename GemmRunner::ElementQScale; | ||
using ElementQOffset = typename GemmRunner::ElementQOffset; | ||
|
||
using LayoutInputA = typename GemmRunner::LayoutInputA; | ||
using LayoutOutput = typename GemmRunner::LayoutOutput; | ||
using LayoutInputWPack = typename GemmRunner::LayoutInputWPack; | ||
using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; | ||
|
||
const cutlass::gemm::GemmCoord problem_size = {m, n, k}; | ||
|
||
ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct"); | ||
cutlass::TensorRef<ElementDequant const, LayoutInputA> ref_a( | ||
reinterpret_cast<ElementDequant const *>(a.data()), | ||
LayoutInputA::packed({m, k})); | ||
|
||
ORT_RETURN_IF_NOT(weights.size_bytes() == k/2 * n/2 * sizeof(ElementWPack), "weights size is not correct"); | ||
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_W( | ||
reinterpret_cast<ElementWPack const *>(weights.data()), | ||
LayoutInputWPack::packed({k/2, n/2})); | ||
|
||
ORT_RETURN_IF_NOT(scales.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQScale), | ||
"scales size is not correct"); | ||
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_scales( | ||
reinterpret_cast<ElementQScale const *>(scales.data()), | ||
LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn})); | ||
|
||
ORT_RETURN_IF_NOT(output.size_bytes() == m * n * sizeof(ElementOutput), "output size is not correct"); | ||
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_output( | ||
reinterpret_cast<ElementOutput *>(output.data()), | ||
LayoutOutput::packed({m, n})); | ||
|
||
// run GEMM | ||
cutlass::Status status; | ||
if constexpr (has_offsets) { | ||
ORT_RETURN_IF_NOT(offsets.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQOffset), | ||
Check warning on line 417 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
"offsets size is not correct"); | ||
cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_offsets( | ||
reinterpret_cast<ElementQOffset const *>(offsets.data()), | ||
LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn})); | ||
status = GemmRunner::run( | ||
stream, problem_size, ref_a, ref_W, ref_scales, ref_offsets, | ||
ref_output, ref_output); | ||
} else { | ||
status = GemmRunner::run( | ||
stream, problem_size, ref_a, ref_W, ref_scales, | ||
ref_output, ref_output); | ||
} | ||
ORT_RETURN_IF_NOT(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); | ||
return Status::OK(); | ||
} | ||
|
||
template<typename ElementT> | ||
Status | ||
blkq4_fp16_gemm_sm80_dispatch( | ||
int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream, | ||
ElementT const* a_ptr, size_t a_size, | ||
uint8_t const* weights_ptr, size_t weights_size, | ||
ElementT const* scales_ptr, size_t scales_size, | ||
uint8_t const* offsets_ptr, size_t offsets_size, | ||
ElementT* output_ptr, size_t output_size) { | ||
auto a = gsl::make_span(a_ptr, a_size); | ||
auto weights = gsl::make_span(weights_ptr, weights_size); | ||
auto scales = gsl::make_span(scales_ptr, scales_size); | ||
auto offsets = gsl::make_span(offsets_ptr, offsets_size); | ||
auto output = gsl::make_span(output_ptr, output_size); | ||
|
||
switch (block_size) | ||
{ | ||
Check warning on line 450 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
case 16: | ||
if (column_wise_blocking) { | ||
if (m > 16) { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 16, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
Check warning on line 455 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
else | ||
return blkq4_gemm_sm80<ElementT, 16, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} else { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 16, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
else | ||
return blkq4_gemm_sm80<ElementT, 16, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} | ||
} else { | ||
if (m > 16) { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 16, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
Check warning on line 467 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
else | ||
return blkq4_gemm_sm80<ElementT, 16, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
Check warning on line 469 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
} else { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 16, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
Check warning on line 472 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
else | ||
return blkq4_gemm_sm80<ElementT, 16, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} | ||
} | ||
break; | ||
|
||
case 32: | ||
if (column_wise_blocking) { | ||
if (m > 16) { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 32, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
Check warning on line 483 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
else | ||
return blkq4_gemm_sm80<ElementT, 32, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} else { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 32, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
else | ||
return blkq4_gemm_sm80<ElementT, 32, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} | ||
} else { | ||
if (m > 16) { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 32, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
Check warning on line 495 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu GitHub Actions / Lint C++
|
||
else | ||
return blkq4_gemm_sm80<ElementT, 32, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} else { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 32, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
else | ||
return blkq4_gemm_sm80<ElementT, 32, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} | ||
} | ||
break; | ||
|
||
case 64: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you support case=128? which is used widely. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's difficult for this kernel. I am working on another version which hopefully can support that. |
||
if (column_wise_blocking) { | ||
if (m > 16) { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 64, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
else | ||
return blkq4_gemm_sm80<ElementT, 64, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} else { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 64, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
else | ||
return blkq4_gemm_sm80<ElementT, 64, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} | ||
} else { | ||
if (m > 16) { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 64, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
else | ||
return blkq4_gemm_sm80<ElementT, 64, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} else { | ||
if (offsets.empty()) | ||
return blkq4_gemm_sm80<ElementT, 64, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); | ||
else | ||
return blkq4_gemm_sm80<ElementT, 64, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); | ||
} | ||
} | ||
break; | ||
} | ||
|
||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported block size: ", block_size); | ||
} | ||
|
||
template | ||
Status blkq4_fp16_gemm_sm80_dispatch<half>( | ||
int block_size, | ||
bool column_wise_blocking, | ||
int m, int n, int k, cudaStream_t stream, | ||
half const* a_ptr, size_t a_size, | ||
uint8_t const* weights_ptr, size_t weights_size, | ||
half const* scales_ptr, size_t scales_size, | ||
uint8_t const* offsets_ptr, size_t offsets_size, | ||
half* output_ptr, size_t output_size); | ||
|
||
template | ||
Status blkq4_fp16_gemm_sm80_dispatch<cutlass::half_t>( | ||
int block_size, | ||
bool column_wise_blocking, | ||
int m, int n, int k, cudaStream_t stream, | ||
cutlass::half_t const* a_ptr, size_t a_size, | ||
uint8_t const* weights_ptr, size_t weights_size, | ||
cutlass::half_t const* scales_ptr, size_t scales_size, | ||
uint8_t const* offsets_ptr, size_t offsets_size, | ||
cutlass::half_t* output_ptr, size_t output_size); | ||
|
||
template | ||
Status blkq4_fp16_gemm_sm80_dispatch<onnxruntime::MLFloat16>( | ||
int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream, | ||
onnxruntime::MLFloat16 const* a_ptr, size_t a_size, | ||
uint8_t const* weights_ptr, size_t weights_size, | ||
onnxruntime::MLFloat16 const* scales_ptr, size_t scales_size, | ||
uint8_t const* offsets_ptr, size_t offsets_size, | ||
onnxruntime::MLFloat16* output_ptr, size_t output_size); | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this can be supported by sm86, sm89?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should