Skip to content

Commit

Permalink
fix fast foak configs
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Nov 18, 2024
1 parent 7abf93c commit 2462613
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 48 deletions.
2 changes: 1 addition & 1 deletion plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _check_config_and_maybe_check_values(
t = list(t.keys())[0] # otherwise take the first value

if t not in values:
if default is None:
if t is not None or default is None:
raise AccelerationPluginConfigError(
f"{self.__class__.__name__}: Value at '{key}' was '{t}'. "
f"Not found in expected set '{values}'."
Expand Down
3 changes: 0 additions & 3 deletions plugins/fused-ops-and-kernels/configs/fast_kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,3 @@ training:

# fast RoPE embedding triton kernels
fast_rope_embeddings: True

# fused linear cross entropy loss
fused_linear_loss: False
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@ training:
# - the FastQuantized version is all-or-nothing

# fast loss triton kernels
fast_loss: False
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rms_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True

# fused linear cross entropy loss
fused_linear_loss: True
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@ peft:
fused_lora: True

# fast loss triton kernels
fast_loss: False
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rsm_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True

# fused linear cross entropy loss
fused_linear_loss: True
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@
from .models.utils import filter_mp_rules
from .utils import lora_adapters_switch_ddp_from_fsdp


def validate_plugin_args(configurations):
# Consider making this a more graceful fallback?
assert (
configurations["fused_linear_loss"] != configurations["fast_loss"]
), "If using `fused_linear_loss`, `fast_loss` must be set to False"


# consider rewriting register_foak_model_patch_rules into something
# like this also
def register_foak_model_patch_rules(
Expand Down Expand Up @@ -80,10 +72,12 @@ def register_foak_model_patch_rules(
# maybe this we should define envvars
FILTER_MAP = {
"fused_lora": {"qkvo", "mlp"},
"fast_loss": "cross-ent",
"fast_loss": {
True: "cross-ent",
"fused_ce_liger": "fused-lce",
},
"fast_rms_layernorm": "rms",
"fast_rope_embeddings": "rope",
"fused_linear_loss": "fused-lce",
}


Expand Down Expand Up @@ -117,28 +111,21 @@ def __init__(self, configurations: Dict[str, Dict]):
key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq"
)
self.configurations["fused_lora"] = self._check_config_and_maybe_check_values(
key="fused_lora", values=[False, True], default=True
key="fused_lora", values=[False, True], default=False
)
self.configurations["fast_loss"] = self._check_config_and_maybe_check_values(
key="fast_loss", values=[False, True], default=True
key="fast_loss", values=[False, True, 'fused_ce_liger'], default=False
)
self.configurations["fast_rms_layernorm"] = (
self._check_config_and_maybe_check_values(
key="fast_rms_layernorm", values=[False, True], default=True
key="fast_rms_layernorm", values=[False, True], default=False
)
)
self.configurations["fast_rope_embeddings"] = (
self._check_config_and_maybe_check_values(
key="fast_rope_embeddings", values=[False, True], default=True
key="fast_rope_embeddings", values=[False, True], default=False
)
)
self.configurations["fused_linear_loss"] = (
self._check_config_and_maybe_check_values(
key="fused_linear_loss", values=[False, True], default=False
)
)

validate_plugin_args(self.configurations)

@property
def requires_agumentation(self):
Expand Down Expand Up @@ -177,6 +164,8 @@ def augmentation(

if k in FILTER_MAP and k not in omitted:
ts = FILTER_MAP[k]
if isinstance(ts, dict) and v in ts:
ts = ts[v]
if isinstance(ts, str):
ts = {ts}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@ plugins:
fused_lora: true

# fast loss triton kernels
fast_loss: false
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rsm_layernorm: true

# fast RoPE embedding triton kernels
fast_rope_embeddings: true

# fused linear cross entropy loss
fused_linear_loss: true
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,10 @@ plugins:
fused_lora: true

# fast loss triton kernels
fast_loss: false
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rsm_layernorm: true

# fast RoPE embedding triton kernels
fast_rope_embeddings: true

# fused linear cross entropy loss
fused_linear_loss: true
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@ plugins:
# - the FastQuantized version is all-or-nothing

# fast loss triton kernels
fast_loss: false
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rms_layernorm: true

# fast RoPE embedding triton kernels
fast_rope_embeddings: true

# fused linear cross entropy loss
fused_linear_loss: true
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,3 @@ plugins:

# fast RoPE embedding triton kernels
fast_rope_embeddings: true

# fused linear cross entropy loss
fused_linear_loss: false
2 changes: 2 additions & 0 deletions scripts/generate_sample_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def read_configuration(path: str) -> Dict:
("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)),
("foak-fast-kernels", (KEY_FAST_KERNELS,)),
("foak-fast-kernels-liger", (KEY_FAST_KERNELS_LIGER,)),
("moe-scattermoe-granite-ep1", (KEY_SCATTERMOE_EP1,)),
("moe-scattermoe-granite-ep1-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP1,)),
("moe-scattermoe-granite-ep1-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP1,)),
("moe-scattermoe-granite-ep2", (KEY_SCATTERMOE_EP2,)),
("moe-scattermoe-granite-ep2-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP2,)),
Expand Down

0 comments on commit 2462613

Please sign in to comment.