Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Phi-3 small to FastGen #5614

Merged
merged 10 commits into from
Nov 15, 2024
3 changes: 3 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FalconPolicy,
PhiPolicy,
Phi3Policy,
Phi3SmallPolicy,
QwenPolicy,
Qwen2Policy,
Qwen2MoePolicy,
Expand Down Expand Up @@ -123,6 +124,8 @@ def build_hf_engine(path: str,
policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi3":
policy = Phi3Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi3small":
policy = Phi3SmallPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen":
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2":
Expand Down
1 change: 1 addition & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .falcon import *
from .phi import *
from .phi3 import *
from .phi3small import *
from .qwen import *
from .qwen_v2 import *
from .qwen_v2_moe import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .policy import Phi3SmallPolicy
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Create a container object to save model-specific tensors using the policy file above.

from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Phi-3 model looks like this:
Phi3SmallForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 3072)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): PhiMLP(
(gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=3072, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
(resid_attn_dropout): Dropout(p=0.0)
(resid_mlp_dropout): Dropout(p=0.0)
(post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
)
)
(final_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
)
(lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
'''


class Phi3SmallTransformerContainer(LayerContainer):
"""
Transformer layer container for the Phi model.
"""
qkv_w: FusedQKVParameter
qkv_b: FusedQKVParameter
attn_out_w: AttentionOutputParameter
attn_out_b: AttentionOutputParameter
mlp_1_w: MLP1Parameter
mlp_1_b: MLP1Parameter
mlp_2_w: MLP2Parameter
mlp_2_b: MLP2Parameter
attn_norm_gamma: NormParameter
attn_norm_beta: NormParameter
mlp_norm_gamma: NormParameter
mlp_norm_beta: NormParameter

PARAM_MAPPING = {
"self_attn.query_key_value.weight": "qkv_w.params",
"self_attn.query_key_value.bias": "qkv_b.params",
"self_attn.dense.weight": "attn_out_w.params",
"self_attn.dense.bias": "attn_out_b.params",
"mlp.up_proj.weight": "mlp_1_w.params",
"mlp.up_proj.bias": "mlp_1_b.params",
"mlp.down_proj.weight": "mlp_2_w.params",
"mlp.down_proj.bias": "mlp_2_b.params",
"input_layernorm.weight": "attn_norm_gamma.params",
"input_layernorm.bias": "attn_norm_beta.params",
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
"post_attention_layernorm.bias": "mlp_norm_beta.params",
}


class Phi3SmallNonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Phi model.
"""
word_emb: EmbeddingParameter
final_norm_gamma: NormParameter
final_norm_beta: NormParameter
word_unembed: UnembedParameter

PARAM_MAPPING = {
"model.embed_tokens.weight": ["word_emb.params", "word_unembed.params"],
"model.final_layernorm.weight": "final_norm_gamma.params",
"model.final_layernorm.bias": "final_norm_beta.params",
}
222 changes: 222 additions & 0 deletions deepspeed/inference/v2/model_implementations/phi3small/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Optional, Tuple

import torch

import deepspeed.comm as dist

from ...allocator import empty_from
from ...inference_utils import ActivationType, DtypeEnum
from .. import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...ragged import RaggedBatchWrapper

from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer


class Phi3SmallInferenceModel(DSTransformerModelBase):
"""
Inference model implementation for ragged batching for Llama-2 models.
"""

_non_transformer: Optional[Phi3SmallNonTransformerContainer]
"""
Embed + unembed container. Specializing the type annotation.
"""

_transformer: Optional[Iterable[Phi3SmallTransformerContainer]]
"""
Per-layer transformer container. Specializing the type annotation.
"""
"""
Properties inherited from `DSInferenceModelBase`
"""

@property
def max_sequence_length(self) -> int:
return self._config.max_seq_length

"""
Properties inherited from `DSTransformerModelBase`
"""

@property
def num_layers(self) -> int:
return self._config.num_hidden_layers

@property
def model_dim(self) -> int:
return self._config.hidden_size

@property
def vocab_size(self) -> int:
return self._config.vocab_size

@property
def head_size(self) -> int:
return self.model_dim // self.n_heads

@property
def n_heads(self) -> int:
return self._config.num_attention_heads

@property
def intermediate_dim(self) -> int:
return self._config.intermediate_size

@property
def n_heads_kv(self) -> int:
return self._config.num_key_value_heads

@property
def activation_dtype(self) -> DtypeEnum:
if self._config.torch_dtype == torch.float16:
return DtypeEnum.fp16
elif self._config.torch_dtype == torch.bfloat16:
return DtypeEnum.bf16
else:
raise NotImplementedError("Only fp16 and bf16 are supported")

@property
def mlp_activation_fn(self) -> ActivationType:
activation = self._config.hidden_act.lower()
if activation == "gelu":
return ActivationType.GEGLU
elif activation == "relu":
return ActivationType.ReGLU
elif activation == "gegelu":
return ActivationType.GEGLU
elif activation == "silu":
return ActivationType.SiGLU
else:
raise NotImplementedError(f"Activation {activation} not supported")

@property
def norm_type(self) -> NormTypeEnum:
return NormTypeEnum.LayerNorm

@property
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_embedding_base)

@property
def mup_embedding_multiplier(self) -> float:
return 10.0

"""
Forward implementations
"""

def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs the embedding lookup prior to running the transformer of the model.
Arguments:
ragged_batch (RaggedBatchWrapper): The batch to embed.
Returns:
torch.Tensor: The embedded batch.
"""
embed = self.embed(ragged_batch, self._non_transformer.word_emb)

if embed.shape[-1] != self.model_dim:
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")

if self.mup_embedding_multiplier > 0.0:
embed = embed * self.mup_embedding_multiplier

return embed

def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead
optimization to fuse the layer norm of the next layer into the current layer.
Arguments:
layer_idx (int): The index of the layer to execute.
residual (torch.Tensor): The residual tensor from the previous layer.
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the
hidden states after pre normalization.
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
"""
cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b)
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info)
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

residual, hidden_states = self.norm(residual,
hidden_states,
cur_params.mlp_norm_gamma,
beta=cur_params.mlp_norm_beta)

hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None)
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

if layer_idx != self.num_layers - 1:
next_params = self._transformer[layer_idx + 1]
residual, hidden_states = self.norm(residual,
hidden_states,
next_params.attn_norm_gamma,
beta=next_params.attn_norm_beta)
else:
# On last layer, we just need to perform the residual add. Adding into the residual
# here is safe.
residual.add_(hidden_states)

return residual, hidden_states

def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""

logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm_gamma,
beta=self._non_transformer.final_norm_beta)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))

dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)

full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))

return full_logits
else:
return logits

def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
residual = self._forward_embed(wrapped_batch)

residual, hidden_states = self.norm(residual,
None,
gamma=self._transformer[0].attn_norm_gamma,
beta=self._transformer[0].attn_norm_beta)

for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
wrapped_batch)

return self._forward_unembed(residual, wrapped_batch)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Any

from ...config_v2 import RaggedInferenceEngineConfig
from ..inference_policy_base import ContainerMap, InferenceV2Policy
from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer
from .model import Phi3SmallInferenceModel


class Phi3SmallPolicy(InferenceV2Policy):

def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Phi3SmallInferenceModel:
return Phi3SmallInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group)

def build_container_map(self) -> ContainerMap:
map = ContainerMap()

transformer_containers = [Phi3SmallTransformerContainer(self.model) for _ in range(self.model.num_layers)]

map.set_transformer_params(['model.layers'], transformer_containers)

map.set_non_transformer_params(Phi3SmallNonTransformerContainer(self.model))

map.set_unmapped_params([])

return map