diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index ba238cba7508..cf087c16da8a 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -113,7 +113,7 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): raise ValueError("unknown fused_qkv_type") - module_name_matches = [k for k in fused_type_dict.keys() if module_str in k] + module_name_matches = [k for k in fused_type_dict.keys() if k in module_str] if module_name_matches: # There can be overlap with matches (e.g., "DecoderLayer" and "FalconDecoderLayer"). # We take the longest matching module_name