Skip to content

Commit

Permalink
remove scalar mul helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk committed Nov 17, 2023
1 parent ed7c4fb commit 21ddce3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <cub/cub.cuh>
#include <cuda_fp16.h>
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
#include "dequantize_blockwise_bnb4.cuh"

Expand Down Expand Up @@ -71,8 +71,8 @@ __global__ void kDequantizeBlockwise(

#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max);
vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max);
vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max;
vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
}

__syncthreads();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,6 @@ namespace cuda {
template <class T>
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream);

// templated scalar multiply function
template <class T>
__device__ inline T ScalarMul(T a, T b);

template <>
__device__ inline float ScalarMul(float a, float b) {
return a * b;
}

template <>
__device__ inline half ScalarMul(half a, half b) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
return a * b;
#else
// half multiplication not supported
return static_cast<half>(static_cast<float>(a) * static_cast<float>(b));
#endif
}

template <>
__device__ inline BFloat16 ScalarMul(BFloat16 a, BFloat16 b) {
return a * b;
}

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// will use the native bfloat16 multiply instruction on sm_80+
template <>
__device__ inline nv_bfloat16 ScalarMul(nv_bfloat16 a, nv_bfloat16 b) {
return a * b;
}
#endif

template <class T>
Status DequantizeBnb4(
const T* quant_map,
Expand Down
39 changes: 4 additions & 35 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,13 @@
#include <cub/cub.cuh>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "matmul_bnb4.cuh"

namespace onnxruntime {
namespace contrib {
namespace cuda {

template <class T>
__device__ inline float ScalarMulFloatOut(T a, T b);

template <>
__device__ inline float ScalarMulFloatOut(float a, float b) {
return a * b;
}

template <>
__device__ inline float ScalarMulFloatOut(half a, half b) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
return static_cast<float>(a * b);
#else
// half multiplication not supported
return static_cast<float>(a) * static_cast<float>(b);
#endif
}

template <>
__device__ inline float ScalarMulFloatOut(BFloat16 a, BFloat16 b) {
return a * b;
}

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// will use the native bfloat16 multiply instruction on sm_80+
template <>
__device__ inline float ScalarMulFloatOut(nv_bfloat16 a, nv_bfloat16 b) {
return static_cast<float>(a * b);
}
#endif

#define num_values_4bit 32
template <class T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
Expand Down Expand Up @@ -110,8 +79,8 @@ __global__ void kgemm_4bit_inference_naive(
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int k = 0; k < num_values_8bit / 4; k++) {
local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax);
local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax);
local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
}

if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
Expand All @@ -138,7 +107,7 @@ __global__ void kgemm_4bit_inference_naive(
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++) {
local_C += ScalarMulFloatOut(local_A[k], local_B[k]);
local_C += static_cast<float>(local_A[k] * local_B[k]);
}
}
}
Expand Down

0 comments on commit 21ddce3

Please sign in to comment.