diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu index 098e3618beddd..226f5bf04482c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -53,7 +53,7 @@ __global__ void kgemm_4bit_inference_naive( const T* __restrict__ A, const uint8_t* B, const T* absmax, - const T* datatype, + const T* quant_map, T* out, int lda, int ldb, @@ -75,12 +75,8 @@ __global__ void kgemm_4bit_inference_naive( uint8_t local_B_4bit[num_values_8bit]; T local_B[num_values_4bit / 4]; T local_A[num_values_4bit / 4]; - __shared__ T quant_map[16]; T local_absmax = T(0.0f); - for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); - __syncthreads(); - // A: [1, K] // B: [N, K] for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {