From a3926bbbf6d0025b5c6076a280e6b91ebd08aada Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 16 Nov 2023 16:16:58 -0800 Subject: [PATCH] infV2 fix for OPT size variants (#4694) Co-authored-by: Jeff Rasley --- .github/workflows/nv-a6000.yml | 4 +- deepspeed/inference/v2/engine_factory.py | 4 ++ .../layer_container_base.py | 41 +++++++++++++++++++ .../v2/model_implementations/opt/container.py | 9 ++-- .../v2/model_implementations/opt/policy.py | 4 +- 5 files changed, 53 insertions(+), 9 deletions(-) diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index e9c33c55f8ff..a2b99de488d5 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -36,7 +36,7 @@ jobs: python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Install transformers run: | - git clone https://github.com/huggingface/transformers + git clone --depth=1 https://github.com/huggingface/transformers cd transformers git rev-parse --short HEAD python -m pip install . @@ -56,7 +56,7 @@ jobs: python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12" - name: MII unit tests run: | - git clone https://github.com/microsoft/DeepSpeed-MII.git + git clone --depth=1 https://github.com/microsoft/DeepSpeed-MII.git cd DeepSpeed-MII pip install .[dev] cd tests diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index 8ff75cc52213..ecca9f3c1b34 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -91,6 +91,10 @@ def build_hf_engine(path: str, # get the policy # TODO: generalize this to other models if model_config.model_type == "opt": + if not model_config.do_layer_norm_before: + raise ValueError( + "Detected OPT-350m model. This model is not currently supported. If this is not the 350m model, please open an issue: https://github.com/microsoft/DeepSpeed-MII/issues" + ) policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "llama": policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine) diff --git a/deepspeed/inference/v2/model_implementations/layer_container_base.py b/deepspeed/inference/v2/model_implementations/layer_container_base.py index 98e3e0bb31ed..f26c87556665 100644 --- a/deepspeed/inference/v2/model_implementations/layer_container_base.py +++ b/deepspeed/inference/v2/model_implementations/layer_container_base.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import re from typing import Type import torch @@ -277,6 +278,30 @@ def set_dependency(self, dep_name: str, dep_value: torch.Tensor) -> None: dep_name (str): The name of the dependency to set. dep_value (torch.Tensor): The value to set the dependency to. """ + + def get_dep_name_target(dep_name: str) -> str: + """ + Helper method for getting the target name for a dependency from the + mapping params. Tries to match exact string first, then looks for + wildcards and attempts regex matching. Will return empty string if + no match found. + """ + if dep_name in self.mapping_params: + # If we have an exact match, it's a direct mapping and we can + # immediately set the value. + return self.mapping_params[dep_name] + + matched_targets = [] + for key, target in self.mapping_params.items(): + regex_key = key.replace("*", ".*") + if re.match(regex_key, dep_name): + matched_targets.append(target) + if len(matched_targets) > 1: + raise ValueError(f"Multiple targets matched for dependency {dep_name}: {matched_targets}") + if matched_targets: + return matched_targets[0] + return "" + if dep_name in self.mapping_params: # If we have an exact match, it's a direct mapping and we can immediately set # the value. @@ -309,6 +334,22 @@ def set_dependency(self, dep_name: str, dep_value: torch.Tensor) -> None: target_dependency = getattr(target_param, target_dependency_name) target_dependency[target_idx] = dep_value return + + # TODO: Refactor this with the help of cmikeh2 + # We should be able to combine this with the wildcard matching above. + target = get_dep_name_target(dep_name) + if target: + # Convert single targets to a list for consistency + if isinstance(target, str): + target = [target] + + for target_name in target: + # Double setting doesn't set the attribute correctly, so we do a getattr then setattr + target_param_name, target_dependency_name = target_name.split(".") + target_param = getattr(self, target_param_name) + setattr(target_param, target_dependency_name, dep_value) + return + raise ValueError( "Could not find a mapping for dependency \"{}\". Check that it is included in the ``MAPPING_PARAMS``. See docstring for more on ``MAPPING_PARAMS``" .format(dep_name)) diff --git a/deepspeed/inference/v2/model_implementations/opt/container.py b/deepspeed/inference/v2/model_implementations/opt/container.py index 5b1c9ce4c8a3..5ddbbde3f141 100644 --- a/deepspeed/inference/v2/model_implementations/opt/container.py +++ b/deepspeed/inference/v2/model_implementations/opt/container.py @@ -87,9 +87,8 @@ class OPTNonTransformerContainer(LayerContainer): final_norm_b: NormParameter PARAM_MAPPING = { - "model.decoder.embed_tokens.weight": "word_emb.params", - "model.decoder.embed_positions.weight": "word_emb_pos.params", - "model.decoder.final_layer_norm.weight": "final_norm_w.params", - "model.decoder.final_layer_norm.bias": "final_norm_b.params", - "lm_head.weight": "word_unembed.params", + "*decoder.embed_tokens.weight": ["word_emb.params", "word_unembed.params"], + "*decoder.embed_positions.weight": "word_emb_pos.params", + "*decoder.final_layer_norm.weight": "final_norm_w.params", + "*decoder.final_layer_norm.bias": "final_norm_b.params", } diff --git a/deepspeed/inference/v2/model_implementations/opt/policy.py b/deepspeed/inference/v2/model_implementations/opt/policy.py index 002fab93b462..af5750260ead 100644 --- a/deepspeed/inference/v2/model_implementations/opt/policy.py +++ b/deepspeed/inference/v2/model_implementations/opt/policy.py @@ -21,10 +21,10 @@ def build_container_map(self) -> ContainerMap: transformer_containers = [OPTTransformerContainer(self.model) for _ in range(self.model.num_layers)] - map.set_transformer_params(['model.decoder.layers'], transformer_containers) + map.set_transformer_params(['model.decoder.layers', 'decoder.layers'], transformer_containers) map.set_non_transformer_params(OPTNonTransformerContainer(self.model)) - map.set_unmapped_params([]) + map.set_unmapped_params(['lm_head.weight']) return map