Skip to content

Commit

Permalink
Merge branch 'master' into gma/launch_opbuilder_detection
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Mar 2, 2024
2 parents 87367e1 + bcc617a commit 52fc101
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 12 deletions.
3 changes: 3 additions & 0 deletions deepspeed/inference/v2/engine_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FalconPolicy,
PhiPolicy,
QwenPolicy,
Qwen2Policy,
)
from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy
from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata
Expand Down Expand Up @@ -120,6 +121,8 @@ def build_hf_engine(path: str,
policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen":
policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine)
elif model_config.model_type == "qwen2":
policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine)
else:
raise ValueError(f"Unsupported model type {model_config.model_type}")

Expand Down
1 change: 1 addition & 0 deletions deepspeed/inference/v2/model_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .falcon import *
from .phi import *
from .qwen import *
from .qwen_v2 import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .policy import Qwen2Policy
82 changes: 82 additions & 0 deletions deepspeed/inference/v2/model_implementations/qwen_v2/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# Create a container object to save model-specific tensors using the policy file above.

from ..common_parameters import *
from ..layer_container_base import LayerContainer
'''
# HF Qwen2 model looks like this:
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(151936, 1024)
(layers): ModuleList(
(0-23): 24 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(o_proj): Linear(in_features=1024, out_features=1024, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
(up_proj): Linear(in_features=1024, out_features=2816, bias=False)
(down_proj): Linear(in_features=2816, out_features=1024, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm()
(post_attention_layernorm): Qwen2RMSNorm()
)
)
(norm): Qwen2RMSNorm()
)
(lm_head): Linear(in_features=1024, out_features=151936, bias=False)
)
'''


class Qwen2TransformerContainer(LayerContainer):
"""
Transformer layer container for the Qwen2 model.
"""
qkv_w: UnfusedQKVParameter
qkv_b: UnfusedQKVParameter
attn_out_w: AttentionOutputParameter
mlp_1_w: GatedMLPParameter
mlp_2_w: MLP2Parameter
attn_norm_gamma: NormParameter
mlp_norm_gamma: NormParameter

PARAM_MAPPING = {
"self_attn.q_proj.weight": "qkv_w.q_params",
"self_attn.k_proj.weight": "qkv_w.k_params",
"self_attn.v_proj.weight": "qkv_w.v_params",
"self_attn.q_proj.bias": "qkv_b.q_params",
"self_attn.k_proj.bias": "qkv_b.k_params",
"self_attn.v_proj.bias": "qkv_b.v_params",
"self_attn.o_proj.weight": "attn_out_w.params",
"mlp.gate_proj.weight": "mlp_1_w.gate_params",
"mlp.up_proj.weight": "mlp_1_w.up_params",
"mlp.down_proj.weight": "mlp_2_w.params",
"input_layernorm.weight": "attn_norm_gamma.params",
"post_attention_layernorm.weight": "mlp_norm_gamma.params",
}


class Qwen2NonTransformerContainer(LayerContainer):
"""
Non-Transformer layer container for the Qwen2 model.
"""
word_emb: EmbeddingParameter
word_unembed: UnembedParameter
final_norm: NormParameter

PARAM_MAPPING = {
"model.embed_tokens.weight": "word_emb.params",
"model.norm.weight": "final_norm.params",
"lm_head.weight": "word_unembed.params",
}
221 changes: 221 additions & 0 deletions deepspeed/inference/v2/model_implementations/qwen_v2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Iterable, Optional, Tuple

import torch

import deepspeed.comm as dist

from ...allocator import empty_from
from ...inference_utils import ActivationType, DtypeEnum
from .. import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...modules import heuristics
from ...ragged import RaggedBatchWrapper

from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer


class Qwen2InferenceModel(DSTransformerModelBase):
"""
Inference model implementation for ragged batching for Llama-2 models.
"""

_non_transformer: Optional[Qwen2NonTransformerContainer]
"""
Embed + unembed container. Specializing the type annotation.
"""

_transformer: Optional[Iterable[Qwen2TransformerContainer]]
"""
Per-layer transformer container. Specializing the type annotation.
"""
"""
Properties ineherited from `DSInferenceModelBase`
"""

@property
def max_sequence_length(self) -> int:
return self._config.max_seq_length

"""
Properties ineherited from `DSTransformerModelBase`
"""

@property
def num_layers(self) -> int:
return self._config.num_hidden_layers

@property
def model_dim(self) -> int:
return self._config.hidden_size

@property
def vocab_size(self) -> int:
return self._config.vocab_size

@property
def head_size(self) -> int:
return self.model_dim // self.n_heads

@property
def n_heads(self) -> int:
return self._config.num_attention_heads

@property
def intermediate_dim(self) -> int:
return self._config.intermediate_size

@property
def n_heads_kv(self) -> int:
return self._config.num_key_value_heads

@property
def activation_dtype(self) -> DtypeEnum:
# TODO(ZonePG): bf16 inference results may be different from huggingface bf16,
# because in rms_norm, Qwen still use float() instead of bf16
# if self._config.torch_dtype == torch.float16:
# return DtypeEnum.fp16
# elif self._config.torch_dtype == torch.bfloat16:
# return DtypeEnum.bf16
# else:
# raise NotImplementedError("Only fp16 and bf16 are supported")
return DtypeEnum.fp16

@property
def mlp_activation_fn(self) -> ActivationType:
return ActivationType.SiGLU

@property
def norm_type(self) -> NormTypeEnum:
return NormTypeEnum.RMSNorm

@property
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_theta)

def make_norm_layer(self) -> None:
"""
Instantiates the normalization layer for the model. This sets the `self.norm` attribute.
TODO(cmikeh2): In the future we'll distinguish between the different norm objects,
but for now we'll just use the same one for all of them.
"""
norm_config = DSNormConfig(
max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
type=self.norm_type,
channels=self.model_dim,
residual_dtype=self.activation_dtype,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
eps=self._config.rms_norm_eps,
)

self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config)

"""
Forward implementations
"""

def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs the embedding lookup prior to running the transformer of the model.
Arguments:
ragged_batch (RaggedBatchWrapper): The batch to embed.
Returns:
torch.Tensor: The embedded batch.
"""
embed = self.embed(ragged_batch, self._non_transformer.word_emb)

if embed.shape[-1] != self.model_dim:
raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}")

return embed

def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor,
ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead
optimization to fuse the layer norm of the next layer into the current layer.
Arguments:
layer_idx (int): The index of the layer to execute.
residual (torch.Tensor): The residual tensor from the previous layer.
hidden_states (torch.Tensor): The hidden states from the previous layer. This is the
hidden states after pre normalization.
ragged_batch_info (RaggedBatchWrapper): The batch metadata.
"""
# TODO(cmikeh2): Distribute ragged_batch_info to all modules

cur_params = self._transformer[layer_idx]
kv_cache = self.state_manager.get_cache(layer_idx)

hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b)
hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info)
hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None)

# Should be configurable in the future
hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None)
hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None)

if self.tp_size > 1:
dist.all_reduce(hidden_states, group=self._base_mp_group)

if layer_idx != self.num_layers - 1:
next_params = self._transformer[layer_idx + 1]
residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None)
else:
# On last layer, we just need to perform the residual add. Adding into the residual
# here is safe.
residual.add_(hidden_states)

return residual, hidden_states

def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size))

dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group)

full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size))

return full_logits
else:
return logits

def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:

residual = self._forward_embed(wrapped_batch)

residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None)

for layer_idx in range(self.num_layers):
residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,
wrapped_batch)

return self._forward_unembed(residual, wrapped_batch)
31 changes: 31 additions & 0 deletions deepspeed/inference/v2/model_implementations/qwen_v2/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from typing import Any

from ...config_v2 import RaggedInferenceEngineConfig
from ..inference_policy_base import ContainerMap, InferenceV2Policy
from .container import Qwen2NonTransformerContainer, Qwen2TransformerContainer
from .model import Qwen2InferenceModel


class Qwen2Policy(InferenceV2Policy):

def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2InferenceModel:
return Qwen2InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group)

def build_container_map(self) -> ContainerMap:
map = ContainerMap()

transformer_containers = [Qwen2TransformerContainer(self.model) for _ in range(self.model.num_layers)]

map.set_transformer_params(['model.layers'], transformer_containers)

map.set_non_transformer_params(Qwen2NonTransformerContainer(self.model))

map.set_unmapped_params(
[f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)])

return map
21 changes: 9 additions & 12 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,19 +1635,16 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
secondary_end = secondary_start + secondary_partition_size

one_dim_param = param.contiguous().view(-1)
start = partition_size * self.rank
end = start + partition_size
if start < param.ds_numel and end <= param.ds_numel:
if secondary_start < param.ds_numel and secondary_end <= param.ds_numel:
sec_src_tensor = one_dim_param.narrow(0, secondary_start, secondary_partition_size)
param.ds_secondary_tensor.copy_(sec_src_tensor)

else:
if start < param.ds_numel:
elements_to_copy = param.ds_numel - start
elements_to_copy_sec = elements_to_copy * param.ds_secondary_tensor_num_of_groups
param.ds_secondary_tensor.narrow(0, 0, elements_to_copy_sec).copy_(
one_dim_param.narrow(0, secondary_start, elements_to_copy_sec))
# ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size
sec_numel = param.ds_numel - secondary_start if secondary_end > param.ds_numel else secondary_partition_size

# copy from full tensor to secondary tensor
param.ds_secondary_tensor.narrow(0, 0,
sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel))

# TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done
get_accelerator().current_stream().synchronize()

print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}",
force=False)
Expand Down

0 comments on commit 52fc101

Please sign in to comment.