Skip to content

Commit

Permalink
Refactor the positional emebdding config code (#4920)
Browse files Browse the repository at this point in the history
The Mixtral PR #4828 has
introduced the positional embedding config class which is a required
argument of `make_attn_layer()` function. This has forced the user to
override and duplicate the `make_attn_layer()` call for new model
implementations using RoPE (This has also broken the Falcon model
implementations). This PR:

- refactors the inference transformer base class to avoid code
duplication by adding a new abstract `positional_embedding_config`
property
- Fixes the Falcon model implementation to use positional embedding
config.

The models `llama_v2`, `OPT`, `Mistral 7B`, `Mixtral`, `Falcon` and
`Phi-2` are tested with the PR!

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
arashb and loadams authored Jan 10, 2024
1 parent 16c265c commit c1e0205
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 130 deletions.
7 changes: 7 additions & 0 deletions deepspeed/inference/v2/model_implementations/falcon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def norm_type(self) -> NormTypeEnum:
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

@property
def positional_embedding_config(self) -> RotateHalfConfig:
"""
The positional embedding configuration for the model.
"""
return RotateHalfConfig()

"""
Forward implementations
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DSUnembedConfig,
NormTypeEnum,
PositionalEmbeddingType,
RotateHalfConfig,
)
from ..modules import heuristics
from ..ragged import (
Expand Down Expand Up @@ -152,6 +153,14 @@ def norm_type(self) -> NormTypeEnum:
"""
...

@property
@abstractmethod
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
"""
The positional embedding configuration for the model.
"""
...

"""
Derived helpers
"""
Expand Down Expand Up @@ -319,7 +328,8 @@ def make_attn_layer(self) -> None:
scale_factor=softmax_scale,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
positional_embedding_type=self.positional_embedding_type)
positional_embedding_type=self.positional_embedding_type,
positional_embedding_config=self.positional_embedding_config)

self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)

Expand Down
24 changes: 3 additions & 21 deletions deepspeed/inference/v2/model_implementations/llama_v2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .. import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...modules import heuristics
from ...ragged import RaggedBatchWrapper

from .container import Llama2NonTransformerContainer, Llama2TransformerContainer
Expand Down Expand Up @@ -106,26 +105,9 @@ def norm_type(self) -> NormTypeEnum:
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

def make_attn_layer(self) -> None:
"""
Builds the attention layer for the model. This sets the `self.attn` attribute.
"""
softmax_scale = 1.0 / (self.head_size**0.5)

rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta)

attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
n_heads_q=self.n_heads_q_local,
n_heads_kv=self.n_heads_kv_local,
head_size=self.head_size,
max_sequences=self._engine_config.state_manager.max_ragged_sequence_count,
scale_factor=softmax_scale,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
positional_embedding_type=self.positional_embedding_type,
positional_embedding_config=rotary_config)

self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)
@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_theta)

"""
Forward implementations
Expand Down
24 changes: 3 additions & 21 deletions deepspeed/inference/v2/model_implementations/mistral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ...model_implementations import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...modules import heuristics
from ...ragged import RaggedBatchWrapper

from .container import MistralNonTransformerContainer, MistralTransformerContainer
Expand Down Expand Up @@ -105,26 +104,9 @@ def norm_type(self) -> NormTypeEnum:
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

def make_attn_layer(self) -> None:
"""
Builds the attention layer for the model. This sets the `self.attn` attribute.
"""
softmax_scale = 1.0 / (self.head_size**0.5)

rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta)

attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
n_heads_q=self.n_heads_q_local,
n_heads_kv=self.n_heads_kv_local,
head_size=self.head_size,
max_sequences=self._engine_config.state_manager.max_ragged_sequence_count,
scale_factor=softmax_scale,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
positional_embedding_type=self.positional_embedding_type,
positional_embedding_config=rotary_config)

self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)
@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rope_theta)

"""
Forward implementations
Expand Down
29 changes: 7 additions & 22 deletions deepspeed/inference/v2/model_implementations/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ...model_implementations import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...modules import heuristics
from ...ragged import RaggedBatchWrapper
from ..inference_model_base import (
DSModelImplementationConfig,
Expand Down Expand Up @@ -110,6 +109,13 @@ def norm_type(self) -> NormTypeEnum:
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
"""
The positional embedding configuration for the model.
"""
return RotateHalfConfig(theta_base=self._config.rope_theta)

"""
Inherited from `DSMoETransformerModelBase`
"""
Expand Down Expand Up @@ -161,27 +167,6 @@ def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInf
self.make_unembedding_layer()
self._kv_cache_config = None

def make_attn_layer(self) -> None:
"""
Builds the attention layer for the model. This sets the `self.attn` attribute.
"""
softmax_scale = 1.0 / (self.head_size**0.5)

rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta)

attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
n_heads_q=self.n_heads_q_local,
n_heads_kv=self.n_heads_kv_local,
head_size=self.head_size,
max_sequences=self._engine_config.state_manager.max_ragged_sequence_count,
scale_factor=softmax_scale,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
positional_embedding_type=self.positional_embedding_type,
positional_embedding_config=rotary_config)

self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)

def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor:
"""
Performs the embedding lookup prior to running the transformer of the model.
Expand Down
10 changes: 5 additions & 5 deletions deepspeed/inference/v2/model_implementations/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
from ...allocator import empty_from
from ...inference_utils import ActivationType, DtypeEnum
from ...model_implementations import *
from ...modules.configs import (
DSEmbeddingsConfig,
NormTypeEnum,
PositionalEmbeddingType,
)
from ...modules.configs import *
from ...ragged import RaggedBatchWrapper
from .container import OPTNonTransformerContainer, OPTTransformerContainer

Expand Down Expand Up @@ -94,6 +90,10 @@ def norm_type(self) -> NormTypeEnum:
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.none

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return None

"""
Overrides of ``DSTransformerModelBase`` methods
"""
Expand Down
63 changes: 3 additions & 60 deletions deepspeed/inference/v2/model_implementations/phi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,11 @@
import deepspeed.comm as dist

from ...allocator import empty_from
from ...config_v2 import RaggedInferenceEngineConfig
from ...inference_utils import ActivationType, DtypeEnum
from .. import *
from ...modules.configs import *
from ...modules.interfaces import *
from ...modules import heuristics
from ...ragged import RaggedBatchWrapper
from ..inference_model_base import (
DSModelImplementationConfig,
MPType,
)

from .containers import PhiNonTransformerContainer, PhiTransformerContainer

Expand Down Expand Up @@ -101,60 +95,9 @@ def norm_type(self) -> NormTypeEnum:
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

"""
Model implementation
"""

def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig,
base_mp_group: MPType) -> None:
"""
Base implementation for initialization. By default, this will initialize
the traditional components of a transformer model:
- Embedding
- QKV projection
- Self attention
- Attention output projection
- Feed forward network
- Normalization
- Unembedding
Arguments:
config (DSModelImplementationConfig): Model-specific configuration. No assumptions
should be made about this config that are not closely tied to the specific
model implementation.
engine_config (RaggedInferenceEngineConfig): Engine configuration.
base_mp_group (MPType): Base communication group for Tensor-parallel inference.
"""
super().__init__(config, engine_config, base_mp_group)

self.make_norm_layer()
self.make_qkv_layer()
self.make_attn_layer()
self.make_attn_out_layer()
self.make_embedding_layer()
self.make_unembedding_layer()
self._kv_cache_config = None

def make_attn_layer(self) -> None:
"""
Builds the attention layer for the model. This sets the `self.attn` attribute.
"""
softmax_scale = 1.0 / (self.head_size**0.5)

rotary_config = RotateHalfConfig(rotate_dim=self._config.rotary_dim)

attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
n_heads_q=self.n_heads_q_local,
n_heads_kv=self.n_heads_kv_local,
head_size=self.head_size,
max_sequences=self._engine_config.state_manager.max_ragged_sequence_count,
scale_factor=softmax_scale,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
positional_embedding_type=self.positional_embedding_type,
positional_embedding_config=rotary_config)

self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)
@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(rotate_dim=self._config.rotary_dim)

"""
Forward implementations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st
self._config.n_heads_kv, self._config.input_dtype)
elif embed_type == PositionalEmbeddingType.rotate_half:
rotary_config = config.positional_embedding_config
assert rotary_config is not None, "Rotary config must be provided if using rotate_half as Positional Embedding Type."

if rotary_config.use_trained_freqs:
# Theta and rotary dim are effectively embedded into either the values (theta) or the shape (rotary_dim)
# of the trained_freqs tensor.
Expand Down

0 comments on commit c1e0205

Please sign in to comment.