Skip to content

Commit

Permalink
hipify int4 gemv (#20666)
Browse files Browse the repository at this point in the history
Hipify MatMulNBits to accommodate the need of Phi3 onnx release.
  • Loading branch information
cloudhan authored May 18, 2024
1 parent 72a3bde commit 5d07291
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 25 deletions.
5 changes: 0 additions & 5 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
"quantization/dequantize_blockwise_bnb4.cuh"
"quantization/dequantize_blockwise_bnb4.cu"
"quantization/matmul_bnb4.cc"
"quantization/matmul_bnb4.cuh"
"quantization/matmul_bnb4.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
"quantization/moe_quantization.h"
"quantization/moe_quantization.cc"
"quantization/quantize_dequantize_linear.cc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ __global__ void Dequantize4BitsKernelReOrder(
if (group_id >= total_groups) {
return;
}
// T __shared__ zero_points_after_reorder[];//K
// T __shared__ scales_after_reorder[]; // K
// const int num_r_per_thread = k / 256;

const int zero_point_shape_x = (groups_per_K + 1) / 2;
const int scales_shape_x = groups_per_K;
Expand Down Expand Up @@ -361,7 +358,6 @@ template <
static void dequantize(ElementT* dst, const uint8_t* weights, const ElementT* scales,
const uint8_t* zero_points, int32_t rows, int32_t columns,
cudaStream_t stream) {
using QuantBlk = typename BlkQuantTraits<ElementT, block_size, qbits, Columnwise>::QuantBlk;
using ThreadBlk = typename BlkQuantTraits<ElementT, block_size, qbits, Columnwise>::ThreadBlk;

// Thread partitioning
Expand Down
15 changes: 8 additions & 7 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ __device__ __forceinline__ T WarpUniform(T value) {
};
} p;
p.value = value;
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
p.asInt = WARP_SHFL((unsigned)p.asInt, 0);
return p.value;
}

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530) && !defined(__HIPCC__)
// Convert 8 4bits integer stored in one uint32_t to 8 halfs.
// 8 4bits with order 0,1,2,3,4,5,6,7,8 will be converted to 8 halfs with order 0,4,1,5,2,6,3,7
__device__ __forceinline__ void Convert8xInt4To8xHalfs(uint32_t value, half2* half_2x4) {
Expand Down Expand Up @@ -169,15 +169,16 @@ __device__ __forceinline__ void AccumulateEightElements(uint32_t values_quant, f
}

constexpr int kColsPerThreadBlock = 8;
constexpr int kWarpSize = 32;
constexpr int kElementsPerThreadPerIteration = 8;
constexpr int kWarpSize = GPU_WARP_SIZE;

// kernel for 4bits quantized gemv, i.e., computing A(1,K) x B(K, N)
// B(K, N) is quantized blockwise with 4bits and stored as [N, (K + block_size - 1)/block_size, blob]
// The thread block size is (kWarpSize, kColsPerThreadBlock) and grid size is (N/kColsPerThreadBlock, 1)
// Each thread block computes [1, K] x [kColsPerThreadBlock, (K + block_size - 1)/block_size, blob],
// i.e., computing kColsPerThreadBlock per block and a warp reduce (1, K) x (K)
template <class T, int block_size, bool has_zero_point>
__global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloatInt4Kernel(
__global__ void __launch_bounds__(kWarpSize * kColsPerThreadBlock) MatMulFloatInt4Kernel(
T* output,
const T* a_data,
const uint8_t* b_data_quant,
Expand All @@ -192,7 +193,7 @@ __global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloatInt
const int lane_id = threadIdx.x;
const int warp_id = WarpUniform(threadIdx.y);
const int n_id = n_block_id * kColsPerThreadBlock + warp_id;
constexpr int k_per_iter = 256;
constexpr int k_per_iter = kWarpSize * kElementsPerThreadPerIteration;

extern __shared__ char shared_buffer[];
// load scale to shared buffer
Expand Down Expand Up @@ -262,8 +263,8 @@ __global__ void __launch_bounds__(kWarpSize* kColsPerThreadBlock) MatMulFloatInt

float sum = (float)(sums[0] + sums[1] + sums[2] + sums[3] + sums[4] + sums[5] + sums[6] + sums[7]);
// warp reduction
for (int i = 16; i > 0; i = i / 2) {
sum += __shfl_down_sync(0xffffffff, sum, i);
for (int i = kWarpSize / 2; i > 0; i = i / 2) {
sum += WARP_SHFL_DOWN(sum, i);
}

if (lane_id == 0) {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
Expand Down Expand Up @@ -267,6 +269,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasDropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BitmaskDropout)>,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// This file serve as a simple example for adding a tunable op to onnxruntime.

#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <pybind11/pybind11.h>

#include <string>

#include "core/providers/cuda/tunable/cuda_tunable.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh"
#include "contrib_ops/cuda/quantization/matmul_nbits.cuh"

namespace py = pybind11;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def profile_matmul_fp_int4_func(m, n, k, dtype, func, is_symmetric):
else f(output_d, a_d, b_d, scales_d, zeropoints_d, m, n, k)
)
duration_ms = my_op.Profile()
total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
total_bytes = (m * k + m * n) * (dtype_to_bytes(dtype)) + n * k / 2

ke.report(MatrixFpInt4Metric(func, dtype, duration_ms, total_bytes, m, n, k, is_symmetric))

Expand Down
100 changes: 100 additions & 0 deletions onnxruntime/python/tools/kernel_explorer/kernels/rocm/matmul_4bits.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>

#include <string>

#include "core/providers/rocm/tunable/rocm_tunable.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
#include "contrib_ops/rocm/quantization/matmul_nbits.cuh"

namespace py = pybind11;

namespace onnxruntime {

// Extend the OpParams so that all specializations have the same parameter passing interface
template <typename T>
struct MatrixFloatInt4Params : rocm::tunable::OpParams {
std::string Signature() const override { return std::to_string(n_); }

T* output_;
const T* a_;
const uint8_t* b_;
const T* scales_;
const uint8_t* zero_points_;
int m_;
int n_;
int k_;
};

template <typename T>
class MatrixFloatInt4 : public IKernelExplorer {
public:
MatrixFloatInt4(DeviceArray& output,
DeviceArray& a,
DeviceArray& b,
DeviceArray& scales,
int m, int n, int k) {
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
params_.output_ = static_cast<T*>(output.ptr());
params_.a_ = static_cast<T*>(a.ptr());
params_.b_ = static_cast<uint8_t*>(b.ptr());
params_.scales_ = static_cast<T*>(scales.ptr());
params_.zero_points_ = nullptr;
params_.m_ = m;
params_.n_ = n;
params_.k_ = k;

HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, 0));
}

MatrixFloatInt4(DeviceArray& output,
DeviceArray& a,
DeviceArray& b,
DeviceArray& scales,
DeviceArray& zeropoints,
int m, int n, int k) : MatrixFloatInt4(output, a, b, scales, m, n, k) {
params_.zero_points_ = static_cast<uint8_t*>(zeropoints.ptr());
}

void Run() override {
contrib::rocm::TryMatMul4Bits<T>(
params_.output_,
params_.a_,
params_.b_,
params_.scales_,
params_.zero_points_,
params_.m_,
params_.n_,
params_.k_,
32,
static_cast<int>(device_prop_.sharedMemPerBlock),
params_.StreamHandle());
}

private:
// A VectorAddOp<T> is a callable that can process const VectorAddParams<T>*
using ParamsT = MatrixFloatInt4Params<T>;
ParamsT params_{};
hipDeviceProp_t device_prop_;
};

#define REGISTER_OP(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int, int>()) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int, int>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
.def("Run", &name<type>::Run);

KE_REGISTER(m) {
REGISTER_OP(MatrixFloatInt4, half);
REGISTER_OP(MatrixFloatInt4, float);
}

} // namespace onnxruntime
17 changes: 12 additions & 5 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ TEST(MatMulNBits, Float32) {
}
}

#if defined(USE_CUDA) || defined(USE_DML)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)

namespace {
// Legacy test function.
Expand All @@ -343,24 +343,31 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
opts.output_abs_error = fp16_abs_error;
}

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (use_float16) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#endif
#ifdef USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
#ifdef USE_DML
execution_providers.push_back(DefaultDmlExecutionProvider());
#endif

RunTest<MLFloat16>(opts, std::move(execution_providers));
} else {
RunTest<float>(opts);
#ifdef USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif

RunTest<float>(opts, std::move(execution_providers));
}
}
} // namespace

TEST(MatMulNBits, Float16) {
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
auto has_gidx_options = {true, false};
#else
auto has_gidx_options = {false};
Expand Down Expand Up @@ -404,7 +411,7 @@ TEST(MatMulNBits, Float16Large) {
}
}

#endif // defined(USE_CUDA) || defined(USE_DML)
#endif // defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)

#if defined(ORT_NEURAL_SPEED)
namespace {
Expand Down

0 comments on commit 5d07291

Please sign in to comment.