Skip to content

Commit

Permalink
improved documention on state dict inferernce
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Nov 11, 2024
1 parent 005300b commit 0406dc1
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
3 changes: 2 additions & 1 deletion plugins/accelerated-moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ pip install -r requirements-khd.txt
### Known Issues

These are currently some known issues not yet resolved:
- The design currently does a swap for the mixture-of-expert module with [ScatterMoE](./src/fms_acceleration_moe/utils/scattermoe.py). This affects the `state_dict` of the model, so any saved checkpoint may need to be converted back to original.
- should eventually remove the dependency on an external `kernel-hyperdrive` repository.
- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
- when used together with FSDP, the FSDP's `clip_grad_norm` will not properly compute for `ScatterMoE`, see [issue here](https://github.com/foundation-model-stack/fms-acceleration/issues/109).



Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
# pylint: disable=too-many-instance-attributes
class ScatterMoEAccelerationPlugin(AccelerationPlugin):

# NOTE: we cannot do
# NOTE: we cannot do
# - require_packages = {"khd"}
# this is because the khd fork is not properly packaged as a PyPI project, and so
# - "importlib.util.find_spec('khd')" returns, but
# - "importlib.metadata.version('kernel-hyperdrive')" does not return
# if we decide to extract the kernels, then we do not need to anymore,
# - "importlib.util.find_spec('khd')" returns, but
# - "importlib.metadata.version('kernel-hyperdrive')" does not return
# if we decide to extract the kernels, then we do not need to anymore,
# https://github.com/foundation-model-stack/fms-acceleration/issues/105

restricted_model_archs = ["GraniteMoeForCausalLM", "MixtralForCausalLM"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _dict_from_json_file(resolved_config_file):
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
# can restore the checkpoint to be loaded by the original architecture.
def get_scattermoe_state_dict(
def recover_original_state_dict_from_dcp_checkpoint(
dcp_checkpoint_dir: str,
pretrained_model_name_or_path: str = None,
):
Expand Down Expand Up @@ -460,7 +460,7 @@ def _infer_prefixes_and_module_names(
checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0])

# get the converted statedict
state_dict = get_scattermoe_state_dict(
state_dict = recover_original_state_dict_from_dcp_checkpoint(
checkpoint_dir, args.pretrained_model_name_or_path
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
import torch

# Local
from .scattermoe_constants import DIM_EXPERT, KEY_SCATTERMOE_ROUTER
from .scattermoe_constants import (
DIM_EXPERT,
KEY_SCATTERMOE_ROUTER,
PARAM_NAME_WEIGHT_SCATTERMOE,
)

# This function creates a dictionary of keys and paths into the the sharded
# safetensors checkpoint file, that are relevant to the "prefix" and "instance_name"
Expand Down Expand Up @@ -84,8 +88,24 @@ def get_checkpoint_meta_from_sharded_safetensor(
expert_map: Dict = None, # map -> [w1,w2,w3]
) -> Dict[str, List[Tuple]]:
"""
utilty function to infer the mapping from ScatterMoe to incoming model
based on the weight_map.
utilty function to infer the mapping of ScatterMoe parameters
from that of an incoming model model, based on a weight_map from a
sharded safetensor.
Parameters:
weight_map (dict): The weight map read in from a safetensor checkpoint.
prefix (str): the prefix where the MoE module lives (with respect to orig model).
instance_name (str): the name of the MoE module in the orig model
router_name (str): name of the router module as it is called in the MoE module
in the original model.
expert_name (str): name of the experts as they are called in the MoE module in
the orignal model. There are two patterns to use this.
i) specifiy a single string, and map them based on the
e.g., experts.w1 -> w1
ii) specify mutiple strings in order of w1, w2, ...
e.g., input_linear|output_linear|input_linear
expert_map (dict): This is used with pattern ii) described above in expert_name.
If not specified, will be the identity map, e.g., w1 -> w1
"""

# insert in order
Expand All @@ -110,14 +130,14 @@ def _insert(L: List, i: int, v):
if "|" in expert_name:
expert_map = {}
_names = expert_name.split("|")
assert len(_names) in {2, 3}, "expert name map has to be length 2/3"
assert len(_names) >= 2, "expert name map has to be at least length 2."

for i, n in enumerate(_names):
if n not in expert_map:
expert_map[n] = []
expert_map[n].append(f"w{i+1}")
expert_map[n].append(PARAM_NAME_WEIGHT_SCATTERMOE[i])
else:
expert_map = {x: [x] for x in ["w1", "w2", "w3"]}
expert_map = {x: [x] for x in PARAM_NAME_WEIGHT_SCATTERMOE}

# state dict -> weights
# 'router.weight': [(k, file),...]
Expand Down Expand Up @@ -166,7 +186,7 @@ def _maybe_reshape_scattermoe_expert_weights(
intermediate_size: int,
):
(_is_w1, _is_w2, _is_w3) = [
f"{x}.weight" in scatter_key for x in ["w1", "w2", "w3"]
f"{x}.weight" in scatter_key for x in PARAM_NAME_WEIGHT_SCATTERMOE
]

if _is_w1 or _is_w2 or _is_w3:
Expand Down

0 comments on commit 0406dc1

Please sign in to comment.