Skip to content

Commit

Permalink
don't use __shared__ quant_map
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk committed Nov 21, 2023
1 parent ac8598a commit 18c476b
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ __global__ void kgemm_4bit_inference_naive(
const T* __restrict__ A,
const uint8_t* B,
const T* absmax,
const T* datatype,
const T* quant_map,
T* out,
int lda,
int ldb,
Expand All @@ -75,12 +75,8 @@ __global__ void kgemm_4bit_inference_naive(
uint8_t local_B_4bit[num_values_8bit];
T local_B[num_values_4bit / 4];
T local_A[num_values_4bit / 4];
__shared__ T quant_map[16];
T local_absmax = T(0.0f);

for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]);
__syncthreads();

// A: [1, K]
// B: [N, K]
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
Expand Down

0 comments on commit 18c476b

Please sign in to comment.