From c1e02052aca93a79e08ea5a9064f4353db03a34c Mon Sep 17 00:00:00 2001 From: Arash Bakhtiari Date: Wed, 10 Jan 2024 09:33:08 -0800 Subject: [PATCH] Refactor the positional emebdding config code (#4920) The Mixtral PR https://github.com/microsoft/DeepSpeed/pull/4828 has introduced the positional embedding config class which is a required argument of `make_attn_layer()` function. This has forced the user to override and duplicate the `make_attn_layer()` call for new model implementations using RoPE (This has also broken the Falcon model implementations). This PR: - refactors the inference transformer base class to avoid code duplication by adding a new abstract `positional_embedding_config` property - Fixes the Falcon model implementation to use positional embedding config. The models `llama_v2`, `OPT`, `Mistral 7B`, `Mixtral`, `Falcon` and `Phi-2` are tested with the PR! --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .../v2/model_implementations/falcon/model.py | 7 +++ .../inference_transformer_base.py | 12 +++- .../model_implementations/llama_v2/model.py | 24 +------ .../v2/model_implementations/mistral/model.py | 24 +------ .../v2/model_implementations/mixtral/model.py | 29 +++------ .../v2/model_implementations/opt/model.py | 10 +-- .../v2/model_implementations/phi/model.py | 63 +------------------ .../attention/dense_blocked_attention.py | 2 + 8 files changed, 41 insertions(+), 130 deletions(-) diff --git a/deepspeed/inference/v2/model_implementations/falcon/model.py b/deepspeed/inference/v2/model_implementations/falcon/model.py index d1ccc38280a0..b2830c80b562 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/model.py +++ b/deepspeed/inference/v2/model_implementations/falcon/model.py @@ -95,6 +95,13 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + @property + def positional_embedding_config(self) -> RotateHalfConfig: + """ + The positional embedding configuration for the model. + """ + return RotateHalfConfig() + """ Forward implementations """ diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index e78a161b4cd0..0cc577451d51 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -22,6 +22,7 @@ DSUnembedConfig, NormTypeEnum, PositionalEmbeddingType, + RotateHalfConfig, ) from ..modules import heuristics from ..ragged import ( @@ -152,6 +153,14 @@ def norm_type(self) -> NormTypeEnum: """ ... + @property + @abstractmethod + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + """ + The positional embedding configuration for the model. + """ + ... + """ Derived helpers """ @@ -319,7 +328,8 @@ def make_attn_layer(self) -> None: scale_factor=softmax_scale, input_dtype=self.activation_dtype, output_dtype=self.activation_dtype, - positional_embedding_type=self.positional_embedding_type) + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=self.positional_embedding_config) self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/model.py b/deepspeed/inference/v2/model_implementations/llama_v2/model.py index 735e8f52cca3..a0c81f4d749e 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/model.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/model.py @@ -14,7 +14,6 @@ from .. import * from ...modules.configs import * from ...modules.interfaces import * -from ...modules import heuristics from ...ragged import RaggedBatchWrapper from .container import Llama2NonTransformerContainer, Llama2TransformerContainer @@ -106,26 +105,9 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half - def make_attn_layer(self) -> None: - """ - Builds the attention layer for the model. This sets the `self.attn` attribute. - """ - softmax_scale = 1.0 / (self.head_size**0.5) - - rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) - - attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, - n_heads_q=self.n_heads_q_local, - n_heads_kv=self.n_heads_kv_local, - head_size=self.head_size, - max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, - scale_factor=softmax_scale, - input_dtype=self.activation_dtype, - output_dtype=self.activation_dtype, - positional_embedding_type=self.positional_embedding_type, - positional_embedding_config=rotary_config) - - self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) """ Forward implementations diff --git a/deepspeed/inference/v2/model_implementations/mistral/model.py b/deepspeed/inference/v2/model_implementations/mistral/model.py index 9c707026f9dd..318d362f1a64 100644 --- a/deepspeed/inference/v2/model_implementations/mistral/model.py +++ b/deepspeed/inference/v2/model_implementations/mistral/model.py @@ -14,7 +14,6 @@ from ...model_implementations import * from ...modules.configs import * from ...modules.interfaces import * -from ...modules import heuristics from ...ragged import RaggedBatchWrapper from .container import MistralNonTransformerContainer, MistralTransformerContainer @@ -105,26 +104,9 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half - def make_attn_layer(self) -> None: - """ - Builds the attention layer for the model. This sets the `self.attn` attribute. - """ - softmax_scale = 1.0 / (self.head_size**0.5) - - rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) - - attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, - n_heads_q=self.n_heads_q_local, - n_heads_kv=self.n_heads_kv_local, - head_size=self.head_size, - max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, - scale_factor=softmax_scale, - input_dtype=self.activation_dtype, - output_dtype=self.activation_dtype, - positional_embedding_type=self.positional_embedding_type, - positional_embedding_config=rotary_config) - - self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) """ Forward implementations diff --git a/deepspeed/inference/v2/model_implementations/mixtral/model.py b/deepspeed/inference/v2/model_implementations/mixtral/model.py index d0cae0ff307b..878cd8e31cec 100644 --- a/deepspeed/inference/v2/model_implementations/mixtral/model.py +++ b/deepspeed/inference/v2/model_implementations/mixtral/model.py @@ -15,7 +15,6 @@ from ...model_implementations import * from ...modules.configs import * from ...modules.interfaces import * -from ...modules import heuristics from ...ragged import RaggedBatchWrapper from ..inference_model_base import ( DSModelImplementationConfig, @@ -110,6 +109,13 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + """ + The positional embedding configuration for the model. + """ + return RotateHalfConfig(theta_base=self._config.rope_theta) + """ Inherited from `DSMoETransformerModelBase` """ @@ -161,27 +167,6 @@ def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInf self.make_unembedding_layer() self._kv_cache_config = None - def make_attn_layer(self) -> None: - """ - Builds the attention layer for the model. This sets the `self.attn` attribute. - """ - softmax_scale = 1.0 / (self.head_size**0.5) - - rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) - - attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, - n_heads_q=self.n_heads_q_local, - n_heads_kv=self.n_heads_kv_local, - head_size=self.head_size, - max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, - scale_factor=softmax_scale, - input_dtype=self.activation_dtype, - output_dtype=self.activation_dtype, - positional_embedding_type=self.positional_embedding_type, - positional_embedding_config=rotary_config) - - self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) - def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: """ Performs the embedding lookup prior to running the transformer of the model. diff --git a/deepspeed/inference/v2/model_implementations/opt/model.py b/deepspeed/inference/v2/model_implementations/opt/model.py index 8bad12f10475..adf011d8f1a7 100644 --- a/deepspeed/inference/v2/model_implementations/opt/model.py +++ b/deepspeed/inference/v2/model_implementations/opt/model.py @@ -12,11 +12,7 @@ from ...allocator import empty_from from ...inference_utils import ActivationType, DtypeEnum from ...model_implementations import * -from ...modules.configs import ( - DSEmbeddingsConfig, - NormTypeEnum, - PositionalEmbeddingType, -) +from ...modules.configs import * from ...ragged import RaggedBatchWrapper from .container import OPTNonTransformerContainer, OPTTransformerContainer @@ -94,6 +90,10 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.none + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return None + """ Overrides of ``DSTransformerModelBase`` methods """ diff --git a/deepspeed/inference/v2/model_implementations/phi/model.py b/deepspeed/inference/v2/model_implementations/phi/model.py index a95b12bb119f..0127c87c7bff 100644 --- a/deepspeed/inference/v2/model_implementations/phi/model.py +++ b/deepspeed/inference/v2/model_implementations/phi/model.py @@ -10,17 +10,11 @@ import deepspeed.comm as dist from ...allocator import empty_from -from ...config_v2 import RaggedInferenceEngineConfig from ...inference_utils import ActivationType, DtypeEnum from .. import * from ...modules.configs import * from ...modules.interfaces import * -from ...modules import heuristics from ...ragged import RaggedBatchWrapper -from ..inference_model_base import ( - DSModelImplementationConfig, - MPType, -) from .containers import PhiNonTransformerContainer, PhiTransformerContainer @@ -101,60 +95,9 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half - """ - Model implementation - """ - - def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, - base_mp_group: MPType) -> None: - """ - Base implementation for initialization. By default, this will initialize - the traditional components of a transformer model: - - Embedding - - QKV projection - - Self attention - - Attention output projection - - Feed forward network - - Normalization - - Unembedding - - Arguments: - config (DSModelImplementationConfig): Model-specific configuration. No assumptions - should be made about this config that are not closely tied to the specific - model implementation. - engine_config (RaggedInferenceEngineConfig): Engine configuration. - base_mp_group (MPType): Base communication group for Tensor-parallel inference. - """ - super().__init__(config, engine_config, base_mp_group) - - self.make_norm_layer() - self.make_qkv_layer() - self.make_attn_layer() - self.make_attn_out_layer() - self.make_embedding_layer() - self.make_unembedding_layer() - self._kv_cache_config = None - - def make_attn_layer(self) -> None: - """ - Builds the attention layer for the model. This sets the `self.attn` attribute. - """ - softmax_scale = 1.0 / (self.head_size**0.5) - - rotary_config = RotateHalfConfig(rotate_dim=self._config.rotary_dim) - - attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, - n_heads_q=self.n_heads_q_local, - n_heads_kv=self.n_heads_kv_local, - head_size=self.head_size, - max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, - scale_factor=softmax_scale, - input_dtype=self.activation_dtype, - output_dtype=self.activation_dtype, - positional_embedding_type=self.positional_embedding_type, - positional_embedding_config=rotary_config) - - self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(rotate_dim=self._config.rotary_dim) """ Forward implementations diff --git a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py index 5f41b5ff6e13..3515b3c2b690 100644 --- a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py +++ b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py @@ -87,6 +87,8 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st self._config.n_heads_kv, self._config.input_dtype) elif embed_type == PositionalEmbeddingType.rotate_half: rotary_config = config.positional_embedding_config + assert rotary_config is not None, "Rotary config must be provided if using rotate_half as Positional Embedding Type." + if rotary_config.use_trained_freqs: # Theta and rotary dim are effectively embedded into either the values (theta) or the shape (rotary_dim) # of the trained_freqs tensor.