diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index cf1764d5..28fecebf 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -206,7 +206,7 @@ def _check_config_and_maybe_check_values( t = list(t.keys())[0] # otherwise take the first value if t not in values: - if default is None: + if t is not None or default is None: raise AccelerationPluginConfigError( f"{self.__class__.__name__}: Value at '{key}' was '{t}'. " f"Not found in expected set '{values}'." diff --git a/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml b/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml index 823af26f..45f0051e 100644 --- a/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml +++ b/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml @@ -23,6 +23,3 @@ training: # fast RoPE embedding triton kernels fast_rope_embeddings: True - - # fused linear cross entropy loss - fused_linear_loss: False \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml b/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml index 8011db78..a154b95b 100644 --- a/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml +++ b/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml @@ -16,13 +16,10 @@ training: # - the FastQuantized version is all-or-nothing # fast loss triton kernels - fast_loss: False + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rms_layernorm: True # fast RoPE embedding triton kernels fast_rope_embeddings: True - - # fused linear cross entropy loss - fused_linear_loss: True \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml index 7f239849..c6655d34 100644 --- a/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml +++ b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml @@ -21,13 +21,10 @@ peft: fused_lora: True # fast loss triton kernels - fast_loss: False + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rsm_layernorm: True # fast RoPE embedding triton kernels fast_rope_embeddings: True - - # fused linear cross entropy loss - fused_linear_loss: True \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 049b26d4..4c215906 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -26,14 +26,6 @@ from .models.utils import filter_mp_rules from .utils import lora_adapters_switch_ddp_from_fsdp - -def validate_plugin_args(configurations): - # Consider making this a more graceful fallback? - assert ( - configurations["fused_linear_loss"] != configurations["fast_loss"] - ), "If using `fused_linear_loss`, `fast_loss` must be set to False" - - # consider rewriting register_foak_model_patch_rules into something # like this also def register_foak_model_patch_rules( @@ -80,10 +72,12 @@ def register_foak_model_patch_rules( # maybe this we should define envvars FILTER_MAP = { "fused_lora": {"qkvo", "mlp"}, - "fast_loss": "cross-ent", + "fast_loss": { + True: "cross-ent", + "fused_ce_liger": "fused-lce", + }, "fast_rms_layernorm": "rms", "fast_rope_embeddings": "rope", - "fused_linear_loss": "fused-lce", } @@ -117,28 +111,21 @@ def __init__(self, configurations: Dict[str, Dict]): key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq" ) self.configurations["fused_lora"] = self._check_config_and_maybe_check_values( - key="fused_lora", values=[False, True], default=True + key="fused_lora", values=[False, True], default=False ) self.configurations["fast_loss"] = self._check_config_and_maybe_check_values( - key="fast_loss", values=[False, True], default=True + key="fast_loss", values=[False, True, 'fused_ce_liger'], default=False ) self.configurations["fast_rms_layernorm"] = ( self._check_config_and_maybe_check_values( - key="fast_rms_layernorm", values=[False, True], default=True + key="fast_rms_layernorm", values=[False, True], default=False ) ) self.configurations["fast_rope_embeddings"] = ( self._check_config_and_maybe_check_values( - key="fast_rope_embeddings", values=[False, True], default=True + key="fast_rope_embeddings", values=[False, True], default=False ) ) - self.configurations["fused_linear_loss"] = ( - self._check_config_and_maybe_check_values( - key="fused_linear_loss", values=[False, True], default=False - ) - ) - - validate_plugin_args(self.configurations) @property def requires_agumentation(self): @@ -177,6 +164,8 @@ def augmentation( if k in FILTER_MAP and k not in omitted: ts = FILTER_MAP[k] + if isinstance(ts, dict) and v in ts: + ts = ts[v] if isinstance(ts, str): ts = {ts} diff --git a/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml index 1abc5a11..1126b4f8 100644 --- a/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml @@ -43,13 +43,10 @@ plugins: fused_lora: true # fast loss triton kernels - fast_loss: false + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rsm_layernorm: true # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: true diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml index 4376182e..71c305ac 100644 --- a/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml @@ -38,13 +38,10 @@ plugins: fused_lora: true # fast loss triton kernels - fast_loss: false + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rsm_layernorm: true # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: true diff --git a/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml b/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml index 7002026a..1752755f 100644 --- a/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml +++ b/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml @@ -21,13 +21,10 @@ plugins: # - the FastQuantized version is all-or-nothing # fast loss triton kernels - fast_loss: false + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rms_layernorm: true # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: true diff --git a/sample-configurations/foak-fast-kernels-sample-configuration.yaml b/sample-configurations/foak-fast-kernels-sample-configuration.yaml index ba7669aa..b9d646b6 100644 --- a/sample-configurations/foak-fast-kernels-sample-configuration.yaml +++ b/sample-configurations/foak-fast-kernels-sample-configuration.yaml @@ -28,6 +28,3 @@ plugins: # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: false diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index 157f55bb..6232dce6 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -224,7 +224,9 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)), ("foak-fast-kernels", (KEY_FAST_KERNELS,)), + ("foak-fast-kernels-liger", (KEY_FAST_KERNELS_LIGER,)), ("moe-scattermoe-granite-ep1", (KEY_SCATTERMOE_EP1,)), + ("moe-scattermoe-granite-ep1-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP1,)), ("moe-scattermoe-granite-ep1-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP1,)), ("moe-scattermoe-granite-ep2", (KEY_SCATTERMOE_EP2,)), ("moe-scattermoe-granite-ep2-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP2,)),