Skip to content

Commit

Permalink
fixes for different OPT model size variants
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Nov 16, 2023
1 parent 25e9cd5 commit 6c6c575
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 11 deletions.
4 changes: 4 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def build_hf_engine(path: str,
# get the policy
# TODO: generalize this to other models
if model_config.model_type == "opt":
if not model_config.do_layer_norm_before:
raise ValueError(
"Detected OPT-350m model. This model is not currently supported. If this is not the 350m model, please open an issue: https://github.com/microsoft/DeepSpeed-MII/issues"
)
policy = OPTPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "llama":
policy = Llama2Policy(model_config, checkpoint_engine=checkpoint_engine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import re
from typing import Type

import torch
Expand Down Expand Up @@ -277,11 +278,32 @@ def set_dependency(self, dep_name: str, dep_value: torch.Tensor) -> None:
dep_name (str): The name of the dependency to set.
dep_value (torch.Tensor): The value to set the dependency to.
"""
if dep_name in self.mapping_params:
# If we have an exact match, it's a direct mapping and we can immediately set
# the value.
target = self.mapping_params[dep_name]

def get_dep_name_target(dep_name: str) -> str:
"""
Helper method for getting the target name for a dependency from the
mapping params. Tries to match exact string first, then looks for
wildcards and attempts regex matching. Will return empty string if
no match found.
"""
if dep_name in self.mapping_params:
# If we have an exact match, it's a direct mapping and we can
# immediately set the value.
return self.mapping_params[dep_name]

matched_targets = []
for key, target in self.mapping_params.items():
regex_key = key.replace("*", ".*")
if re.match(regex_key, dep_name):
matched_targets.append(target)
if len(matched_targets) > 1:
raise ValueError(f"Multiple targets matched for dependency {dep_name}: {matched_targets}")
if matched_targets:
return matched_targets[0]
return ""

target = get_dep_name_target(dep_name)
if target:
# Convert single targets to a list for consistency
if isinstance(target, str):
target = [target]
Expand Down
9 changes: 4 additions & 5 deletions deepspeed/inference/v2/model_implementations/opt/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ class OPTNonTransformerContainer(LayerContainer):
final_norm_b: NormParameter

PARAM_MAPPING = {
"model.decoder.embed_tokens.weight": "word_emb.params",
"model.decoder.embed_positions.weight": "word_emb_pos.params",
"model.decoder.final_layer_norm.weight": "final_norm_w.params",
"model.decoder.final_layer_norm.bias": "final_norm_b.params",
"lm_head.weight": "word_unembed.params",
"*decoder.embed_tokens.weight": ["word_emb.params", "word_unembed.params"],
"*decoder.embed_positions.weight": "word_emb_pos.params",
"*decoder.final_layer_norm.weight": "final_norm_w.params",
"*decoder.final_layer_norm.bias": "final_norm_b.params",
}
4 changes: 2 additions & 2 deletions deepspeed/inference/v2/model_implementations/opt/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def build_container_map(self) -> ContainerMap:

transformer_containers = [OPTTransformerContainer(self.model) for _ in range(self.model.num_layers)]

map.set_transformer_params(['model.decoder.layers'], transformer_containers)
map.set_transformer_params(['model.decoder.layers', 'decoder.layers'], transformer_containers)

map.set_non_transformer_params(OPTNonTransformerContainer(self.model))

map.set_unmapped_params([])
map.set_unmapped_params(['lm_head.weight'])

return map

0 comments on commit 6c6c575

Please sign in to comment.