From 64944c2e17e19643e963ae7e44666e88c130e638 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Tue, 12 Nov 2024 07:41:53 +0000 Subject: [PATCH] fix readme and update GraniteMoE to FOAK Signed-off-by: Yu Chin Fabian Lim --- .../utils/scattermoe_state_dict.py | 4 +- .../tests/test_scattermoe_state_dict.py | 10 +- .../framework_plugin_fast_kernels.py | 3 + .../framework_plugin_fast_quantized_peft.py | 6 +- .../models/granitemoe.py | 116 ++++++++++++++++++ sample-configurations/CONTENTS.yaml | 6 +- scripts/benchmarks/README.md | 3 +- scripts/benchmarks/scenarios-moe.yaml | 3 + 8 files changed, 138 insertions(+), 13 deletions(-) create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index 1b74beb7..e13f6ba5 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -132,8 +132,8 @@ def _insert(L: List, i: int, v): _names = expert_name.split("|") _n, _n2 = len(_names), len(PARAM_NAME_WEIGHT_SCATTERMOE) assert ( - 2 <= _n < _n2 - ), f"If expert_name has |, expect between 2 and {_n2} entries." + 2 <= _n <= _n2 + ), f"If expert_name has |, expect between 2 and {_n2} entries, but got {_n}." for i, n in enumerate(_names): if n not in expert_map: diff --git a/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py b/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py index 8798a4ad..ff8965ba 100644 --- a/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py +++ b/plugins/accelerated-moe/tests/test_scattermoe_state_dict.py @@ -119,13 +119,13 @@ def build_dummy_weight_map_non_sharded_moe( @pytest.mark.parametrize( ( - "sharded_ckpt,prefix,module_name,router_name,expert_name,", - "num_layers,num_experts,expert_keys,sharded_ckpt", + "sharded_ckpt,prefix,module_name,router_name,expert_name," + "num_layers,num_experts,expert_keys" ), PARAMETERS, ) def test_get_metadata_from_sharded_safetensor_correctly( - sharded_cpkt: bool, + sharded_ckpt: bool, prefix: str, module_name: str, router_name: str, @@ -135,7 +135,7 @@ def test_get_metadata_from_sharded_safetensor_correctly( expert_keys: List[str], ): - if sharded_cpkt: + if sharded_ckpt: weight_map = build_dummy_weight_map_sharded_moe( prefix, module_name, @@ -170,7 +170,7 @@ def test_get_metadata_from_sharded_safetensor_correctly( assert _key in ckpt_metadata, f"unable top map scattermoe expert weight {n}." _n = len(ckpt_metadata[_key]) - if sharded_cpkt: + if sharded_ckpt: assert ( _n == num_experts ), f"missing expert weights, only mapped {_n} weights out of {num_experts}." diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 16ea64b7..906a4668 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -39,6 +39,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = from .models import ( # pylint: disable=import-outside-toplevel gpt_bigcode, granite, + granitemoe, llama, mistral, mixtral, @@ -47,6 +48,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = rules = [ *gpt_bigcode.get_mp_rules(base_type), *granite.get_mp_rules(base_type), + *granitemoe.get_mp_rules(base_type), *llama.get_mp_rules(base_type), *mistral.get_mp_rules(base_type), *mixtral.get_mp_rules(base_type), @@ -76,6 +78,7 @@ class FastKernelsAccelerationPlugin(AccelerationPlugin): # NOTE: may remove this when we have generic model rules restricted_model_archs = [ "GraniteForCausalLM", + "GraniteMoeForCausalLM", "GPTBigCodeForCausalLM", "MixtralForCausalLM", "LlamaForCausalLM", diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py index c825ebb9..9fbab69f 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py @@ -25,6 +25,7 @@ import torch import torch.distributed as dist + # consider moving this somewhere else later def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin): """ @@ -56,7 +57,7 @@ def _all_reduce_hook(grad): if not A.weight.is_cuda: value = A.weight - if is_fsdp_enabled() and value.device == torch.device('meta'): + if is_fsdp_enabled() and value.device == torch.device("meta"): # if low_cpu_mem_mode value = torch.empty(*value.size(), dtype=value.dtype) @@ -68,7 +69,7 @@ def _all_reduce_hook(grad): if not B.weight.is_cuda: value = B.weight - if is_fsdp_enabled() and value.device == torch.device('meta'): + if is_fsdp_enabled() and value.device == torch.device("meta"): value = torch.empty(*value.size(), dtype=value.dtype) set_module_tensor_to_device(B, "weight", "cuda", value) @@ -81,6 +82,7 @@ def _all_reduce_hook(grad): A.weight.register_hook(_all_reduce_hook) B.weight.register_hook(_all_reduce_hook) + def register_foak_model_patch_rules(base_type): # Third Party from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py new file mode 100644 index 00000000..6da14682 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granitemoe.py @@ -0,0 +1,116 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from functools import partial + +# Third Party +from fms_acceleration.model_patcher import ( + ModelPatcherRule, + ModelPatcherTrigger, + combine_functions, + combine_triggers, +) + +# Local +from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss +from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm +from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops + + +def get_mp_rules(base_type: str): + """ + Function to access all patch rules in this module. + If it is a forward_builder rule with `base_type` in + its forward builder argument, wrap the forward_builder + function as a partial function with the base_type argument + """ + try: + # Third Party + from transformers.models.granitemoe.modeling_granitemoe import ( # pylint: disable=import-outside-toplevel + GraniteMoeAttention, + GraniteMoeRMSNorm, + ) + except ImportError: + return [] + + return [ + # TODO: have a generic version of this rule + # - do regex on RMSNorm class name + # - check on the tensors required for fast_rms_layernorm + ModelPatcherRule( + rule_id="granitemoe-rms", + trigger=ModelPatcherTrigger(check=GraniteMoeRMSNorm), + forward=fast_rms_layernorm, + ), + # TODO: have a generic version of this rule + # - do regex on Attention class name + # - have a set of qkv / o module names and check on that + ModelPatcherRule( + rule_id="granitemoe-qkvo", + trigger=combine_triggers( + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=GraniteMoeAttention, + submodule_names=["q_proj", "k_proj", "v_proj"], + ) + ), + ModelPatcherTrigger( + check=partial( + trigger_fused_ops, + attn_cls=GraniteMoeAttention, + submodule_names=["o_proj"], + ) + ), + logic="OR", + ), + forward_builder=combine_functions( + partial( + build_lora_fused_ops, + submodule_names=["q_proj", "k_proj", "v_proj"], + fused_op=KEY_QKV, + base_type=base_type, + ), + partial( + build_lora_fused_ops, + submodule_names=["o_proj"], + fused_op=KEY_O, + base_type=base_type, + ), + logic="APPEND", + ), + ), + ModelPatcherRule( + rule_id="granitemoe-cross-ent", + import_and_maybe_reload=( + "torch.nn.CrossEntropyLoss", + FastCrossEntropyLoss, + "transformers.models.granitemoe.modeling_granitemoe", + ), + ), + # TODO: have a generic version of this rule + # - get the module name + # - check if "apply_rotary_pos_emb" exists + # - patch + ModelPatcherRule( + rule_id="granitemoe-rope", + import_and_maybe_reload=( + "transformers.models.granitemoe.modeling_granitemoe.apply_rotary_pos_emb", + fast_rope_embedding, + None, + ), + ), + ] diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index dd393633..b3b1deec 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -91,7 +91,7 @@ framework_configs: - accelerated-moe - attention-and-distributed-packing - fused-ops-and-kernels - filename: moe-scattermoe-granite-ep1-padding-foak-free-sample-configuration.yaml + filename: moe-scattermoe-granite-ep1-padding-free-foak-sample-configuration.yaml - shortname: moe-scattermoe-granite-ep2 plugins: @@ -109,7 +109,7 @@ framework_configs: - accelerated-moe - attention-and-distributed-packing - fused-ops-and-kernels - filename: moe-scattermoe-granite-ep2-padding-foak-free-sample-configuration.yaml + filename: moe-scattermoe-granite-ep2-padding-free-foak-sample-configuration.yaml - shortname: moe-scattermoe-granite-ep4 plugins: @@ -127,7 +127,7 @@ framework_configs: - accelerated-moe - attention-and-distributed-packing - fused-ops-and-kernels - filename: moe-scattermoe-granite-ep4-padding-foak-free-sample-configuration.yaml + filename: moe-scattermoe-granite-ep4-padding-free-foak-sample-configuration.yaml - shortname: moe-scattermoe-granite-ep8 plugins: diff --git a/scripts/benchmarks/README.md b/scripts/benchmarks/README.md index 269d3ead..4795efee 100644 --- a/scripts/benchmarks/README.md +++ b/scripts/benchmarks/README.md @@ -76,13 +76,14 @@ bash run_benchmarks.sh NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FIL ``` where: - `NUM_GPUS_MATRIX`: list of `num_gpu` settings to bench for, e.g. `"1 2"` will bench for 1 and 2 gpus. +- `EFFECTIVE_BS_MATRIX`: list of effective batch sizes, e.g., `"4 8"` will bench for effective batch sizes 4 and 8. - `RESULT_DIR`: where the benchmark results will be placed. - `SCENARIOS_CONFIG`: the `scenarios.yaml` file. - `SCENARIOS_CONFIG`: specify to run only a specific `scenario` by providing the specific `scenario` name. The recommended way to run `benchmarks.sh` is using `tox` which handles the dependencies: ``` -tox -e run-benches -- NUM_GPUS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER +tox -e run-benches -- NUM_GPUS_MATRIX EFFECTIVE_BS_MATRIX RESULT_DIR SCENARIOS_CONFIG SCENARIOS_FILTER ``` Alternatively run [`benchmark.py`](./benchmark.py) directly. To see the help do: diff --git a/scripts/benchmarks/scenarios-moe.yaml b/scripts/benchmarks/scenarios-moe.yaml index a1fd4c48..efa2725e 100644 --- a/scripts/benchmarks/scenarios-moe.yaml +++ b/scripts/benchmarks/scenarios-moe.yaml @@ -43,8 +43,11 @@ scenarios: - moe-scattermoe-granite-ep2 - moe-scattermoe-granite-ep4 - moe-scattermoe-granite-ep1-padding-free + - moe-scattermoe-granite-ep1-padding-free-foak - moe-scattermoe-granite-ep2-padding-free + - moe-scattermoe-granite-ep2-padding-free-foak - moe-scattermoe-granite-ep4-padding-free + - moe-scattermoe-granite-ep4-padding-free-foak arguments: learning_rate: 5e-5 torch_dtype: bfloat16