From 0406dc1c04ba971f5d3a29ad3955e44e2bcc96c9 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 11 Nov 2024 08:52:06 +0000 Subject: [PATCH] improved documention on state dict inferernce Signed-off-by: Yu Chin Fabian Lim --- plugins/accelerated-moe/README.md | 3 +- .../framework_plugin_scattermoe.py | 8 ++--- .../utils/checkpoint_utils.py | 4 +-- .../utils/scattermoe_state_dict.py | 34 +++++++++++++++---- 4 files changed, 35 insertions(+), 14 deletions(-) diff --git a/plugins/accelerated-moe/README.md b/plugins/accelerated-moe/README.md index 2bb304ec..1d6af169 100644 --- a/plugins/accelerated-moe/README.md +++ b/plugins/accelerated-moe/README.md @@ -85,8 +85,9 @@ pip install -r requirements-khd.txt ### Known Issues These are currently some known issues not yet resolved: -- The design currently does a swap for the mixture-of-expert module with [ScatterMoE](./src/fms_acceleration_moe/utils/scattermoe.py). This affects the `state_dict` of the model, so any saved checkpoint may need to be converted back to original. - should eventually remove the dependency on an external `kernel-hyperdrive` repository. - now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed. +- when used together with FSDP, the FSDP's `clip_grad_norm` will not properly compute for `ScatterMoE`, see [issue here](https://github.com/foundation-model-stack/fms-acceleration/issues/109). + diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index c84df07b..148a5488 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -31,12 +31,12 @@ # pylint: disable=too-many-instance-attributes class ScatterMoEAccelerationPlugin(AccelerationPlugin): - # NOTE: we cannot do + # NOTE: we cannot do # - require_packages = {"khd"} # this is because the khd fork is not properly packaged as a PyPI project, and so - # - "importlib.util.find_spec('khd')" returns, but - # - "importlib.metadata.version('kernel-hyperdrive')" does not return - # if we decide to extract the kernels, then we do not need to anymore, + # - "importlib.util.find_spec('khd')" returns, but + # - "importlib.metadata.version('kernel-hyperdrive')" does not return + # if we decide to extract the kernels, then we do not need to anymore, # https://github.com/foundation-model-stack/fms-acceleration/issues/105 restricted_model_archs = ["GraniteMoeForCausalLM", "MixtralForCausalLM"] diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index c4298bd4..d8d33b18 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -217,7 +217,7 @@ def _dict_from_json_file(resolved_config_file): # - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints # to map the ScatterMoE checkpoint to that of the original model. This is useful so that we # can restore the checkpoint to be loaded by the original architecture. -def get_scattermoe_state_dict( +def recover_original_state_dict_from_dcp_checkpoint( dcp_checkpoint_dir: str, pretrained_model_name_or_path: str = None, ): @@ -460,7 +460,7 @@ def _infer_prefixes_and_module_names( checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0]) # get the converted statedict - state_dict = get_scattermoe_state_dict( + state_dict = recover_original_state_dict_from_dcp_checkpoint( checkpoint_dir, args.pretrained_model_name_or_path ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index b397c1cf..2d3ddf83 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -24,7 +24,11 @@ import torch # Local -from .scattermoe_constants import DIM_EXPERT, KEY_SCATTERMOE_ROUTER +from .scattermoe_constants import ( + DIM_EXPERT, + KEY_SCATTERMOE_ROUTER, + PARAM_NAME_WEIGHT_SCATTERMOE, +) # This function creates a dictionary of keys and paths into the the sharded # safetensors checkpoint file, that are relevant to the "prefix" and "instance_name" @@ -84,8 +88,24 @@ def get_checkpoint_meta_from_sharded_safetensor( expert_map: Dict = None, # map -> [w1,w2,w3] ) -> Dict[str, List[Tuple]]: """ - utilty function to infer the mapping from ScatterMoe to incoming model - based on the weight_map. + utilty function to infer the mapping of ScatterMoe parameters + from that of an incoming model model, based on a weight_map from a + sharded safetensor. + + Parameters: + weight_map (dict): The weight map read in from a safetensor checkpoint. + prefix (str): the prefix where the MoE module lives (with respect to orig model). + instance_name (str): the name of the MoE module in the orig model + router_name (str): name of the router module as it is called in the MoE module + in the original model. + expert_name (str): name of the experts as they are called in the MoE module in + the orignal model. There are two patterns to use this. + i) specifiy a single string, and map them based on the + e.g., experts.w1 -> w1 + ii) specify mutiple strings in order of w1, w2, ... + e.g., input_linear|output_linear|input_linear + expert_map (dict): This is used with pattern ii) described above in expert_name. + If not specified, will be the identity map, e.g., w1 -> w1 """ # insert in order @@ -110,14 +130,14 @@ def _insert(L: List, i: int, v): if "|" in expert_name: expert_map = {} _names = expert_name.split("|") - assert len(_names) in {2, 3}, "expert name map has to be length 2/3" + assert len(_names) >= 2, "expert name map has to be at least length 2." for i, n in enumerate(_names): if n not in expert_map: expert_map[n] = [] - expert_map[n].append(f"w{i+1}") + expert_map[n].append(PARAM_NAME_WEIGHT_SCATTERMOE[i]) else: - expert_map = {x: [x] for x in ["w1", "w2", "w3"]} + expert_map = {x: [x] for x in PARAM_NAME_WEIGHT_SCATTERMOE} # state dict -> weights # 'router.weight': [(k, file),...] @@ -166,7 +186,7 @@ def _maybe_reshape_scattermoe_expert_weights( intermediate_size: int, ): (_is_w1, _is_w2, _is_w3) = [ - f"{x}.weight" in scatter_key for x in ["w1", "w2", "w3"] + f"{x}.weight" in scatter_key for x in PARAM_NAME_WEIGHT_SCATTERMOE ] if _is_w1 or _is_w2 or _is_w3: