Skip to content

Commit

Permalink
fix build warning && cpp lint
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Feb 26, 2024
1 parent b274b4d commit 1167ad7
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 52 deletions.
30 changes: 14 additions & 16 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2808,22 +2808,14 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's scale and zero point are specified by input scales and zero_points.

Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits

For a block blob. It is stored in format:
struct Blob {
uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization
uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization
uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization
}
Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which:
- n_blocks_per_col = (K + block_size - 1) / block_size
- blob_size = block_size / 8 * bits

Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col]
Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is:
- [(N * n_blocks_per_col + 1) / 2] if bits <=4
- [N * n_blocks_per_col] if bits > 4


#### Version

Expand All @@ -2844,17 +2836,19 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.</dd>
</dl>

#### Inputs (3 - 4)
#### Inputs (3 - 5)

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional data blob</dd>
<dd>1 or 2 dimensional data blob</dd>
<dt><tt>scales</tt> : T1</dt>
<dd>quantization scale</dd>
<dt><tt>zero_points</tt> (optional) : T2</dt>
<dt><tt>zero_points</tt> (optional) : T3</dt>
<dd>quantization zero points</dd>
<dt><tt>g_idx</tt> (optional) : T4</dt>
<dd>group_idx</dd>
</dl>

#### Outputs
Expand All @@ -2869,8 +2863,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
<dt><tt>T2</tt> : tensor(uint8), tensor(int32)</dt>
<dd>Constrain quantized weight types to uint8/int32.</dd>
<dt><tt>T3</tt> : tensor(uint8), tensor(int32), tensor(float16), tensor(float)</dt>
<dd>Constrain quantized zero point types to uint8/int32/float16/float.</dd>
<dt><tt>T4</tt> : tensor(int32)</dt>
<dd>the index tensor.</dd>
</dl>


Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down Expand Up @@ -854,7 +854,7 @@ Do not modify directly.*
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "core/mlas/inc/mlas_q4.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "matmul_nbits_impl.h"
#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"

#ifdef ORT_NEURAL_SPEED
#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
Expand Down Expand Up @@ -288,8 +288,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);

const bool has_single_b_matrix = (reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType<float>()) && std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });
const bool has_single_b_matrix =
(reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType<float>()) &&
std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; });

Check warning on line 294 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:294: Lines should be <= 120 characters long [whitespace/line_length] [2]

if (has_single_b_matrix) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h"
#include <algorithm>
#include <cassert>
#include <cmath>
Expand Down Expand Up @@ -54,7 +55,7 @@ void Dequantize4BitsKernelReOrder(

if constexpr (std::is_same_v<T, MLFloat16>) {
T zp_adjust = -scale * MLFloat16(zp_f);
output_i[i] = float((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
output_i[i] = static_cast<float>((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
} else {
T zp_adjust = -scale * zp_f;
output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust;
Expand Down Expand Up @@ -86,7 +87,7 @@ void DequantizeBlockwise(
for (int j = 0; j < 256; j++) {
Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points,
reorder_idx, block_size, groups_per_threadblock,
total_groups, N, K, block_id, j);
total_groups, N, K, static_cast<int>(block_id), j);
}
});
}
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ __global__ void Dequantize4BitsKernelReOrder(
if (group_id >= total_groups) {
return;
}
//T __shared__ zero_points_after_reorder[];//K
//T __shared__ scales_after_reorder[]; // K
//const int num_r_per_thread = k / 256;
// T __shared__ zero_points_after_reorder[];//K
// T __shared__ scales_after_reorder[]; // K
// const int num_r_per_thread = k / 256;

const int zero_point_shape_x = (groups_per_K + 1) / 2;
const int scales_shape_x = groups_per_K;
Expand Down Expand Up @@ -102,7 +102,7 @@ __global__ void Dequantize4BitsKernelReOrder(
}
}

template <class T, typename ZeroT=uint8_t>
template <class T, typename ZeroT = uint8_t>
__global__ void Dequantize4BitsKernel(
T* output,
const uint8_t* quant_data,
Expand All @@ -116,15 +116,15 @@ __global__ void Dequantize4BitsKernel(
if (block_id >= total_groups) {
return;
}
const int zero_point_shape_x = (groups_per_K + 1) / 2;
const int scales_shape_x = groups_per_K;
int n_idx = block_id / scales_shape_x;
int kb_idx = block_id % scales_shape_x;
int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1));
uint32_t quant_value = *(reinterpret_cast<const uint32_t*>(quant_data + element_offset / 2));
T scale = *(scale_data + block_id);
T zero_point_value;
if constexpr(std::is_same_v<ZeroT, uint8_t>) {
if constexpr (std::is_same_v<ZeroT, uint8_t>) {
const int scales_shape_x = groups_per_K;
const int zero_point_shape_x = (groups_per_K + 1) / 2;
int kb_idx = block_id % scales_shape_x;
int n_idx = block_id / scales_shape_x;
uint8_t zp = 8;
if (zero_points) {
zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2];
Expand Down Expand Up @@ -168,7 +168,7 @@ Status Dequantize4Bits(
groups_per_threadblock,
total_groups);
} else {
//static_assert(std::is_same_v<ZeroT, uint8_t>, "ZeroT must be uint8_t");
// static_assert(std::is_same_v<ZeroT, uint8_t>, "ZeroT must be uint8_t");
Dequantize4BitsKernelReOrder<<<groups_per_grid, GridDim::maxThreadsPerBlock, 0, stream>>>(
output,
quant_data,
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// pre-packed and block-compacted into int4
//

#include "matmul_nbits.h"
#include "contrib_ops/cuda/quantization/matmul_nbits.h"
#include <cstdint>
#include "core/common/status.h"
#include "core/framework/float16.h"
Expand Down Expand Up @@ -67,6 +67,9 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
IAllocatorUniquePtr<T> b_data_ptr = GetScratchBuffer<T>(N_ * K_padded, ctx->GetComputeStream());
auto* b_data = b_data_ptr.get();
if (column_wise_quant_blk_) {
if (reorder_idx) {
ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]");
}
// column-wise block
if ((zero_points && zero_points->IsDataType<T>())) {
ORT_RETURN_IF_ERROR(Dequantize4Bits(
Expand Down Expand Up @@ -159,7 +162,7 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<MLFloat16>())
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
MatMulNBits<MLFloat16>);

} // namespace cuda
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3375,7 +3375,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
.Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional)
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
.TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/uint32/int32/float16.")
.TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.")
.TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.")

Check warning on line 3379 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/graph/contrib_ops/contrib_defs.cc:3379: Lines should be <= 120 characters long [whitespace/line_length] [2]
.TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
Expand Down
26 changes: 18 additions & 8 deletions onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np
import numpy.typing as npt
import onnx
import torch
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
from packaging import version

Expand Down Expand Up @@ -155,14 +154,16 @@ def __init__(
# Proximal solver || weight - dequantize(quantize(weight))||_p^p
@staticmethod
def optimize_weights(
tensor: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
tensor,
scale,
zero,
min_max: list[int],
axis: int = 0,
opt_params: dict = None, # noqa: RUF013
verbose=False,
):
import torch

opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params
lp_norm, beta, kappa, iters = (
opt_params["lp_norm"],
Expand Down Expand Up @@ -214,7 +215,7 @@ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits):
ori_int_tensor = ori_int_tensor.T
pack_tensor = pack_tensor.T
if bits in [2, 4, 8]:
compress_ratio = pack_tensor.dtype.itemsize * 8 // bits
compress_ratio = pack_tensor.element_size() * 8 // bits
for j in range(0, compress_ratio):
pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j))
else:
Expand All @@ -224,6 +225,8 @@ def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits):
def quantize_internal(
self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1
):
import torch

weight = tensor.float()
ori_shape = weight.shape

Expand Down Expand Up @@ -288,6 +291,7 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]):
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
if node.op_type != "MatMul":
return node # only care about MatMul for now
import torch

logger.info(f"start to quantize {node.name} ...")
inputB = node.input[1] # noqa: N806
Expand Down Expand Up @@ -458,6 +462,8 @@ class MatMul4BitsQuantizer:
def __init__(
self,
model: ModelProto | str,
block_size: int = 128,
is_symmetric: bool = False,
accuracy_level: int | None = None,
nodes_to_exclude=None,
algo_config: WeightOnlyQuantConfig = None,
Expand All @@ -466,11 +472,15 @@ def __init__(
nodes_to_exclude = []
self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model)
self.model_path = model if isinstance(model, str) else None
self.block_size = block_size
self.is_symmetric = is_symmetric
self.accuracy_level = accuracy_level
self.nodes_to_exclude = set(nodes_to_exclude)
self.node_quantizer = None
if algo_config is None:
algo_config = DefaultWeightOnlyQuantConfig(block_size=32, is_symmetric=False, accuracy_level=accuracy_level)
algo_config = DefaultWeightOnlyQuantConfig(
block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level
)
self.algo_config = algo_config
if algo_config.algorithm == "HQQ":
self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
Expand Down Expand Up @@ -527,8 +537,8 @@ def _generate_q4_node_config(self):
q4_node_config = {}
template_config_q4 = {
"bits": 4,
"group_size": self.algo_config.block_size,
"scheme": "sym" if self.algo_config.is_symmetric else "asym",
"group_size": self.block_size,
"scheme": "sym" if self.is_symmetric else "asym",
}
for node in self.model.model.graph.node:
if node.op_type in ["MatMul"]:
Expand Down
21 changes: 14 additions & 7 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void QuantizeDequantize(std::vector<float>& raw_vals,

void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level,
bool has_zeropoint, bool use_float16, bool has_g_idx = false, bool zp_is_4bit = true, float fp16_abs_error = 0.02f) {

Check warning on line 70 in onnxruntime/test/contrib_ops/matmul_4bits_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/contrib_ops/matmul_4bits_test.cc:70: Lines should be <= 120 characters long [whitespace/line_length] [2]
std::cerr << M << " " << N << " " << K << " " << block_size << " " << has_zeropoint << " " << use_float16 << " " << has_g_idx << " " << zp_is_4bit << " " << std::endl;

Check warning on line 71 in onnxruntime/test/contrib_ops/matmul_4bits_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/contrib_ops/matmul_4bits_test.cc:71: Lines should be <= 120 characters long [whitespace/line_length] [2]
zp_is_4bit = zp_is_4bit | has_g_idx;
RandomValueGenerator random{1234};
std::vector<float> input0_vals(random.Gaussian<float>(std::vector<int64_t>({M, K}), 0.0f, 0.25f));
Expand Down Expand Up @@ -115,6 +116,8 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", accuracy_level);
auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; };

if (use_float16) {
test.AddInput<MLFloat16>("A", {M, K}, ToFloat16(input0_vals), false);
test.AddInput<uint8_t>("B", {q_cols, q_rows}, input1_vals, true);
Expand All @@ -137,15 +140,16 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura

test.AddInput<MLFloat16>("zero_points", {static_cast<int64_t>(q_scale_size)}, ToFloat16(zp_f), true);
}
} else {
} else if(has_g_idx) {

Check warning on line 143 in onnxruntime/test/contrib_ops/matmul_4bits_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing space before ( in if( [whitespace/parens] [5] Raw Output: onnxruntime/test/contrib_ops/matmul_4bits_test.cc:143: Missing space before ( in if( [whitespace/parens] [5]
test.AddInput<uint8_t>("", {0}, {});
}
if (has_g_idx) {
std::vector<int32_t> g_idx(K);
for (int64_t i = 0; i < K; i++) {
int K_pad = gsl::narrow<int32_t>(ceildiv(K, block_size) * block_size);
std::vector<int32_t> g_idx(K_pad);
for (int64_t i = 0; i < K_pad; i++) {
g_idx[i] = gsl::narrow<int32_t>(i / block_size);
}
test.AddInput<int32_t>("g_idx", {static_cast<int64_t>(K)}, g_idx, true);
test.AddInput<int32_t>("g_idx", {static_cast<int64_t>(K_pad)}, g_idx, true);
}

test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(expected_vals));
Expand Down Expand Up @@ -180,11 +184,12 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
test.AddInput<uint8_t>("", {0}, {});
}
if (has_g_idx) {
std::vector<int32_t> g_idx(K);
for (int64_t i = 0; i < K; i++) {
int K_pad = gsl::narrow<int32_t>(ceildiv(K, block_size) * block_size);
std::vector<int32_t> g_idx(K_pad);
for (int64_t i = 0; i < K_pad; i++) {
g_idx[i] = gsl::narrow<int32_t>(i / block_size);
}
test.AddInput<int32_t>("g_idx", {static_cast<int64_t>(K)}, g_idx, true);
test.AddInput<int32_t>("g_idx", {static_cast<int64_t>(K_pad)}, g_idx, true);
}
test.AddOutput<float>("Y", {M, N}, expected_vals);
if (accuracy_level == 4) {
Expand All @@ -209,7 +214,9 @@ TEST(MatMulNBits, Float32) {
for (auto accuracy_level : {0}) {
RunTest(M, N, K, block_size, accuracy_level, false, false);
RunTest(M, N, K, block_size, accuracy_level, true, false);
#if !defined(DISABLE_OPTIONAL_TYPE)
RunTest(M, N, K, block_size, accuracy_level, false, false, true);
#endif
RunTest(M, N, K, block_size, accuracy_level, true, false, false, false);
}
#endif
Expand Down
Loading

0 comments on commit 1167ad7

Please sign in to comment.