From 7930227c785b44d29fa68bc8897f7ea3e96e2848 Mon Sep 17 00:00:00 2001 From: Your Date: Mon, 8 Jan 2024 21:41:01 +0000 Subject: [PATCH] lint --- .../cuda/cutlass_extensions/ft_gemm_configs.h | 1 - .../gemm/kernel/default_fpA_intB_traits.h | 1 - .../gemm/kernel/gemm_moe_problem_visitor.h | 1 - .../gemm/kernel/moe_cutlass_kernel.h | 1 - onnxruntime/contrib_ops/cuda/moe/moe_base.h | 24 +++++++++---------- .../test_parity_mixtral_moe_int4.py | 22 +++++++++-------- 6 files changed, 24 insertions(+), 26 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/cutlass_extensions/ft_gemm_configs.h b/onnxruntime/contrib_ops/cuda/cutlass_extensions/ft_gemm_configs.h index ec7d710749960..d21c21dcd5e16 100644 --- a/onnxruntime/contrib_ops/cuda/cutlass_extensions/ft_gemm_configs.h +++ b/onnxruntime/contrib_ops/cuda/cutlass_extensions/ft_gemm_configs.h @@ -54,4 +54,3 @@ struct CutlassGemmConfig { int stages = -1; }; } // namespace ort_fastertransformer - diff --git a/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index 220735e828b0d..2811504daa3a0 100644 --- a/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ b/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -141,4 +141,3 @@ struct MixedGemmArchTraits< } // namespace kernel } // namespace gemm } // namespace cutlass - diff --git a/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h index f36ed47846659..9a285b4df927c 100644 --- a/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -76,4 +76,3 @@ struct GemmMoeProblemVisitor } // namespace kernel } // namespace gemm } // namespace cutlass - diff --git a/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index 239337e631e85..a64daa397a742 100644 --- a/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -505,4 +505,3 @@ struct MoeFCGemm { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 0c84731951d64..b27d0b9980172 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -48,7 +48,7 @@ class MoEBase { int64_t hidden_size = input_dims[input_dims.size() - 1]; int64_t local_num_experts = fc1_experts_weights_dims[0]; int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = fc1_experts_weights_dims[2] * (is_uint8_t?2:1); + int64_t inter_size = fc1_experts_weights_dims[2] * (is_uint8_t ? 2 : 1); if (fc1_experts_weights_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", @@ -69,17 +69,17 @@ class MoEBase { fc2_experts_weights_dims[1], " and ", inter_size); } - //if (fc1_experts_weights_dims[2] != inter_size) { - // return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - // "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - // fc1_experts_weights_dims[2], - // " and ", inter_size); - //} - //if (fc2_experts_weights_dims[2] != hidden_size) { - // return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - // "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - // fc2_experts_weights_dims[2], " and ", hidden_size); - //} + // if (fc1_experts_weights_dims[2] != inter_size) { + // return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + // "fc1_experts_weights_dims[2] must be equal to inter_size, got ", + // fc1_experts_weights_dims[2], + // " and ", inter_size); + // } + // if (fc2_experts_weights_dims[2] != hidden_size) { + // return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + // "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", + // fc2_experts_weights_dims[2], " and ", hidden_size); + // } if (router_probs_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", router_probs_dims.size()); diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe_int4.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe_int4.py index 79aa058c2eac6..8184f2e598320 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe_int4.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe_int4.py @@ -45,22 +45,24 @@ def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") -def quant_dequant(weights, quant_mode:bool=True): +def quant_dequant(weights, quant_mode: bool = True): # use the test version `_symmetric_...` to get the non-interleaved weights type = torch.quint4x2 if quant_mode else torch.int8 - quant_weights, processed_q_weight, torch_weight_scales = torch.ops.fastertransformer._symmetric_quantize_last_axis_of_batched_matrix( - weights.T.cpu().contiguous(), type) + ( + quant_weights, + processed_q_weight, + torch_weight_scales, + ) = torch.ops.fastertransformer._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) # Unpack the int4s int int8s if quant_mode: - upper = (quant_weights >> 4) + upper = quant_weights >> 4 lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends dq_quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) dq_quant_weights = dq_quant_weights.to(dtype=weights.dtype) - result = torch.multiply(dq_quant_weights, - torch_weight_scales.unsqueeze(0)).T.contiguous() - return torch_weight_scales,processed_q_weight, result.to(device=weights.device) + result = torch.multiply(dq_quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous() + return torch_weight_scales, processed_q_weight, result.to(device=weights.device) def create_moe_onnx_graph( @@ -102,9 +104,9 @@ def create_moe_onnx_graph( ), ] - fc1_shape = [num_experts, hidden_size, inter_size//2] - fc2_shape = [num_experts, inter_size, hidden_size//2] - fc3_shape = [num_experts, hidden_size, inter_size//2] + fc1_shape = [num_experts, hidden_size, inter_size // 2] + fc2_shape = [num_experts, inter_size, hidden_size // 2] + fc3_shape = [num_experts, hidden_size, inter_size // 2] torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 np_type = numpy.uint8