diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index a0dc050bbbf9a..a17fa9fefbaa3 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -19,6 +19,7 @@ MistralPolicy, MixtralPolicy, FalconPolicy, + PhiPolicy, ) from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata @@ -114,6 +115,8 @@ def build_hf_engine(path: str, policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "falcon": policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "phi-msft": + policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine) else: raise ValueError(f"Unsupported model type {model_config.model_type}") diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu index 5dd79f0c636a0..807f3c1b3d631 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu @@ -21,7 +21,7 @@ constexpr int threads = 256; Supports head size 32, 64, 128, 256 */ -template +template __global__ void kv_rotary_pos_kernel(T* kv_cache, T* q, T* k, @@ -36,28 +36,31 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, { // Derived constexpr constexpr int vector_T = kv_rot::granularity / sizeof(T); - constexpr int threads_per_head = headSize / vector_T; + constexpr int real_threads_per_head = headSize / vector_T; + constexpr int threads_per_head = paddedHeadSize / vector_T; constexpr int half_head_size = headSize >> 1; + constexpr int tokens_per_block = kv_rot::threads / threads_per_head; // CG helpers cg::thread_block tb = cg::this_thread_block(); cg::thread_block_tile warp = cg::tiled_partition(tb); - cg::thread_block_tile head_group = - cg::tiled_partition(warp); + cg::thread_block_tile head_group = cg::tiled_partition(tb); // Parallelize on the head dimension for X blocks const int head_idx = blockIdx.x; const int block_seq_idx = threadIdx.x / threads_per_head; - const int base_neuron_idx = (threadIdx.x * vector_T) % headSize; + const int base_neuron_idx = head_group.thread_rank() * vector_T; const int half_idx = base_neuron_idx % half_head_size; - const int half_head_lanes = threads_per_head / 2; + const int half_head_lanes = real_threads_per_head / 2; // Multiple tokens processed by the same threadblock const int token_idx = blockIdx.y * tokens_per_block + block_seq_idx; const bool valid_token = token_idx < batch_desc.batch_metadata->n_tokens; - const bool load_inv_freq = (inv_freq != nullptr) && valid_token; + + const bool valid_thread = valid_token && (head_group.thread_rank() < real_threads_per_head); + const bool load_inv_freq = (inv_freq != nullptr) && valid_thread; // If we have GQA, then only one of the Q heads needs to do rotary + copy // for each of the heads in the group. @@ -68,9 +71,9 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, const int kv_head_idx = head_idx / qRatio; // Ensure we don't access invalid portions of the seq_metadata - const int32_t seq_id = (valid_token) ? batch_desc.tokens_to_seq[token_idx] : 0; + const int32_t seq_id = (valid_thread) ? batch_desc.tokens_to_seq[token_idx] : 0; const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_id]; - // This will give an invalid index if valid_token is false, but should never affect memory. + // This will give an invalid index if valid_thread is false, but should never affect memory. const int32_t global_token_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx); T* q_row = q + token_idx * qkv_stride + head_idx * headSize; @@ -82,7 +85,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, const KVCacheDescriptor kv_desc = batch_desc.kv_desc; const int32_t seq_kv_block_idx = global_token_idx / kv_desc.block_size; const int32_t mapped_kv_block_idx = - (valid_token) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0; + (valid_thread) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0; const int32_t kv_block_offset = global_token_idx % kv_desc.block_size; const int32_t kv_offset = @@ -95,12 +98,11 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, T k_reg[vector_T], v_reg[vector_T], inv_freq_reg[vector_T]; - mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_token); - mem_access::load_global(k_reg, k_row + base_neuron_idx, valid_token); - mem_access::load_global(v_reg, v_row + base_neuron_idx, valid_token); + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_thread); + mem_access::load_global(k_reg, k_row + base_neuron_idx, valid_thread); + mem_access::load_global(v_reg, v_row + base_neuron_idx, valid_thread); mem_access::load_global( inv_freq_reg, inv_freq + half_idx, load_inv_freq); - if constexpr (doRotary) { #pragma unroll for (int i = 0; i < vector_T; i++) { @@ -125,8 +127,12 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, float q_rot = q_f * rotary_sign; float k_rot = k_f * rotary_sign; - const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes); - const float k_rot_temp = head_group.shfl_xor(k_rot, half_head_lanes); + const int target_lane = (head_neuron_idx < half_head_size) + ? head_group.thread_rank() + half_head_lanes + : head_group.thread_rank() - half_head_lanes; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + const float k_rot_temp = head_group.shfl(k_rot, target_lane); q_reg[i] = conversion::to(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt)); @@ -135,7 +141,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, } } - if (valid_token) { + if (valid_thread) { mem_access::store_global(kv_cache + kv_offset + base_neuron_idx, k_reg); mem_access::store_global( @@ -144,7 +150,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, } else { T inv_freq_reg[vector_T]; - mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_token); + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_thread); mem_access::load_global( inv_freq_reg, inv_freq + half_idx, load_inv_freq); @@ -166,7 +172,11 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, float q_f = conversion::to(q_reg[i]); float q_rot = q_f * rotary_sign; - const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes); + const int target_lane = (head_neuron_idx < half_head_size) + ? head_group.thread_rank() + half_head_lanes + : head_group.thread_rank() - half_head_lanes; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); q_reg[i] = conversion::to(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt)); @@ -174,26 +184,46 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, } } - if (valid_token && doRotary) { + if (valid_thread && doRotary) { mem_access::store_global(q_row + base_neuron_idx, q_reg); } } -#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE) \ - if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ - kv_rotary_pos_kernel \ - <<>>(kv_cache, \ - q, \ - k, \ - v, \ - inv_freq, \ - theta_base, \ - batch_desc, \ - qkv_stride, \ - kv_cache_stride, \ - v_offset, \ +#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, PADDED_HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + inv_freq, \ + theta_base, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ inv_freq_stride); +#define LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, HEAD_SIZE) \ + if (padded_head_size == 64) { \ + DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, 64); \ + } else if (padded_head_size == 128) { \ + DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, 128); \ + } else { \ + assert(false); \ + } + +#define LAUNCH_KV_ROTARY_FOR_Q_RATIO(Q_RATIO) \ + if (head_size == 64) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \ + } else if (head_size == 80) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \ + } else if (head_size == 128) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \ + } else { \ + assert(false); \ + } + template void launch_kv_rotary_kernel(T* kv_cache, T* q, @@ -213,33 +243,26 @@ void launch_kv_rotary_kernel(T* kv_cache, cudaStream_t stream) { constexpr int vector_T = kv_rot::granularity / sizeof(T); - const int threads_per_head = head_size / vector_T; + + const int padded_head_size = next_pow2(head_size); + const int threads_per_head = padded_head_size / vector_T; + const int tokens_per_block = kv_rot::threads / threads_per_head; const dim3 block(kv_rot::threads); const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; const dim3 grid(n_q_heads, token_blocks); - DISPATCH_KV_ROTARY_IMPL(1, 64) - DISPATCH_KV_ROTARY_IMPL(1, 128) - DISPATCH_KV_ROTARY_IMPL(2, 64) - DISPATCH_KV_ROTARY_IMPL(2, 128) - DISPATCH_KV_ROTARY_IMPL(4, 64) - DISPATCH_KV_ROTARY_IMPL(4, 128) - DISPATCH_KV_ROTARY_IMPL(5, 64) - DISPATCH_KV_ROTARY_IMPL(5, 128) - DISPATCH_KV_ROTARY_IMPL(8, 64) - DISPATCH_KV_ROTARY_IMPL(8, 128) - DISPATCH_KV_ROTARY_IMPL(16, 64) - DISPATCH_KV_ROTARY_IMPL(16, 128) - DISPATCH_KV_ROTARY_IMPL(29, 64) - DISPATCH_KV_ROTARY_IMPL(29, 128) - DISPATCH_KV_ROTARY_IMPL(35, 64) - DISPATCH_KV_ROTARY_IMPL(35, 128) - DISPATCH_KV_ROTARY_IMPL(36, 64) - DISPATCH_KV_ROTARY_IMPL(36, 128) - DISPATCH_KV_ROTARY_IMPL(71, 64) - DISPATCH_KV_ROTARY_IMPL(71, 128) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(1) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(2) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(4) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(5) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(8) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(16) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(29) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(35) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(36) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(71) } #define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \ @@ -266,21 +289,41 @@ INSTANTIATE_KV_ROTARY_KERNEL(__half) INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16) #endif -#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \ - if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ - kv_rotary_pos_kernel \ - <<>>(kv_cache, \ - q, \ - k, \ - v, \ - nullptr, \ - 0.f, \ - batch_desc, \ - qkv_stride, \ - kv_cache_stride, \ - v_offset, \ +#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, PADDED_HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + nullptr, \ + 0.f, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ 0); +#define LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, HEAD_SIZE) \ + if (padded_head_size == 64) { \ + DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, 64); \ + } else if (padded_head_size == 128) { \ + DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, 128); \ + } else { \ + assert(false); \ + } + +#define LAUNCH_KV_COPY_FOR_Q_RATIO(Q_RATIO) \ + if (head_size == 64) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \ + } else if (head_size == 80) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \ + } else if (head_size == 128) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \ + } else { \ + assert(false); \ + } + template void launch_kv_copy_kernel(T* kv_cache, T* q, @@ -297,23 +340,19 @@ void launch_kv_copy_kernel(T* kv_cache, cudaStream_t stream) { constexpr int vector_T = kv_rot::granularity / sizeof(T); - const int threads_per_head = head_size / vector_T; + const int padded_head_size = next_pow2(head_size); + const int threads_per_head = padded_head_size / vector_T; const int tokens_per_block = kv_rot::threads / threads_per_head; const dim3 block(kv_rot::threads); const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; const dim3 grid(n_q_heads, token_blocks); - DISPATCH_KV_COPY_IMPL(1, 64) - DISPATCH_KV_COPY_IMPL(1, 128) - DISPATCH_KV_COPY_IMPL(2, 64) - DISPATCH_KV_COPY_IMPL(2, 128) - DISPATCH_KV_COPY_IMPL(4, 64) - DISPATCH_KV_COPY_IMPL(4, 128) - DISPATCH_KV_COPY_IMPL(5, 64) - DISPATCH_KV_COPY_IMPL(5, 128) - DISPATCH_KV_COPY_IMPL(8, 64) - DISPATCH_KV_COPY_IMPL(8, 128) + LAUNCH_KV_COPY_FOR_Q_RATIO(1) + LAUNCH_KV_COPY_FOR_Q_RATIO(2) + LAUNCH_KV_COPY_FOR_Q_RATIO(4) + LAUNCH_KV_COPY_FOR_Q_RATIO(5) + LAUNCH_KV_COPY_FOR_Q_RATIO(8) } #define INSTANTIATE_KV_COPY_KERNEL(TYPE) \ diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py index f206a4f5d28c7..25d3bac3ab081 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -18,7 +18,7 @@ class BlockedRotaryEmbeddings(DSKernelBase): """ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - supported_head_sizes = [64, 128] + supported_head_sizes = [64, 80, 128] supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71] def __init__(self, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py index 59da1db0f5d66..f6c0bb6c6b360 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py @@ -23,7 +23,7 @@ class BlockedTrainedRotaryEmbeddings(DSKernelBase): """ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - supported_head_sizes = [64, 128] + supported_head_sizes = [64, 80, 128] supported_q_ratios = [1, 2, 4, 5, 8] def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py index c9f6ffd37b3e6..a885eadd78a19 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py @@ -23,7 +23,7 @@ class LinearBlockedKVCopy(DSKernelBase): """ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - supported_head_sizes = [64, 128] + supported_head_sizes = [64, 80, 128] supported_q_ratios = [1, 2, 4, 5, 8] def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index ab1f984fba7e6..a08d966ac6d05 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -14,3 +14,4 @@ from .mistral import * from .mixtral import * from .falcon import * +from .phi import * diff --git a/deepspeed/inference/v2/model_implementations/phi/__init__.py b/deepspeed/inference/v2/model_implementations/phi/__init__.py new file mode 100644 index 0000000000000..032377792cc3f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .phi_policy import PhiPolicy diff --git a/deepspeed/inference/v2/model_implementations/phi/phi_containers.py b/deepspeed/inference/v2/model_implementations/phi/phi_containers.py new file mode 100644 index 0000000000000..ab6d0181611c9 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/phi_containers.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Phi-2 model looks like this: + +PhiForCausalLM( + (transformer): PhiModel( + (embd): Embedding( + (wte): Embedding(51200, 2560) + (drop): Dropout(p=0.0, inplace=False) + ) + (h): ModuleList( + (0-31): 32 x ParallelBlock( + (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) + (resid_dropout): Dropout(p=0.1, inplace=False) + (mixer): MHA( + (rotary_emb): RotaryEmbedding() + (Wqkv): Linear(in_features=2560, out_features=7680, bias=True) + (out_proj): Linear(in_features=2560, out_features=2560, bias=True) + (inner_attn): SelfAttention( + (drop): Dropout(p=0.0, inplace=False) + ) + (inner_cross_attn): CrossAttention( + (drop): Dropout(p=0.0, inplace=False) + ) + ) + (mlp): MLP( + (fc1): Linear(in_features=2560, out_features=10240, bias=True) + (fc2): Linear(in_features=10240, out_features=2560, bias=True) + (act): NewGELUActivation() + ) + ) + ) + ) + (lm_head): CausalLMHead( + (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) + (linear): Linear(in_features=2560, out_features=51200, bias=True) + ) + (loss): CausalLMLoss( + (loss_fct): CrossEntropyLoss() + ) +) +''' + + +class PhiTransformerContainer(LayerContainer): + """ + Transformer layer container for the Phi model. + """ + qkv_w: FusedQKVParameter + qkv_b: FusedQKVParameter + attn_out_w: AttentionOutputParameter + attn_out_b: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_1_b: MLP1Parameter + mlp_2_w: MLP2Parameter + mlp_2_b: MLP2Parameter + ln_gamma: NormParameter + ln_beta: NormParameter + + PARAM_MAPPING = { + "mixer.Wqkv.weight": "qkv_w.params", + "mixer.Wqkv.bias": "qkv_b.params", + "mixer.out_proj.weight": "attn_out_w.params", + "mixer.out_proj.bias": "attn_out_b.params", + "mlp.fc1.weight": "mlp_1_w.params", + "mlp.fc1.bias": "mlp_1_b.params", + "mlp.fc2.weight": "mlp_2_w.params", + "mlp.fc2.bias": "mlp_2_b.params", + "ln.weight": "ln_gamma.params", + "ln.bias": "ln_beta.params", + } + + +class PhiNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Phi model. + """ + word_emb: EmbeddingParameter + word_unembed_w: UnembedParameter + word_unembed_b: UnembedParameter + final_norm_gamma: NormParameter + final_norm_beta: NormParameter + + PARAM_MAPPING = { + "transformer.embd.wte.weight": "word_emb.params", + "lm_head.ln.weight": "final_norm_gamma.params", + "lm_head.ln.bias": "final_norm_beta.params", + "lm_head.linear.weight": "word_unembed_w.params", + "lm_head.linear.bias": "word_unembed_b.params", + } diff --git a/deepspeed/inference/v2/model_implementations/phi/phi_model.py b/deepspeed/inference/v2/model_implementations/phi/phi_model.py new file mode 100644 index 0000000000000..0e5fb8613044f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/phi_model.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .phi_containers import PhiNonTransformerContainer, PhiTransformerContainer + + +class PhiInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[PhiNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[PhiTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def model_dim(self) -> int: + return self._config.n_embd + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.n_head + + @property + def intermediate_dim(self) -> int: + n_inner = getattr(self._config, "n_inner", None) + return n_inner if n_inner is not None else 4 * self.model_dim + + @property + def n_heads_kv(self) -> int: + return getattr(self._config, "n_head_kv", None) or self.n_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.GELU + # TODO: Add support for New GeLU activation functions + # 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig() + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + attn_ln_out = hidden_states + attn_hidden_state = self.qkv(attn_ln_out, cur_params.qkv_w, b=cur_params.qkv_b) + attn_hidden_state = self.attn(attn_hidden_state, kv_cache, ragged_batch_info) + attention_output = self.attn_out(attn_hidden_state, cur_params.attn_out_w, b=cur_params.attn_out_b) + + mlp_ln_out = hidden_states + mlp_hidden_state = self.mlp_1(mlp_ln_out, cur_params.mlp_1_w, b=cur_params.mlp_1_b) + mlp_output = self.mlp_2(mlp_hidden_state, cur_params.mlp_2_w, b=cur_params.mlp_2_b) + + mlp_output.add_(attention_output) + + if self.tp_size > 1: + dist.all_reduce(mlp_output, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, mlp_output = self.norm(residual, mlp_output, next_params.ln_gamma, beta=next_params.ln_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(mlp_output) + + return residual, mlp_output + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed_w, + ragged_batch_info, + bias=self._non_transformer.word_unembed_b, + gamma=self._non_transformer.final_norm_gamma, + beta=self._non_transformer.final_norm_beta) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, + None, + gamma=self._transformer[0].ln_gamma, + beta=self._transformer[0].ln_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/phi/phi_policy.py b/deepspeed/inference/v2/model_implementations/phi/phi_policy.py new file mode 100644 index 0000000000000..ca2a1aae70003 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/phi_policy.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from ..phi.phi_containers import PhiNonTransformerContainer, PhiTransformerContainer +from ..phi.phi_model import PhiInferenceModel + + +class PhiPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> PhiInferenceModel: + return PhiInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + trans_container_cls = PhiTransformerContainer + transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['transformer.h'], transformer_containers) + + map.set_non_transformer_params(PhiNonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py index 40d70cbd4df76..36130902c665c 100644 --- a/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py +++ b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py @@ -9,8 +9,8 @@ from deepspeed.accelerator import get_accelerator from ....allocator import empty_from -from ....inference_utils import DtypeEnum -from ....kernels.core_ops import CUDAFPLN, BlasLibLinear, CUDARMSNorm +from ....inference_utils import DtypeEnum, ActivationType +from ....kernels.core_ops import CUDAFPLN, BlasLibLinear, CUDARMSNorm, CUDABiasActivation from ....kernels.ragged_ops import RaggedLogitsGather from ....ragged import RaggedBatchWrapper from ...interfaces import DSUnembedBase, DSUnembedRegistry @@ -65,6 +65,8 @@ def __init__(self, config: DSUnembedConfig, implementation_config: Dict[str, Any self._norm = None self._linear = BlasLibLinear(self._config.dtype) + # Here the activation kernel is being used to apply bias, hence the identity activation type! + self._act_fn = CUDABiasActivation(self._config.vocab_size, self._config.dtype, ActivationType.IDENTITY) self._intermediate = torch.empty((self._config.max_sequences, self._config.model_dim), dtype=self._config.dtype, @@ -82,6 +84,7 @@ def forward(self, hidden_states: torch.Tensor, vocab_embedding: torch.Tensor, ragged_metadata: RaggedBatchWrapper, + bias: Optional[torch.Tensor] = None, gamma: Optional[torch.Tensor] = None, beta: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -111,5 +114,7 @@ def forward(self, output = empty_from(self._output, (ragged_metadata.current_sequences, self._config.vocab_size)) self._linear(output, cut_down_hidden_states, vocab_embedding) + if bias is not None: + self._act_fn(output, bias) return output diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py index 90fe26eb4490d..18b5d3925ef57 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py @@ -17,7 +17,7 @@ def test_single_sequence_single_block(n_tokens: int, history_size: int): """ Validate that the copy works correctly """ - head_size = 64 + head_size = 80 n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py index 618c2d3b87ec6..c9ad982701607 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py @@ -72,11 +72,11 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: @pytest.mark.inference_v2_ops @pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 15), (1, 63)]) @pytest.mark.parametrize("trained_emb", [False, True]) -def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_emb: bool): +@pytest.mark.parametrize("head_size", [64, 80]) +def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_emb: bool, head_size: int): """ Validate that the copy works correctly """ - head_size = 64 n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 @@ -116,11 +116,11 @@ def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_ @pytest.mark.inference_v2_ops @pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) @pytest.mark.parametrize("trained_emb", [False, True]) -def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, trained_emb: bool): +@pytest.mark.parametrize("head_size", [64, 80]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, trained_emb: bool, head_size: int): """ Validate that the copy works correctly """ - head_size = 64 n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 @@ -159,8 +159,8 @@ def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, train @pytest.mark.inference_v2_ops @pytest.mark.parametrize("trained_emb", [False, True]) -def test_multi_sequences(trained_emb: bool) -> None: - head_size = 64 +@pytest.mark.parametrize("head_size", [64, 80]) +def test_multi_sequences(trained_emb: bool, head_size: int) -> None: n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64