-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into gma/launch_opbuilder_detection
- Loading branch information
Showing
7 changed files
with
353 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,4 @@ | |
from .falcon import * | ||
from .phi import * | ||
from .qwen import * | ||
from .qwen_v2 import * |
6 changes: 6 additions & 0 deletions
6
deepspeed/inference/v2/model_implementations/qwen_v2/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from .policy import Qwen2Policy |
82 changes: 82 additions & 0 deletions
82
deepspeed/inference/v2/model_implementations/qwen_v2/container.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
# Create a container object to save model-specific tensors using the policy file above. | ||
|
||
from ..common_parameters import * | ||
from ..layer_container_base import LayerContainer | ||
''' | ||
# HF Qwen2 model looks like this: | ||
Qwen2ForCausalLM( | ||
(model): Qwen2Model( | ||
(embed_tokens): Embedding(151936, 1024) | ||
(layers): ModuleList( | ||
(0-23): 24 x Qwen2DecoderLayer( | ||
(self_attn): Qwen2SdpaAttention( | ||
(q_proj): Linear(in_features=1024, out_features=1024, bias=True) | ||
(k_proj): Linear(in_features=1024, out_features=1024, bias=True) | ||
(v_proj): Linear(in_features=1024, out_features=1024, bias=True) | ||
(o_proj): Linear(in_features=1024, out_features=1024, bias=False) | ||
(rotary_emb): Qwen2RotaryEmbedding() | ||
) | ||
(mlp): Qwen2MLP( | ||
(gate_proj): Linear(in_features=1024, out_features=2816, bias=False) | ||
(up_proj): Linear(in_features=1024, out_features=2816, bias=False) | ||
(down_proj): Linear(in_features=2816, out_features=1024, bias=False) | ||
(act_fn): SiLU() | ||
) | ||
(input_layernorm): Qwen2RMSNorm() | ||
(post_attention_layernorm): Qwen2RMSNorm() | ||
) | ||
) | ||
(norm): Qwen2RMSNorm() | ||
) | ||
(lm_head): Linear(in_features=1024, out_features=151936, bias=False) | ||
) | ||
''' | ||
|
||
|
||
class Qwen2TransformerContainer(LayerContainer): | ||
""" | ||
Transformer layer container for the Qwen2 model. | ||
""" | ||
qkv_w: UnfusedQKVParameter | ||
qkv_b: UnfusedQKVParameter | ||
attn_out_w: AttentionOutputParameter | ||
mlp_1_w: GatedMLPParameter | ||
mlp_2_w: MLP2Parameter | ||
attn_norm_gamma: NormParameter | ||
mlp_norm_gamma: NormParameter | ||
|
||
PARAM_MAPPING = { | ||
"self_attn.q_proj.weight": "qkv_w.q_params", | ||
"self_attn.k_proj.weight": "qkv_w.k_params", | ||
"self_attn.v_proj.weight": "qkv_w.v_params", | ||
"self_attn.q_proj.bias": "qkv_b.q_params", | ||
"self_attn.k_proj.bias": "qkv_b.k_params", | ||
"self_attn.v_proj.bias": "qkv_b.v_params", | ||
"self_attn.o_proj.weight": "attn_out_w.params", | ||
"mlp.gate_proj.weight": "mlp_1_w.gate_params", | ||
"mlp.up_proj.weight": "mlp_1_w.up_params", | ||
"mlp.down_proj.weight": "mlp_2_w.params", | ||
"input_layernorm.weight": "attn_norm_gamma.params", | ||
"post_attention_layernorm.weight": "mlp_norm_gamma.params", | ||
} | ||
|
||
|
||
class Qwen2NonTransformerContainer(LayerContainer): | ||
""" | ||
Non-Transformer layer container for the Qwen2 model. | ||
""" | ||
word_emb: EmbeddingParameter | ||
word_unembed: UnembedParameter | ||
final_norm: NormParameter | ||
|
||
PARAM_MAPPING = { | ||
"model.embed_tokens.weight": "word_emb.params", | ||
"model.norm.weight": "final_norm.params", | ||
"lm_head.weight": "word_unembed.params", | ||
} |
221 changes: 221 additions & 0 deletions
221
deepspeed/inference/v2/model_implementations/qwen_v2/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Iterable, Optional, Tuple | ||
|
||
import torch | ||
|
||
import deepspeed.comm as dist | ||
|
||
from ...allocator import empty_from | ||
from ...inference_utils import ActivationType, DtypeEnum | ||
from .. import * | ||
from ...modules.configs import * | ||
from ...modules.interfaces import * | ||
from ...modules import heuristics | ||
from ...ragged import RaggedBatchWrapper | ||
|
||
from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer | ||
|
||
|
||
class Qwen2InferenceModel(DSTransformerModelBase): | ||
""" | ||
Inference model implementation for ragged batching for Llama-2 models. | ||
""" | ||
|
||
_non_transformer: Optional[Qwen2NonTransformerContainer] | ||
""" | ||
Embed + unembed container. Specializing the type annotation. | ||
""" | ||
|
||
_transformer: Optional[Iterable[Qwen2TransformerContainer]] | ||
""" | ||
Per-layer transformer container. Specializing the type annotation. | ||
""" | ||
""" | ||
Properties ineherited from `DSInferenceModelBase` | ||
""" | ||
|
||
@property | ||
def max_sequence_length(self) -> int: | ||
return self._config.max_seq_length | ||
|
||
""" | ||
Properties ineherited from `DSTransformerModelBase` | ||
""" | ||
|
||
@property | ||
def num_layers(self) -> int: | ||
return self._config.num_hidden_layers | ||
|
||
@property | ||
def model_dim(self) -> int: | ||
return self._config.hidden_size | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
return self._config.vocab_size | ||
|
||
@property | ||
def head_size(self) -> int: | ||
return self.model_dim // self.n_heads | ||
|
||
@property | ||
def n_heads(self) -> int: | ||
return self._config.num_attention_heads | ||
|
||
@property | ||
def intermediate_dim(self) -> int: | ||
return self._config.intermediate_size | ||
|
||
@property | ||
def n_heads_kv(self) -> int: | ||
return self._config.num_key_value_heads | ||
|
||
@property | ||
def activation_dtype(self) -> DtypeEnum: | ||
# TODO(ZonePG): bf16 inference results may be different from huggingface bf16, | ||
# because in rms_norm, Qwen still use float() instead of bf16 | ||
# if self._config.torch_dtype == torch.float16: | ||
# return DtypeEnum.fp16 | ||
# elif self._config.torch_dtype == torch.bfloat16: | ||
# return DtypeEnum.bf16 | ||
# else: | ||
# raise NotImplementedError("Only fp16 and bf16 are supported") | ||
return DtypeEnum.fp16 | ||
|
||
@property | ||
def mlp_activation_fn(self) -> ActivationType: | ||
return ActivationType.SiGLU | ||
|
||
@property | ||
def norm_type(self) -> NormTypeEnum: | ||
return NormTypeEnum.RMSNorm | ||
|
||
@property | ||
def positional_embedding_type(self) -> PositionalEmbeddingType: | ||
return PositionalEmbeddingType.rotate_half | ||
|
||
@property | ||
def positional_embedding_config(self) -> Optional[RotateHalfConfig]: | ||
return RotateHalfConfig(theta_base=self._config.rope_theta) | ||
|
||
def make_norm_layer(self) -> None: | ||
""" | ||
Instantiates the normalization layer for the model. This sets the `self.norm` attribute. | ||
TODO(cmikeh2): In the future we'll distinguish between the different norm objects, | ||
but for now we'll just use the same one for all of them. | ||
""" | ||
norm_config = DSNormConfig( | ||
max_tokens=self._engine_config.state_manager.max_ragged_batch_size, | ||
type=self.norm_type, | ||
channels=self.model_dim, | ||
residual_dtype=self.activation_dtype, | ||
input_dtype=self.activation_dtype, | ||
output_dtype=self.activation_dtype, | ||
eps=self._config.rms_norm_eps, | ||
) | ||
|
||
self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) | ||
|
||
""" | ||
Forward implementations | ||
""" | ||
|
||
def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: | ||
""" | ||
Performs the embedding lookup prior to running the transformer of the model. | ||
Arguments: | ||
ragged_batch (RaggedBatchWrapper): The batch to embed. | ||
Returns: | ||
torch.Tensor: The embedded batch. | ||
""" | ||
embed = self.embed(ragged_batch, self._non_transformer.word_emb) | ||
|
||
if embed.shape[-1] != self.model_dim: | ||
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") | ||
|
||
return embed | ||
|
||
def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, | ||
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead | ||
optimization to fuse the layer norm of the next layer into the current layer. | ||
Arguments: | ||
layer_idx (int): The index of the layer to execute. | ||
residual (torch.Tensor): The residual tensor from the previous layer. | ||
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the | ||
hidden states after pre normalization. | ||
ragged_batch_info (RaggedBatchWrapper): The batch metadata. | ||
""" | ||
# TODO(cmikeh2): Distribute ragged_batch_info to all modules | ||
|
||
cur_params = self._transformer[layer_idx] | ||
kv_cache = self.state_manager.get_cache(layer_idx) | ||
|
||
hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) | ||
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) | ||
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) | ||
|
||
if self.tp_size > 1: | ||
dist.all_reduce(hidden_states, group=self._base_mp_group) | ||
|
||
residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) | ||
|
||
# Should be configurable in the future | ||
hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) | ||
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) | ||
|
||
if self.tp_size > 1: | ||
dist.all_reduce(hidden_states, group=self._base_mp_group) | ||
|
||
if layer_idx != self.num_layers - 1: | ||
next_params = self._transformer[layer_idx + 1] | ||
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) | ||
else: | ||
# On last layer, we just need to perform the residual add. Adding into the residual | ||
# here is safe. | ||
residual.add_(hidden_states) | ||
|
||
return residual, hidden_states | ||
|
||
def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: | ||
""" | ||
Performs unembedding of the hidden states to logits. This will only sample the final | ||
token of each sequence. | ||
""" | ||
logits = self.unembed(hidden_states, | ||
self._non_transformer.word_unembed, | ||
ragged_batch_info, | ||
gamma=self._non_transformer.final_norm) | ||
|
||
if self.tp_size > 1: | ||
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) | ||
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) | ||
|
||
dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) | ||
|
||
full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) | ||
|
||
return full_logits | ||
else: | ||
return logits | ||
|
||
def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: | ||
|
||
residual = self._forward_embed(wrapped_batch) | ||
|
||
residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) | ||
|
||
for layer_idx in range(self.num_layers): | ||
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, | ||
wrapped_batch) | ||
|
||
return self._forward_unembed(residual, wrapped_batch) |
31 changes: 31 additions & 0 deletions
31
deepspeed/inference/v2/model_implementations/qwen_v2/policy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
from typing import Any | ||
|
||
from ...config_v2 import RaggedInferenceEngineConfig | ||
from ..inference_policy_base import ContainerMap, InferenceV2Policy | ||
from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer | ||
from .model import Qwen2InferenceModel | ||
|
||
|
||
class Qwen2Policy(InferenceV2Policy): | ||
|
||
def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2InferenceModel: | ||
return Qwen2InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) | ||
|
||
def build_container_map(self) -> ContainerMap: | ||
map = ContainerMap() | ||
|
||
transformer_containers = [Qwen2TransformerContainer(self.model) for _ in range(self.model.num_layers)] | ||
|
||
map.set_transformer_params(['model.layers'], transformer_containers) | ||
|
||
map.set_non_transformer_params(Qwen2NonTransformerContainer(self.model)) | ||
|
||
map.set_unmapped_params( | ||
[f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) | ||
|
||
return map |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters