Skip to content

Commit

Permalink
Merge branch 'master' into lyj/chatglm2
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jul 20, 2024
2 parents b026e4a + 879c6cd commit 7beff94
Show file tree
Hide file tree
Showing 16 changed files with 356 additions and 14 deletions.
1 change: 1 addition & 0 deletions blogs/deepspeed-fastgen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ We currently support the following model architectures in this alpha release of
* [Falcon](https://huggingface.co/models?other=falcon)
* [Mixtral](https://huggingface.co/models?other=mixtral)
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Phi-3](https://huggingface.co/models?other=phi3)
* [Qwen](https://huggingface.co/models?other=qwen)

All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.
Expand Down
3 changes: 3 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MixtralPolicy,
FalconPolicy,
PhiPolicy,
Phi3Policy,
QwenPolicy,
Qwen2Policy,
)
Expand Down Expand Up @@ -119,6 +120,8 @@ def build_hf_engine(path: str,
policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi":
policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "phi3":
policy = Phi3Policy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen":
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BlockedRotaryEmbeddings(DSKernelBase):
"""

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 80, 128]
supported_head_sizes = [64, 80, 96, 128]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \
} else if (head_size == 80) { \
LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \
} else if (head_size == 96) { \
LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 96); \
} else if (head_size == 128) { \
LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \
} else { \
Expand Down Expand Up @@ -326,6 +328,8 @@ INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16)
LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \
} else if (head_size == 80) { \
LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \
} else if (head_size == 96) { \
LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 96); \
} else if (head_size == 128) { \
LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \
} else { \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BlockedTrainedRotaryEmbeddings(DSKernelBase):
"""

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 80, 128]
supported_head_sizes = [64, 80, 96, 128]
supported_q_ratios = [1, 2, 4, 5, 8]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class LinearBlockedKVCopy(DSKernelBase):
"""

supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
supported_head_sizes = [64, 80, 128]
supported_head_sizes = [64, 80, 96, 128]
supported_q_ratios = [1, 2, 4, 5, 8]

def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None:
Expand Down
1 change: 1 addition & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
from .mixtral import *
from .falcon import *
from .phi import *
from .phi3 import *
from .qwen import *
from .qwen_v2 import *
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,24 @@ def finalize(self) -> torch.Tensor:
return self.inference_model.transform_mlp_1_param(fused_param)


class FusedGatedMLPParameter(ParameterBase):
"""
Gated MLP projection container.
"""

params: torch.Tensor
"""
Weight parameter for the fused gating and non-gating weight parameters.
"""

def finalize(self) -> torch.Tensor:
gate_params = self.params[:self.params.shape[0] // 2]
up_params = self.params[self.params.shape[0] // 2:]
total_neurons = gate_params.shape[0] + up_params.shape[0]
fused_param = torch.cat([gate_params, up_params], dim=-1).reshape(total_neurons, -1)
return self.inference_model.transform_mlp_1_param(fused_param)


class MLP2Parameter(ParameterBase):
"""
Second MLP projection weight container. This performs a straight pass-through to the
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/inference/v2/model_implementations/phi3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .policy import Phi3Policy
75 changes: 75 additions & 0 deletions deepspeed/inference/v2/model_implementations/phi3/containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Create a container object to save model-specific tensors using the policy file above.

from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Phi-3 model looks like this:
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 3072)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): PhiMLP(
(gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=3072, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
(resid_attn_dropout): Dropout(p=0.0)
(resid_mlp_dropout): Dropout(p=0.0)
(post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
)
)
(final_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
)
(lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
'''


class Phi3TransformerContainer(LayerContainer):
"""
Transformer layer container for the Phi model.
"""
qkv_w: FusedQKVParameter
attn_out_w: AttentionOutputParameter
mlp_1_w: FusedGatedMLPParameter
mlp_2_w: MLP2Parameter
attn_norm_gamma: NormParameter
mlp_norm_gamma: NormParameter

PARAM_MAPPING = {
"self_attn.qkv_proj.weight": "qkv_w.params",
"self_attn.o_proj.weight": "attn_out_w.params",
"mlp.gate_up_proj.weight": "mlp_1_w.params",
"mlp.down_proj.weight": "mlp_2_w.params",
"input_layernorm.weight": "attn_norm_gamma.params",
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
}


class Phi3NonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Phi model.
"""
word_emb: EmbeddingParameter
word_unembed_w: UnembedParameter
final_norm_gamma: NormParameter

PARAM_MAPPING = {
"model.embed_tokens.weight": "word_emb.params",
"model.norm.weight": "final_norm_gamma.params",
"lm_head.weight": "word_unembed_w.params",
}
204 changes: 204 additions & 0 deletions deepspeed/inference/v2/model_implementations/phi3/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Optional, Tuple

import torch

import deepspeed.comm as dist

from ...allocator import empty_from
from ...inference_utils import ActivationType, DtypeEnum
from .. import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...ragged import RaggedBatchWrapper

from .containers import Phi3NonTransformerContainer, Phi3TransformerContainer


class Phi3InferenceModel(DSTransformerModelBase):
"""
Inference model implementation for ragged batching for Llama-2 models.
"""

_non_transformer: Optional[Phi3NonTransformerContainer]
"""
Embed + unembed container. Specializing the type annotation.
"""

_transformer: Optional[Iterable[Phi3TransformerContainer]]
"""
Per-layer transformer container. Specializing the type annotation.
"""
"""
Properties inherited from `DSInferenceModelBase`
"""

@property
def max_sequence_length(self) -> int:
return self._config.max_seq_length

"""
Properties inherited from `DSTransformerModelBase`
"""

@property
def num_layers(self) -> int:
return self._config.num_hidden_layers

@property
def model_dim(self) -> int:
return self._config.hidden_size

@property
def vocab_size(self) -> int:
return self._config.vocab_size

@property
def head_size(self) -> int:
return self.model_dim // self.n_heads

@property
def n_heads(self) -> int:
return self._config.num_attention_heads

@property
def intermediate_dim(self) -> int:
return self._config.intermediate_size

@property
def n_heads_kv(self) -> int:
return self._config.num_key_value_heads

@property
def activation_dtype(self) -> DtypeEnum:
if self._config.torch_dtype == torch.float16:
return DtypeEnum.fp16
elif self._config.torch_dtype == torch.bfloat16:
return DtypeEnum.bf16
else:
raise NotImplementedError("Only fp16 and bf16 are supported")

@property
def mlp_activation_fn(self) -> ActivationType:
activation = self._config.hidden_act.lower()
if activation == "gelu":
return ActivationType.GEGLU
elif activation == "relu":
return ActivationType.ReGLU
elif activation == "gegelu":
return ActivationType.GEGLU
elif activation == "silu":
return ActivationType.SiGLU
else:
raise NotImplementedError(f"Activation {activation} not supported")

@property
def norm_type(self) -> NormTypeEnum:
return NormTypeEnum.RMSNorm

@property
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_theta)

"""
Forward implementations
"""

def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs the embedding lookup prior to running the transformer of the model.
Arguments:
ragged_batch (RaggedBatchWrapper): The batch to embed.
Returns:
torch.Tensor: The embedded batch.
"""
embed = self.embed(ragged_batch, self._non_transformer.word_emb)

if embed.shape[-1] != self.model_dim:
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")

return embed

def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead
optimization to fuse the layer norm of the next layer into the current layer.
Arguments:
layer_idx (int): The index of the layer to execute.
residual (torch.Tensor): The residual tensor from the previous layer.
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the
hidden states after pre normalization.
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
"""
cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None)
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info)
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None)

hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None)
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

if layer_idx != self.num_layers - 1:
next_params = self._transformer[layer_idx + 1]
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None)
else:
# On last layer, we just need to perform the residual add. Adding into the residual
# here is safe.
residual.add_(hidden_states)

return residual, hidden_states

def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed_w,
ragged_batch_info,
gamma=self._non_transformer.final_norm_gamma)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))

dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)

full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))

return full_logits
else:
return logits

def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
residual = self._forward_embed(wrapped_batch)

residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=None)

for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
wrapped_batch)

return self._forward_unembed(residual, wrapped_batch)
Loading

0 comments on commit 7beff94

Please sign in to comment.