Skip to content

Commit

Permalink
c++ linter
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk committed Oct 23, 2023
1 parent cca2da8 commit 7f1d345
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,67 +44,79 @@ FORCEINLINE uint8_t QuantizeOneFP4(float x) {

int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if (x > 0.29166667f)
if (x > 0.583333f)
if (x > 0.8333333f)
if (x > 0.29166667f) {
if (x > 0.583333f) {
if (x > 0.8333333f) {
return 0b0011 + sign;
else
} else {
return 0b0010 + sign;
else if (x > 0.4166667f)
}
} else if (x > 0.4166667f) {
return 0b101 + sign;
else
} else {
return 0b100 + sign;
else if (x > 0.0859375f)
if (x > 0.20833333f)
}
} else if (x > 0.0859375f) {
if (x > 0.20833333f) {
return 0b0111 + sign;
else
} else {
return 0b0110 + sign;
else if (x > 0.00260417f)
}
} else if (x > 0.00260417f) {
return 0b0001 + sign;
else
} else {
return 0b0000 + sign;
}
}

FORCEINLINE uint8_t QuantizeOneNF4(float x) {
if (x > 0.03979014977812767f)
if (x > 0.3893125355243683f) // 1
if (x > 0.6427869200706482f) // 11
if (x > 0.8614784181118011f) // 111
if (x > 0.03979014977812767f) {
if (x > 0.3893125355243683f) { // 1
if (x > 0.6427869200706482f) { // 11
if (x > 0.8614784181118011f) { // 111
return 0b1111;
else
} else {
return 0b1110;
else if (x > 0.5016634166240692f) // 110
}
} else if (x > 0.5016634166240692f) { // 110
return 0b1101;
else
} else {
return 0b1100;
else if (x > 0.2035212516784668f) // 10
if (x > 0.2920137718319893f) // 101
}
} else if (x > 0.2035212516784668f) { // 10
if (x > 0.2920137718319893f) { // 101
return 0b1011;
else
} else {
return 0b1010;
else if (x > 0.1202552504837513f) // 100
}
} else if (x > 0.1202552504837513f) { // 100
return 0b1001;
else
} else {
return 0b1000;
else if (x > -0.33967943489551544f) // 0
if (x > -0.13791173323988914f) // 01
if (x > -0.045525018125772476f) // 011
}
} else if (x > -0.33967943489551544f) { // 0
if (x > -0.13791173323988914f) { // 01
if (x > -0.045525018125772476f) { // 011
return 0b0111;
else
} else {
return 0b0110;
else if (x > -0.23460740596055984f) // 010
}
} else if (x > -0.23460740596055984f) { // 010
return 0b0101;
else
} else {
return 0b0100;
else if (x > -0.6106329262256622f) // 00
if (x > -0.4599952697753906f) // 001
}
} else if (x > -0.6106329262256622f) { // 00
if (x > -0.4599952697753906f) { // 001
return 0b0011;
else
} else {
return 0b0010;
else if (x > -0.8480964004993439f) // 000
}
} else if (x > -0.8480964004993439f) { // 000
return 0b0001;
else
} else {
return 0b0000;
}
}

template <int32_t DATA_TYPE>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@ namespace contrib {
namespace cuda {

template<class T>
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream)
{
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) {
ORT_ENFORCE(quant_type == FP4 || quant_type == NF4, "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");

T host_quant_map[16];
switch (quant_type) {
case FP4:
for(int i = 0; i < 16; i++)
for (int i = 0; i < 16; i++)
host_quant_map[i] = static_cast<T>(fp4_qaunt_map[i]);
break;
case NF4:
for(int i = 0; i < 16; i++)
for (int i = 0; i < 16; i++)
host_quant_map[i] = static_cast<T>(nf4_qaunt_map[i]);
break;
}
Expand All @@ -38,25 +37,29 @@ template Status SetBnbQuantMap<half>(int quant_type, half* quant_map_buffer, cud


template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
__global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsigned char *quant_data, const T *absmax, const int block_size, const int n)
{
__global__ void kDequantizeBlockwise(
const T *quant_map,
T *output,
const uint8_t *quant_data,
const T *absmax,
const int block_size,
const int n) {
const int n_load = (gridDim.x * TILE_SIZE);
int valid_items_load = 0;
int valid_items_store = 0;
const int base_idx = (blockIdx.x * TILE_SIZE);

T vals[NUM_PER_TH*2];
unsigned char qvals[NUM_PER_TH];
uint8_t qvals[NUM_PER_TH];
T local_abs_max = T(0.0f);

typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockLoad<uint8_t, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<T, THREADS, NUM_PER_TH*2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;

__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;

for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
{
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) {
valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i;
valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2;

Expand All @@ -66,8 +69,7 @@ __global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsign
LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128);

#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
for (int j = 0; j < NUM_PER_TH; j++) {
vals[j*2] = quant_map[qvals[j] >> 4] * local_abs_max;
vals[j*2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
}
Expand All @@ -79,17 +81,43 @@ __global__ void kDequantizeBlockwise(const T *quant_map, T *output, const unsign


template<class T>
Status DequantizeBnb4(const T* quant_map, T *output, const unsigned char *quant_data, const T *absmax, int block_size, int numel, cudaStream_t stream)
{
Status DequantizeBnb4(
const T* quant_map,
T *output,
const uint8_t *quant_data,
const T *absmax,
int block_size,
int numel,
cudaStream_t stream) {
int tile_size = 1024;
kDequantizeBlockwise<T, 512, 64, 8><<<(numel+tile_size-1)/tile_size, 64, 0, stream>>>(quant_map, output, quant_data, absmax, block_size/2, numel);
kDequantizeBlockwise<T, 512, 64, 8><<<(numel+tile_size-1)/tile_size, 64, 0, stream>>>(
quant_map,
output,
quant_data,
absmax,
block_size/2,
numel);

return Status::OK();
}

template Status DequantizeBnb4<float>(const float* quant_map, float *output, const unsigned char *quant_data, const float *absmax, int block_size, int numel, cudaStream_t stream);

template Status DequantizeBnb4<half>(const half* quant_map, half *output, const unsigned char *quant_data, const half *absmax, int block_size, int numel, cudaStream_t stream);
template Status DequantizeBnb4<float>(
const float* quant_map,
float *output,
const uint8_t *quant_data,
const float *absmax,
int block_size,
int numel,
cudaStream_t stream);

template Status DequantizeBnb4<half>(
const half* quant_map,
half *output,
const uint8_t *quant_data,
const half *absmax,
int block_size,
int numel,
cudaStream_t stream);

} // namespace cuda
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ template <class T>
Status DequantizeBnb4(
const T* quant_map,
T* output,
const unsigned char* quant_data,
const uint8_t* quant_data,
const T* absmax,
int block_size,
int numel,
Expand Down
Loading

0 comments on commit 7f1d345

Please sign in to comment.