From 3c4fca940aa99b5231f3a80a76bad56a81ea428c Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Tue, 4 Jun 2024 18:19:12 +0000 Subject: [PATCH 1/6] Add support for the Phi-3 small model --- deepspeed/inference/v2/engine_factory.py | 3 + .../v2/model_implementations/__init__.py | 1 + .../phi3small/__init__.py | 6 + .../phi3small/containers.py | 87 ++++++++ .../model_implementations/phi3small/model.py | 207 ++++++++++++++++++ .../model_implementations/phi3small/policy.py | 30 +++ 6 files changed, 334 insertions(+) create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/__init__.py create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/containers.py create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/model.py create mode 100644 deepspeed/inference/v2/model_implementations/phi3small/policy.py diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index c21affb9a0de..895d8bbfae15 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -21,6 +21,7 @@ FalconPolicy, PhiPolicy, Phi3Policy, + Phi3SmallPolicy, QwenPolicy, Qwen2Policy, ) @@ -122,6 +123,8 @@ def build_hf_engine(path: str, policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "phi3": policy = Phi3Policy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "phi3small": + policy = Phi3SmallPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "qwen": policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "qwen2": diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index e4160ab94949..cd27bf495d94 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -16,5 +16,6 @@ from .falcon import * from .phi import * from .phi3 import * +from .phi3small import * from .qwen import * from .qwen_v2 import * diff --git a/deepspeed/inference/v2/model_implementations/phi3small/__init__.py b/deepspeed/inference/v2/model_implementations/phi3small/__init__.py new file mode 100644 index 000000000000..71df721cf135 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Phi3SmallPolicy diff --git a/deepspeed/inference/v2/model_implementations/phi3small/containers.py b/deepspeed/inference/v2/model_implementations/phi3small/containers.py new file mode 100644 index 000000000000..deb31d311627 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/containers.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Phi-3 model looks like this: + +Phi3SmallForCausalLM( + (model): Phi3Model( + (embed_tokens): Embedding(32064, 3072) + (embed_dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-31): 32 x Phi3DecoderLayer( + (self_attn): Phi3Attention( + (o_proj): Linear(in_features=3072, out_features=3072, bias=False) + (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False) + (rotary_emb): Phi3RotaryEmbedding() + ) + (mlp): PhiMLP( + (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False) + (down_proj): Linear(in_features=16384, out_features=3072, bias=False) + (activation_fn): SiLU() + ) + (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + (resid_attn_dropout): Dropout(p=0.0) + (resid_mlp_dropout): Dropout(p=0.0) + (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + ) + ) + (final_layernorm): Phi3RMSNorm((3072,), eps=1e-05) + ) + (lm_head): Linear(in_features=3072, out_features=32064, bias=False) +) +''' + + +class Phi3SmallTransformerContainer(LayerContainer): + """ + Transformer layer container for the Phi model. + """ + qkv_w: FusedQKVParameter + qkv_b: FusedQKVParameter + attn_out_w: AttentionOutputParameter + attn_out_b: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_1_b: MLP1Parameter + mlp_2_w: MLP2Parameter + mlp_2_b: MLP2Parameter + attn_norm_gamma: NormParameter + attn_norm_beta: NormParameter + mlp_norm_gamma: NormParameter + mlp_norm_beta: NormParameter + + PARAM_MAPPING = { + "self_attn.query_key_value.weight": "qkv_w.params", + "self_attn.query_key_value.bias": "qkv_b.params", + "self_attn.dense.weight": "attn_out_w.params", + "self_attn.dense.bias": "attn_out_b.params", + "mlp.up_proj.weight": "mlp_1_w.params", + "mlp.up_proj.bias": "mlp_1_b.params", + "mlp.down_proj.weight": "mlp_2_w.params", + "mlp.down_proj.bias": "mlp_2_b.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "input_layernorm.bias": "attn_norm_beta.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + "post_attention_layernorm.bias": "mlp_norm_beta.params", + } + + +class Phi3SmallNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Phi model. + """ + word_emb: EmbeddingParameter + final_norm_gamma: NormParameter + final_norm_beta: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.final_layernorm.weight": "final_norm_gamma.params", + "model.final_layernorm.bias": "final_norm_beta.params", + } diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py new file mode 100644 index 000000000000..e8c22a108611 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer + + +class Phi3SmallInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[Phi3SmallNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Phi3SmallTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_embedding_base) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta) + + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + word_unembed = torch.empty(self.vocab_size, self.model_dim, dtype=hidden_states.dtype, device=hidden_states.device) + torch.nn.init.xavier_uniform_(word_unembed) + logits = self.unembed(hidden_states, + word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm_gamma, + beta=self._non_transformer.final_norm_beta) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/phi3small/policy.py b/deepspeed/inference/v2/model_implementations/phi3small/policy.py new file mode 100644 index 000000000000..235fa41ac608 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi3small/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .containers import Phi3SmallNonTransformerContainer, Phi3SmallTransformerContainer +from .model import Phi3SmallInferenceModel + + +class Phi3SmallPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Phi3SmallInferenceModel: + return Phi3SmallInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Phi3SmallTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Phi3SmallNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map From 0ce02a9f068a012f1207ad5de9655905e339b44d Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Fri, 7 Jun 2024 21:18:05 +0000 Subject: [PATCH 2/6] Tie lm_head to embedding weights --- .../v2/model_implementations/phi3small/containers.py | 3 ++- .../inference/v2/model_implementations/phi3small/model.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/inference/v2/model_implementations/phi3small/containers.py b/deepspeed/inference/v2/model_implementations/phi3small/containers.py index deb31d311627..fcdc17e6cd4d 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/containers.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/containers.py @@ -79,9 +79,10 @@ class Phi3SmallNonTransformerContainer(LayerContainer): word_emb: EmbeddingParameter final_norm_gamma: NormParameter final_norm_beta: NormParameter + word_unembed: UnembedParameter PARAM_MAPPING = { - "model.embed_tokens.weight": "word_emb.params", + "model.embed_tokens.weight": ["word_emb.params", "word_unembed.params"], "model.final_layernorm.weight": "final_norm_gamma.params", "model.final_layernorm.bias": "final_norm_beta.params", } diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py index e8c22a108611..9d5e6a599365 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -175,10 +175,9 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge Performs unembedding of the hidden states to logits. This will only sample the final token of each sequence. """ - word_unembed = torch.empty(self.vocab_size, self.model_dim, dtype=hidden_states.dtype, device=hidden_states.device) - torch.nn.init.xavier_uniform_(word_unembed) + logits = self.unembed(hidden_states, - word_unembed, + self._non_transformer.word_unembed, ragged_batch_info, gamma=self._non_transformer.final_norm_gamma, beta=self._non_transformer.final_norm_beta) From 230615bd402150ad6dbe68e53b96ac9ffa090122 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Tue, 11 Jun 2024 22:57:04 +0000 Subject: [PATCH 3/6] Add mup_embedding_multiplier --- .../inference/v2/model_implementations/phi3small/model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py index 9d5e6a599365..1f1e853fc167 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -108,6 +108,11 @@ def positional_embedding_type(self) -> PositionalEmbeddingType: def positional_embedding_config(self) -> Optional[RotateHalfConfig]: return RotateHalfConfig(theta_base=self._config.rope_embedding_base) + @property + def mup_embedding_multiplier(self) -> float: + return 10.0 + + """ Forward implementations """ @@ -127,6 +132,9 @@ def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: if embed.shape[-1] != self.model_dim: raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + if self.mup_embedding_multiplier > 0.0: + embed = embed * self.mup_embedding_multiplier + return embed def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, From e95086f3bd33441f3f8bd409ad4ec115b0ee3b78 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Mon, 28 Oct 2024 11:44:20 -0700 Subject: [PATCH 4/6] Formatting fixes --- csrc/aio/common/deepspeed_aio_common.cpp | 5 ++--- csrc/aio/py_lib/deepspeed_py_aio.cpp | 10 ++++------ csrc/aio/py_lib/deepspeed_py_aio_handle.cpp | 10 ++++------ csrc/aio/py_lib/deepspeed_py_copy.cpp | 2 +- .../evoformer_attn/gemm_kernel_utils.h | 9 ++++----- csrc/includes/simd.h | 2 +- csrc/xpu/includes/simd.h | 2 +- csrc/xpu/includes/type_shim.h | 10 +++++----- .../v2/model_implementations/phi3small/model.py | 16 ++++++++++++---- 9 files changed, 34 insertions(+), 32 deletions(-) mode change 100755 => 100644 csrc/xpu/includes/simd.h diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index 0f2895dfa328..a65cc500cc82 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -301,9 +301,8 @@ int regular_read(const char* filename, std::vector& buffer) } while (r > 0); if (read_bytes != num_bytes) { - std::cerr << "read error " - << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes - << std::endl; + std::cerr << "read error " << " read_bytes (read) = " << read_bytes + << " num_bytes (fstat) = " << num_bytes << std::endl; } assert(read_bytes == num_bytes); close(fd); diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index 387b713f2bfc..0556f5aa8168 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -72,9 +72,8 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; return 0; } @@ -118,8 +117,7 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; return 0; } diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index c21e92de9449..23ddabe260d4 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -93,9 +93,8 @@ int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, co if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; return 0; } @@ -128,9 +127,8 @@ int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; + std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 + << " call = " << fn_time.count() * 1e6 << std::endl; return 0; } diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp index 8a59107dd347..c597b91d05c9 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.cpp +++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp @@ -10,7 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include "deepspeed_py_copy.h" #include -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) or defined(__AVX256__) union AVX_Data { diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h index 2a4300c5cac1..c102234a4dfb 100644 --- a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h +++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h @@ -125,11 +125,10 @@ struct CheckArch { std::cerr << #PTR " is not correctly aligned\n"; \ return false; \ } -#define EVOFORMER_CHECK(COND, ERR) \ - if (!(COND)) { \ - std::cerr << "[Evoformer Attention]" \ - << "'" #COND "' failed: " << ERR << "\n"; \ - return false; \ +#define EVOFORMER_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << "[Evoformer Attention]" << "'" #COND "' failed: " << ERR << "\n"; \ + return false; \ } #endif diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index f5bfb45dd2e2..a205026ec7c1 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -27,7 +27,7 @@ inline void writeAs(void* dst, const T& val) std::memcpy(dst, &val, sizeof(T)); } -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) diff --git a/csrc/xpu/includes/simd.h b/csrc/xpu/includes/simd.h old mode 100755 new mode 100644 index f77568be7835..097e2d8585cc --- a/csrc/xpu/includes/simd.h +++ b/csrc/xpu/includes/simd.h @@ -13,7 +13,7 @@ #define TILE (128 * 1024 * 1024) #if defined(__AVX512__) or defined(__AVX256__) -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) #if defined(__AVX512__) #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) diff --git a/csrc/xpu/includes/type_shim.h b/csrc/xpu/includes/type_shim.h index fa41757c895b..1897afd1fea2 100644 --- a/csrc/xpu/includes/type_shim.h +++ b/csrc/xpu/includes/type_shim.h @@ -82,11 +82,11 @@ } template -__inline__ __attribute__((always_inline)) T reduce_block_into_lanes( - T* x, - T val, - int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. +__inline__ __attribute__((always_inline)) T +reduce_block_into_lanes(T* x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); int tid = item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range(2); diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py index 1f1e853fc167..532cb49c80c4 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -112,7 +112,6 @@ def positional_embedding_config(self) -> Optional[RotateHalfConfig]: def mup_embedding_multiplier(self) -> float: return 10.0 - """ Forward implementations """ @@ -160,7 +159,10 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if self.tp_size > 1: dist.all_reduce(hidden_states, group=self._base_mp_group) - residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta) + residual, hidden_states = self.norm(residual, + hidden_states, + cur_params.mlp_norm_gamma, + beta=cur_params.mlp_norm_beta) hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) @@ -170,7 +172,10 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if layer_idx != self.num_layers - 1: next_params = self._transformer[layer_idx + 1] - residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta) + residual, hidden_states = self.norm(residual, + hidden_states, + next_params.attn_norm_gamma, + beta=next_params.attn_norm_beta) else: # On last layer, we just need to perform the residual add. Adding into the residual # here is safe. @@ -205,7 +210,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: residual = self._forward_embed(wrapped_batch) - residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta) + residual, hidden_states = self.norm(residual, + None, + gamma=self._transformer[0].attn_norm_gamma, + beta=self._transformer[0].attn_norm_beta) for layer_idx in range(self.num_layers): residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, From 6f1e1fc4954f198b757ca82417cdab4c55e5a2e5 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Mon, 28 Oct 2024 11:48:59 -0700 Subject: [PATCH 5/6] Revert "Formatting fixes" This reverts commit e95086f3bd33441f3f8bd409ad4ec115b0ee3b78. --- csrc/aio/common/deepspeed_aio_common.cpp | 5 +++-- csrc/aio/py_lib/deepspeed_py_aio.cpp | 10 ++++++---- csrc/aio/py_lib/deepspeed_py_aio_handle.cpp | 10 ++++++---- csrc/aio/py_lib/deepspeed_py_copy.cpp | 2 +- .../evoformer_attn/gemm_kernel_utils.h | 9 +++++---- csrc/includes/simd.h | 2 +- csrc/xpu/includes/simd.h | 2 +- csrc/xpu/includes/type_shim.h | 10 +++++----- .../v2/model_implementations/phi3small/model.py | 16 ++++------------ 9 files changed, 32 insertions(+), 34 deletions(-) mode change 100644 => 100755 csrc/xpu/includes/simd.h diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index a65cc500cc82..0f2895dfa328 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -301,8 +301,9 @@ int regular_read(const char* filename, std::vector& buffer) } while (r > 0); if (read_bytes != num_bytes) { - std::cerr << "read error " << " read_bytes (read) = " << read_bytes - << " num_bytes (fstat) = " << num_bytes << std::endl; + std::cerr << "read error " + << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes + << std::endl; } assert(read_bytes == num_bytes); close(fd); diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index 0556f5aa8168..387b713f2bfc 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -72,8 +72,9 @@ int deepspeed_py_aio_write(const torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 - << " call = " << fn_time.count() * 1e6 << std::endl; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; return 0; } @@ -117,7 +118,8 @@ int deepspeed_py_aio_read(torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 - << " call = " << fn_time.count() * 1e6 << std::endl; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; return 0; } diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index 23ddabe260d4..c21e92de9449 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -93,8 +93,9 @@ int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, co if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 - << " call = " << fn_time.count() * 1e6 << std::endl; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; return 0; } @@ -127,8 +128,9 @@ int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, const std::chrono::duration fn_time = std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " << "aio = " << aio_time.count() * 1e6 - << " call = " << fn_time.count() * 1e6 << std::endl; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; return 0; } diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp index c597b91d05c9..8a59107dd347 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.cpp +++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp @@ -10,7 +10,7 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices. #include "deepspeed_py_copy.h" #include -#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) #if defined(__AVX512__) or defined(__AVX256__) union AVX_Data { diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h index c102234a4dfb..2a4300c5cac1 100644 --- a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h +++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h @@ -125,10 +125,11 @@ struct CheckArch { std::cerr << #PTR " is not correctly aligned\n"; \ return false; \ } -#define EVOFORMER_CHECK(COND, ERR) \ - if (!(COND)) { \ - std::cerr << "[Evoformer Attention]" << "'" #COND "' failed: " << ERR << "\n"; \ - return false; \ +#define EVOFORMER_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << "[Evoformer Attention]" \ + << "'" #COND "' failed: " << ERR << "\n"; \ + return false; \ } #endif diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index a205026ec7c1..f5bfb45dd2e2 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -27,7 +27,7 @@ inline void writeAs(void* dst, const T& val) std::memcpy(dst, &val, sizeof(T)); } -#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) #if defined(__AVX512__) #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) diff --git a/csrc/xpu/includes/simd.h b/csrc/xpu/includes/simd.h old mode 100644 new mode 100755 index 097e2d8585cc..f77568be7835 --- a/csrc/xpu/includes/simd.h +++ b/csrc/xpu/includes/simd.h @@ -13,7 +13,7 @@ #define TILE (128 * 1024 * 1024) #if defined(__AVX512__) or defined(__AVX256__) -#define ROUND_DOWN(size, step) ((size) & ~((step) - 1)) +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) #if defined(__AVX512__) #define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) diff --git a/csrc/xpu/includes/type_shim.h b/csrc/xpu/includes/type_shim.h index 1897afd1fea2..fa41757c895b 100644 --- a/csrc/xpu/includes/type_shim.h +++ b/csrc/xpu/includes/type_shim.h @@ -82,11 +82,11 @@ } template -__inline__ __attribute__((always_inline)) T -reduce_block_into_lanes(T* x, - T val, - int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. +__inline__ __attribute__((always_inline)) T reduce_block_into_lanes( + T* x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. { auto item_ct1 = sycl::ext::oneapi::experimental::this_nd_item<3>(); int tid = item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range(2); diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py index 532cb49c80c4..1f1e853fc167 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -112,6 +112,7 @@ def positional_embedding_config(self) -> Optional[RotateHalfConfig]: def mup_embedding_multiplier(self) -> float: return 10.0 + """ Forward implementations """ @@ -159,10 +160,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if self.tp_size > 1: dist.all_reduce(hidden_states, group=self._base_mp_group) - residual, hidden_states = self.norm(residual, - hidden_states, - cur_params.mlp_norm_gamma, - beta=cur_params.mlp_norm_beta) + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta) hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) @@ -172,10 +170,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if layer_idx != self.num_layers - 1: next_params = self._transformer[layer_idx + 1] - residual, hidden_states = self.norm(residual, - hidden_states, - next_params.attn_norm_gamma, - beta=next_params.attn_norm_beta) + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta) else: # On last layer, we just need to perform the residual add. Adding into the residual # here is safe. @@ -210,10 +205,7 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: residual = self._forward_embed(wrapped_batch) - residual, hidden_states = self.norm(residual, - None, - gamma=self._transformer[0].attn_norm_gamma, - beta=self._transformer[0].attn_norm_beta) + residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta) for layer_idx in range(self.num_layers): residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, From 857c779889173896ac4f769632279f83fec10a59 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Mon, 28 Oct 2024 12:22:03 -0700 Subject: [PATCH 6/6] Format fixes --- .../v2/model_implementations/phi3small/model.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/deepspeed/inference/v2/model_implementations/phi3small/model.py b/deepspeed/inference/v2/model_implementations/phi3small/model.py index 1f1e853fc167..532cb49c80c4 100644 --- a/deepspeed/inference/v2/model_implementations/phi3small/model.py +++ b/deepspeed/inference/v2/model_implementations/phi3small/model.py @@ -112,7 +112,6 @@ def positional_embedding_config(self) -> Optional[RotateHalfConfig]: def mup_embedding_multiplier(self) -> float: return 10.0 - """ Forward implementations """ @@ -160,7 +159,10 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if self.tp_size > 1: dist.all_reduce(hidden_states, group=self._base_mp_group) - residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=cur_params.mlp_norm_beta) + residual, hidden_states = self.norm(residual, + hidden_states, + cur_params.mlp_norm_gamma, + beta=cur_params.mlp_norm_beta) hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) @@ -170,7 +172,10 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid if layer_idx != self.num_layers - 1: next_params = self._transformer[layer_idx + 1] - residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=next_params.attn_norm_beta) + residual, hidden_states = self.norm(residual, + hidden_states, + next_params.attn_norm_gamma, + beta=next_params.attn_norm_beta) else: # On last layer, we just need to perform the residual add. Adding into the residual # here is safe. @@ -205,7 +210,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: residual = self._forward_embed(wrapped_batch) - residual, hidden_states = self.norm(residual, None, gamma=self._transformer[0].attn_norm_gamma, beta=self._transformer[0].attn_norm_beta) + residual, hidden_states = self.norm(residual, + None, + gamma=self._transformer[0].attn_norm_gamma, + beta=self._transformer[0].attn_norm_beta) for layer_idx in range(self.num_layers): residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states,