Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Jan 8, 2024
1 parent f238e0a commit 7930227
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,3 @@ struct CutlassGemmConfig {
int stages = -1;
};
} // namespace ort_fastertransformer

Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,3 @@ struct MixedGemmArchTraits<
} // namespace kernel
} // namespace gemm
} // namespace cutlass

Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,3 @@ struct GemmMoeProblemVisitor
} // namespace kernel
} // namespace gemm
} // namespace cutlass

Original file line number Diff line number Diff line change
Expand Up @@ -505,4 +505,3 @@ struct MoeFCGemm {
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

24 changes: 12 additions & 12 deletions onnxruntime/contrib_ops/cuda/moe/moe_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MoEBase {
int64_t hidden_size = input_dims[input_dims.size() - 1];
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = fc1_experts_weights_dims[2] * (is_uint8_t?2:1);
int64_t inter_size = fc1_experts_weights_dims[2] * (is_uint8_t ? 2 : 1);

if (fc1_experts_weights_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ",
Expand All @@ -69,17 +69,17 @@ class MoEBase {
fc2_experts_weights_dims[1],
" and ", inter_size);
}
//if (fc1_experts_weights_dims[2] != inter_size) {
// return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
// "fc1_experts_weights_dims[2] must be equal to inter_size, got ",
// fc1_experts_weights_dims[2],
// " and ", inter_size);
//}
//if (fc2_experts_weights_dims[2] != hidden_size) {
// return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
// "fc2_experts_weights_dims[2] must be equal to hidden_size, got ",
// fc2_experts_weights_dims[2], " and ", hidden_size);
//}
// if (fc1_experts_weights_dims[2] != inter_size) {
// return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
// "fc1_experts_weights_dims[2] must be equal to inter_size, got ",
// fc1_experts_weights_dims[2],
// " and ", inter_size);
// }
// if (fc2_experts_weights_dims[2] != hidden_size) {
// return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
// "fc2_experts_weights_dims[2] must be equal to hidden_size, got ",
// fc2_experts_weights_dims[2], " and ", hidden_size);
// }
if (router_probs_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ",
router_probs_dims.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,24 @@ def print_tensor(name, numpy_array):
print(f"const std::vector<float> {name} = {value_string_of(numpy_array)};")


def quant_dequant(weights, quant_mode:bool=True):
def quant_dequant(weights, quant_mode: bool = True):
# use the test version `_symmetric_...` to get the non-interleaved weights
type = torch.quint4x2 if quant_mode else torch.int8
quant_weights, processed_q_weight, torch_weight_scales = torch.ops.fastertransformer._symmetric_quantize_last_axis_of_batched_matrix(
weights.T.cpu().contiguous(), type)
(
quant_weights,
processed_q_weight,
torch_weight_scales,
) = torch.ops.fastertransformer._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type)

# Unpack the int4s int int8s
if quant_mode:
upper = (quant_weights >> 4)
upper = quant_weights >> 4
lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends
dq_quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape)

dq_quant_weights = dq_quant_weights.to(dtype=weights.dtype)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error test

Local variable 'dq_quant_weights' may be used before it is initialized.
result = torch.multiply(dq_quant_weights,
torch_weight_scales.unsqueeze(0)).T.contiguous()
return torch_weight_scales,processed_q_weight, result.to(device=weights.device)
result = torch.multiply(dq_quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous()
return torch_weight_scales, processed_q_weight, result.to(device=weights.device)


def create_moe_onnx_graph(
Expand Down Expand Up @@ -102,9 +104,9 @@ def create_moe_onnx_graph(
),
]

fc1_shape = [num_experts, hidden_size, inter_size//2]
fc2_shape = [num_experts, inter_size, hidden_size//2]
fc3_shape = [num_experts, hidden_size, inter_size//2]
fc1_shape = [num_experts, hidden_size, inter_size // 2]
fc2_shape = [num_experts, inter_size, hidden_size // 2]
fc3_shape = [num_experts, hidden_size, inter_size // 2]

torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32
np_type = numpy.uint8
Expand Down

0 comments on commit 7930227

Please sign in to comment.