Skip to content

Commit

Permalink
fix readme and update GraniteMoE to FOAK
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Nov 12, 2024
1 parent d3fb653 commit 64944c2
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _insert(L: List, i: int, v):
_names = expert_name.split("|")
_n, _n2 = len(_names), len(PARAM_NAME_WEIGHT_SCATTERMOE)
assert (
2 <= _n < _n2
), f"If expert_name has |, expect between 2 and {_n2} entries."
2 <= _n <= _n2
), f"If expert_name has |, expect between 2 and {_n2} entries, but got {_n}."

for i, n in enumerate(_names):
if n not in expert_map:
Expand Down
10 changes: 5 additions & 5 deletions plugins/accelerated-moe/tests/test_scattermoe_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def build_dummy_weight_map_non_sharded_moe(

@pytest.mark.parametrize(
(
"sharded_ckpt,prefix,module_name,router_name,expert_name,",
"num_layers,num_experts,expert_keys,sharded_ckpt",
"sharded_ckpt,prefix,module_name,router_name,expert_name,"
"num_layers,num_experts,expert_keys"
),
PARAMETERS,
)
def test_get_metadata_from_sharded_safetensor_correctly(
sharded_cpkt: bool,
sharded_ckpt: bool,
prefix: str,
module_name: str,
router_name: str,
Expand All @@ -135,7 +135,7 @@ def test_get_metadata_from_sharded_safetensor_correctly(
expert_keys: List[str],
):

if sharded_cpkt:
if sharded_ckpt:
weight_map = build_dummy_weight_map_sharded_moe(
prefix,
module_name,
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_get_metadata_from_sharded_safetensor_correctly(
assert _key in ckpt_metadata, f"unable top map scattermoe expert weight {n}."

_n = len(ckpt_metadata[_key])
if sharded_cpkt:
if sharded_ckpt:
assert (
_n == num_experts
), f"missing expert weights, only mapped {_n} weights out of {num_experts}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
from .models import ( # pylint: disable=import-outside-toplevel
gpt_bigcode,
granite,
granitemoe,
llama,
mistral,
mixtral,
Expand All @@ -47,6 +48,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
rules = [
*gpt_bigcode.get_mp_rules(base_type),
*granite.get_mp_rules(base_type),
*granitemoe.get_mp_rules(base_type),
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
*mixtral.get_mp_rules(base_type),
Expand Down Expand Up @@ -76,6 +78,7 @@ class FastKernelsAccelerationPlugin(AccelerationPlugin):
# NOTE: may remove this when we have generic model rules
restricted_model_archs = [
"GraniteForCausalLM",
"GraniteMoeForCausalLM",
"GPTBigCodeForCausalLM",
"MixtralForCausalLM",
"LlamaForCausalLM",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
import torch.distributed as dist


# consider moving this somewhere else later
def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
"""
Expand Down Expand Up @@ -56,7 +57,7 @@ def _all_reduce_hook(grad):
if not A.weight.is_cuda:
value = A.weight

if is_fsdp_enabled() and value.device == torch.device('meta'):
if is_fsdp_enabled() and value.device == torch.device("meta"):
# if low_cpu_mem_mode
value = torch.empty(*value.size(), dtype=value.dtype)

Expand All @@ -68,7 +69,7 @@ def _all_reduce_hook(grad):
if not B.weight.is_cuda:
value = B.weight

if is_fsdp_enabled() and value.device == torch.device('meta'):
if is_fsdp_enabled() and value.device == torch.device("meta"):
value = torch.empty(*value.size(), dtype=value.dtype)

set_module_tensor_to_device(B, "weight", "cuda", value)
Expand All @@ -81,6 +82,7 @@ def _all_reduce_hook(grad):
A.weight.register_hook(_all_reduce_hook)
B.weight.register_hook(_all_reduce_hook)


def register_foak_model_patch_rules(base_type):
# Third Party
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from functools import partial

# Third Party
from fms_acceleration.model_patcher import (
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


def get_mp_rules(base_type: str):
"""
Function to access all patch rules in this module.
If it is a forward_builder rule with `base_type` in
its forward builder argument, wrap the forward_builder
function as a partial function with the base_type argument
"""
try:
# Third Party
from transformers.models.granitemoe.modeling_granitemoe import ( # pylint: disable=import-outside-toplevel
GraniteMoeAttention,
GraniteMoeRMSNorm,
)
except ImportError:
return []

return [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
ModelPatcherRule(
rule_id="granitemoe-rms",
trigger=ModelPatcherTrigger(check=GraniteMoeRMSNorm),
forward=fast_rms_layernorm,
),
# TODO: have a generic version of this rule
# - do regex on Attention class name
# - have a set of qkv / o module names and check on that
ModelPatcherRule(
rule_id="granitemoe-qkvo",
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=GraniteMoeAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=GraniteMoeAttention,
submodule_names=["o_proj"],
)
),
logic="OR",
),
forward_builder=combine_functions(
partial(
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
base_type=base_type,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
base_type=base_type,
),
logic="APPEND",
),
),
ModelPatcherRule(
rule_id="granitemoe-cross-ent",
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
FastCrossEntropyLoss,
"transformers.models.granitemoe.modeling_granitemoe",
),
),
# TODO: have a generic version of this rule
# - get the module name
# - check if "apply_rotary_pos_emb" exists
# - patch
ModelPatcherRule(
rule_id="granitemoe-rope",
import_and_maybe_reload=(
"transformers.models.granitemoe.modeling_granitemoe.apply_rotary_pos_emb",
fast_rope_embedding,
None,
),
),
]
6 changes: 3 additions & 3 deletions sample-configurations/CONTENTS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ framework_configs:
- accelerated-moe
- attention-and-distributed-packing
- fused-ops-and-kernels
filename: moe-scattermoe-granite-ep1-padding-foak-free-sample-configuration.yaml
filename: moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml

- shortname: moe-scattermoe-granite-ep2
plugins:
Expand All @@ -109,7 +109,7 @@ framework_configs:
- accelerated-moe
- attention-and-distributed-packing
- fused-ops-and-kernels
filename: moe-scattermoe-granite-ep2-padding-foak-free-sample-configuration.yaml
filename: moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml

- shortname: moe-scattermoe-granite-ep4
plugins:
Expand All @@ -127,7 +127,7 @@ framework_configs:
- accelerated-moe
- attention-and-distributed-packing
- fused-ops-and-kernels
filename: moe-scattermoe-granite-ep4-padding-foak-free-sample-configuration.yaml
filename: moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml

- shortname: moe-scattermoe-granite-ep8
plugins:
Expand Down
3 changes: 2 additions & 1 deletion scripts/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ bash run_benchmarks.sh NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FIL
```
where:
- `NUM_GPUS_MATRIX`: list of `num_gpu` settings to bench for, e.g. `"1 2"` will bench for 1 and 2 gpus.
- `EFFECTIVE_BS_MATRIX`: list of effective batch sizes, e.g., `"4 8"` will bench for effective batch sizes 4 and 8.
- `RESULT_DIR`: where the benchmark results will be placed.
- `SCENARIOS_CONFIG`: the `scenarios.yaml` file.
- `SCENARIOS_CONFIG`: specify to run only a specific `scenario` by providing the specific `scenario` name.

The recommended way to run `benchmarks.sh` is using `tox` which handles the dependencies:
```
tox -e run-benches -- NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER
tox -e run-benches -- NUM_GPUS_MATRIX EFFECTIVE_BS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER
```

Alternatively run [`benchmark.py`](./benchmark.py) directly. To see the help do:
Expand Down
3 changes: 3 additions & 0 deletions scripts/benchmarks/scenarios-moe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ scenarios:
- moe-scattermoe-granite-ep2
- moe-scattermoe-granite-ep4
- moe-scattermoe-granite-ep1-padding-free
- moe-scattermoe-granite-ep1-padding-free-foak
- moe-scattermoe-granite-ep2-padding-free
- moe-scattermoe-granite-ep2-padding-free-foak
- moe-scattermoe-granite-ep4-padding-free
- moe-scattermoe-granite-ep4-padding-free-foak
arguments:
learning_rate: 5e-5
torch_dtype: bfloat16
Expand Down

0 comments on commit 64944c2

Please sign in to comment.