diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 7c1f463009b2..8ba652a8de92 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -193,6 +193,19 @@ def communication_backend_name(self): def is_triton_supported(self): ... + # Graph operations + @abc.abstractmethod + def create_graph(self): + ... + + @abc.abstractmethod + def capture_to_graph(self, graph, pool=None, stream=None): + ... + + @abc.abstractmethod + def replay_graph(self, graph): + ... + # Tensor operations @property @abc.abstractmethod diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 05c4f6c650f0..a02777f5223b 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -198,8 +198,18 @@ def is_fp16_supported(self): def supported_dtypes(self): return [torch.float, torch.bfloat16] - # Tensor operations + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property def BFloat16Tensor(self): return torch.BFloat16Tensor diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 4b94ddb6865c..2d74daecf3df 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -222,6 +222,17 @@ def is_triton_supported(self): else: return False + # Graph operations + def create_graph(self): + return torch.cuda.CUDAGraph() + + def capture_to_graph(self, graph, pool=None, stream=None): + return torch.cuda.graph(graph, pool, stream) + + def replay_graph(self, graph): + graph.replay() + return + # Tensor operations @property diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index f0d4cac721b5..f6303cf9890f 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -166,6 +166,17 @@ def communication_backend_name(self): def is_triton_supported(self): return False + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property def BFloat16Tensor(self): diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 8bfd59cd2ad6..4e20445d9d32 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -176,6 +176,17 @@ def communication_backend_name(self): def is_triton_supported(self): return False + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property diff --git a/blogs/deepspeed-fastgen/README.md b/blogs/deepspeed-fastgen/README.md index 4081c780e09a..89fdefd22b0b 100644 --- a/blogs/deepspeed-fastgen/README.md +++ b/blogs/deepspeed-fastgen/README.md @@ -229,6 +229,8 @@ We currently support the following model architectures in this alpha release of * [Mistral](https://huggingface.co/models?other=mistral) * [OPT](https://huggingface.co/models?other=opt) * [Falcon](https://huggingface.co/models?other=falcon) +* [Mixtral](https://huggingface.co/models?other=mixtral) +* [Phi-2](https://huggingface.co/models?other=phi-msft) All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer. diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index 59888adf17c3..e60984d64b76 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -194,7 +194,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, #elif defined(__ENABLE_CANN__) if (dev_params) { size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - if (half_precision) memoryCopySize /= 2; + if (half_precision) memcpy_size /= 2; aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], @@ -202,6 +202,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); _buf_index = !_buf_index; + } #endif } *rounded_size = new_rounded_size; diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 44d3ed3cac61..b1a104b2571d 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -215,8 +215,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, #if defined(__ENABLE_CUDA__) if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } #elif defined(__ENABLE_CANN__) - if ((t / TILE) >= 2) { aclrtSynchronizeStream((_streams[_buf_index].stream()); - } + if ((t / TILE) >= 2) { aclrtSynchronizeStream(_streams[_buf_index].stream()); } #endif #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { @@ -274,7 +273,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, #elif defined(__ENABLE_CANN__) if (dev_params) { size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - if (half_precision) memoryCopySize /= 2; + if (half_precision) memcpy_size /= 2; aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], @@ -282,6 +281,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); _buf_index = !_buf_index; + } #endif } *rounded_size = new_rounded_size; diff --git a/csrc/includes/cpu_lion.h b/csrc/includes/cpu_lion.h index d83fe9473332..34c29eec47db 100644 --- a/csrc/includes/cpu_lion.h +++ b/csrc/includes/cpu_lion.h @@ -223,7 +223,7 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, #elif defined(__ENABLE_CANN__) if (dev_params) { size_t memcpy_size = copy_size * sizeof(_doubled_buffer[_buf_index][0]); - if (half_precision) memoryCopySize /= 2; + if (half_precision) memcpy_size /= 2; aclrtMemcpy(dev_params + t, memcpy_size, _doubled_buffer[_buf_index], @@ -231,6 +231,7 @@ void Lion_Optimizer::Step_AVX(size_t* rounded_size, aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE); _buf_index = !_buf_index; + } #endif } *rounded_size = new_rounded_size; diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 568211645f40..fb92c1e98421 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -595,7 +595,7 @@ def get_all_ranks_from_group(group=None): while True: group_ranks.append(cdb.get_global_rank(group, rank)) rank += 1 - except RuntimeError: + except (RuntimeError, ValueError): pass return group_ranks diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index ccecc8376ad6..5cdd99ff0b90 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -531,11 +531,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): get_accelerator().current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self.module(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True @@ -547,7 +547,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def model_times(self): diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 6b64ed3185a2..ca9fb113b15a 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -61,7 +61,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool: # We need to download the checkpoint files from HF if model_has_safetensors(self.model_name_or_path): # Prioritize downloading safetensors if they are available - allow_patterns = ["*.safetensors", "*.json", "*.pt"] + allow_patterns = ["*.safetensors", "*.json"] else: # Fallback to bin files when safetensors are not present allow_patterns = ["*.bin", "*.json", "*.pt"] diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index 9558125ff934..a17fa9fefbaa 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -17,7 +17,9 @@ OPTPolicy, Llama2Policy, MistralPolicy, + MixtralPolicy, FalconPolicy, + PhiPolicy, ) from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata @@ -105,8 +107,16 @@ def build_hf_engine(path: str, assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \ f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}" policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "mixtral": + # Ensure we're using the correct version of transformers for mistral + import transformers + assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \ + f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}" + policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "falcon": policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "phi-msft": + policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine) else: raise ValueError(f"Unsupported model type {model_config.model_type}") diff --git a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py index 988152b2e7c0..38a4ebd6fba3 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py @@ -10,4 +10,4 @@ from .logits_gather import * from .moe_gather import * from .moe_scatter import * -from .top_1_gating import * +from .top_k_gating import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h new file mode 100644 index 000000000000..abb9e15f8f6f --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#define TOP_K_SWITCH(N_TOP_K, ...) \ + [&] { \ + if (1 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 1; \ + __VA_ARGS__(); \ + } else if (2 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 2; \ + __VA_ARGS__(); \ + } \ + }() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp index 8493bbf4b9af..634a63b81a31 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp @@ -13,6 +13,8 @@ (C_TYPE*)k.data_ptr(), \ (C_TYPE*)v.data_ptr(), \ (C_TYPE*)inv_freq_ptr, \ + rotary_dim, \ + theta_base, \ batch_wrapper, \ qkv_stride, \ kv_cache_stride, \ @@ -51,6 +53,9 @@ void kv_trained_rotary_embeddings(torch::Tensor& kv_cache, TORCH_CHECK(n_tokens == k.size(0)); TORCH_CHECK(n_tokens == v.size(0)); + const float theta_base = 0.f; + const int32_t rotary_dim = inv_freq.size(0) * 2; + // Dimensions const int32_t block_size = kv_cache.size(1); const int32_t n_kv_heads = kv_cache.size(3); @@ -91,6 +96,8 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache, torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, + const int32_t rotary_dim, + const float theta_base, torch::Tensor& batch_metadata, torch::Tensor& seq_metadata, torch::Tensor& tokens_to_seq, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu index 980334f02b0b..c295b85a246e 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu @@ -21,12 +21,14 @@ constexpr int threads = 256; Supports head size 32, 64, 128, 256 */ -template +template __global__ void kv_rotary_pos_kernel(T* kv_cache, T* q, T* k, T* v, const T* inv_freq, + const int32_t rotary_dim, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, @@ -35,28 +37,31 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, { // Derived constexpr constexpr int vector_T = kv_rot::granularity / sizeof(T); - constexpr int threads_per_head = headSize / vector_T; - constexpr int half_head_size = headSize >> 1; + constexpr int real_threads_per_head = headSize / vector_T; + constexpr int threads_per_head = paddedHeadSize / vector_T; + constexpr int tokens_per_block = kv_rot::threads / threads_per_head; // CG helpers cg::thread_block tb = cg::this_thread_block(); cg::thread_block_tile warp = cg::tiled_partition(tb); - cg::thread_block_tile head_group = - cg::tiled_partition(warp); + cg::thread_block_tile head_group = cg::tiled_partition(tb); // Parallelize on the head dimension for X blocks const int head_idx = blockIdx.x; const int block_seq_idx = threadIdx.x / threads_per_head; - const int base_neuron_idx = (threadIdx.x * vector_T) % headSize; - const int half_idx = base_neuron_idx % half_head_size; - const int half_head_lanes = threads_per_head / 2; + const int base_neuron_idx = head_group.thread_rank() * vector_T; + const int half_rotary_size = rotary_dim / 2; + const int half_dim_lanes = half_rotary_size / vector_T; + const int half_idx = base_neuron_idx % half_rotary_size; // Multiple tokens processed by the same threadblock const int token_idx = blockIdx.y * tokens_per_block + block_seq_idx; const bool valid_token = token_idx < batch_desc.batch_metadata->n_tokens; - const bool load_inv_freq = (inv_freq != nullptr) && valid_token; + + const bool valid_thread = valid_token && (head_group.thread_rank() < real_threads_per_head); + const bool load_inv_freq = (inv_freq != nullptr) && valid_thread; // If we have GQA, then only one of the Q heads needs to do rotary + copy // for each of the heads in the group. @@ -67,9 +72,9 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, const int kv_head_idx = head_idx / qRatio; // Ensure we don't access invalid portions of the seq_metadata - const int32_t seq_id = (valid_token) ? batch_desc.tokens_to_seq[token_idx] : 0; + const int32_t seq_id = (valid_thread) ? batch_desc.tokens_to_seq[token_idx] : 0; const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_id]; - // This will give an invalid index if valid_token is false, but should never affect memory. + // This will give an invalid index if valid_thread is false, but should never affect memory. const int32_t global_token_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx); T* q_row = q + token_idx * qkv_stride + head_idx * headSize; @@ -81,7 +86,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, const KVCacheDescriptor kv_desc = batch_desc.kv_desc; const int32_t seq_kv_block_idx = global_token_idx / kv_desc.block_size; const int32_t mapped_kv_block_idx = - (valid_token) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0; + (valid_thread) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0; const int32_t kv_block_offset = global_token_idx % kv_desc.block_size; const int32_t kv_offset = @@ -94,12 +99,11 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, T k_reg[vector_T], v_reg[vector_T], inv_freq_reg[vector_T]; - mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_token); - mem_access::load_global(k_reg, k_row + base_neuron_idx, valid_token); - mem_access::load_global(v_reg, v_row + base_neuron_idx, valid_token); + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_thread); + mem_access::load_global(k_reg, k_row + base_neuron_idx, valid_thread); + mem_access::load_global(v_reg, v_row + base_neuron_idx, valid_thread); mem_access::load_global( inv_freq_reg, inv_freq + half_idx, load_inv_freq); - if constexpr (doRotary) { #pragma unroll for (int i = 0; i < vector_T; i++) { @@ -110,31 +114,37 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, inv_freq_flt = conversion::to(inv_freq_reg[i]) * (float)global_token_idx; } else { inv_freq_flt = - (float)((head_neuron_idx % half_head_size) * 2) / (float)headSize; + (float)((head_neuron_idx % half_rotary_size) * 2) / (float)rotary_dim; // Conversion to T and back means that both branches of this if statement // will produce the same results if using the same algo for producing the // freqs. - T trunc_freq = conversion::to(1.0 / powf(10000.0, inv_freq_flt)); + T trunc_freq = conversion::to(1.0 / powf(theta_base, inv_freq_flt)); inv_freq_flt = conversion::to(trunc_freq) * (float)global_token_idx; } - float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f; + float rotary_sign = (head_neuron_idx >= half_rotary_size) ? -1.0f : 1.0f; float q_f = conversion::to(q_reg[i]); float k_f = conversion::to(k_reg[i]); float q_rot = q_f * rotary_sign; float k_rot = k_f * rotary_sign; - const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes); - const float k_rot_temp = head_group.shfl_xor(k_rot, half_head_lanes); + const int target_lane = (head_neuron_idx < half_rotary_size) + ? head_group.thread_rank() + half_dim_lanes + : head_group.thread_rank() - half_dim_lanes; + + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + const float k_rot_temp = head_group.shfl(k_rot, target_lane); - q_reg[i] = - conversion::to(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt)); - k_reg[i] = - conversion::to(k_f * cosf(inv_freq_flt) + k_rot_temp * sinf(inv_freq_flt)); + if (base_neuron_idx < rotary_dim) { + q_reg[i] = conversion::to(q_f * cosf(inv_freq_flt) + + q_rot_temp * sinf(inv_freq_flt)); + k_reg[i] = conversion::to(k_f * cosf(inv_freq_flt) + + k_rot_temp * sinf(inv_freq_flt)); + } } } - if (valid_token) { + if (valid_thread) { mem_access::store_global(kv_cache + kv_offset + base_neuron_idx, k_reg); mem_access::store_global( @@ -143,7 +153,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, } else { T inv_freq_reg[vector_T]; - mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_token); + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_thread); mem_access::load_global( inv_freq_reg, inv_freq + half_idx, load_inv_freq); @@ -157,47 +167,75 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, inv_freq_flt = conversion::to(inv_freq_reg[i]) * (float)global_token_idx; } else { inv_freq_flt = - (float)((head_neuron_idx % half_head_size) * 2) / (float)headSize; - inv_freq_flt = 1.0 / powf(10000.0, inv_freq_flt) * (float)global_token_idx; + (float)((head_neuron_idx % half_rotary_size) * 2) / (float)rotary_dim; + inv_freq_flt = 1.0 / powf(theta_base, inv_freq_flt) * (float)global_token_idx; } - float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f; + float rotary_sign = (head_neuron_idx >= half_rotary_size) ? -1.0f : 1.0f; float q_f = conversion::to(q_reg[i]); float q_rot = q_f * rotary_sign; - const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes); + const int target_lane = (head_neuron_idx < half_rotary_size) + ? head_group.thread_rank() + half_dim_lanes + : head_group.thread_rank() - half_dim_lanes; - q_reg[i] = - conversion::to(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt)); + const float q_rot_temp = head_group.shfl(q_rot, target_lane); + if (base_neuron_idx < rotary_dim) + q_reg[i] = conversion::to(q_f * cosf(inv_freq_flt) + + q_rot_temp * sinf(inv_freq_flt)); } } } - if (valid_token && doRotary) { + if (valid_thread && doRotary) { mem_access::store_global(q_row + base_neuron_idx, q_reg); } } -#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE) \ - if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ - kv_rotary_pos_kernel \ - <<>>(kv_cache, \ - q, \ - k, \ - v, \ - inv_freq, \ - batch_desc, \ - qkv_stride, \ - kv_cache_stride, \ - v_offset, \ +#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, PADDED_HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + inv_freq, \ + rotary_dim, \ + theta_base, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ inv_freq_stride); +#define LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, HEAD_SIZE) \ + if (padded_head_size == 64) { \ + DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, 64); \ + } else if (padded_head_size == 128) { \ + DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE, 128); \ + } else { \ + assert(false); \ + } + +#define LAUNCH_KV_ROTARY_FOR_Q_RATIO(Q_RATIO) \ + if (head_size == 64) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \ + } else if (head_size == 80) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \ + } else if (head_size == 128) { \ + LAUNCH_KV_ROTARY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \ + } else { \ + assert(false); \ + } + template void launch_kv_rotary_kernel(T* kv_cache, T* q, T* k, T* v, T* inv_freq, + const int32_t rotary_dim, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, @@ -210,33 +248,26 @@ void launch_kv_rotary_kernel(T* kv_cache, cudaStream_t stream) { constexpr int vector_T = kv_rot::granularity / sizeof(T); - const int threads_per_head = head_size / vector_T; + + const int padded_head_size = next_pow2(head_size); + const int threads_per_head = padded_head_size / vector_T; + const int tokens_per_block = kv_rot::threads / threads_per_head; const dim3 block(kv_rot::threads); const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; const dim3 grid(n_q_heads, token_blocks); - DISPATCH_KV_ROTARY_IMPL(1, 64) - DISPATCH_KV_ROTARY_IMPL(1, 128) - DISPATCH_KV_ROTARY_IMPL(2, 64) - DISPATCH_KV_ROTARY_IMPL(2, 128) - DISPATCH_KV_ROTARY_IMPL(4, 64) - DISPATCH_KV_ROTARY_IMPL(4, 128) - DISPATCH_KV_ROTARY_IMPL(5, 64) - DISPATCH_KV_ROTARY_IMPL(5, 128) - DISPATCH_KV_ROTARY_IMPL(8, 64) - DISPATCH_KV_ROTARY_IMPL(8, 128) - DISPATCH_KV_ROTARY_IMPL(16, 64) - DISPATCH_KV_ROTARY_IMPL(16, 128) - DISPATCH_KV_ROTARY_IMPL(29, 64) - DISPATCH_KV_ROTARY_IMPL(29, 128) - DISPATCH_KV_ROTARY_IMPL(35, 64) - DISPATCH_KV_ROTARY_IMPL(35, 128) - DISPATCH_KV_ROTARY_IMPL(36, 64) - DISPATCH_KV_ROTARY_IMPL(36, 128) - DISPATCH_KV_ROTARY_IMPL(71, 64) - DISPATCH_KV_ROTARY_IMPL(71, 128) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(1) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(2) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(4) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(5) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(8) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(16) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(29) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(35) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(36) + LAUNCH_KV_ROTARY_FOR_Q_RATIO(71) } #define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \ @@ -245,6 +276,8 @@ void launch_kv_rotary_kernel(T* kv_cache, TYPE * k, \ TYPE * v, \ TYPE * inv_freq, \ + const int32_t rotary_dim, \ + const float theta_base, \ const BatchWrapperCPP batch_desc, \ const int qkv_stride, \ const int kv_cache_stride, \ @@ -262,10 +295,41 @@ INSTANTIATE_KV_ROTARY_KERNEL(__half) INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16) #endif -#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \ - if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ - kv_rotary_pos_kernel<<>>( \ - kv_cache, q, k, v, nullptr, batch_desc, qkv_stride, kv_cache_stride, v_offset, 0); +#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, PADDED_HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + nullptr, \ + -1, \ + 0.f, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + 0); + +#define LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, HEAD_SIZE) \ + if (padded_head_size == 64) { \ + DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, 64); \ + } else if (padded_head_size == 128) { \ + DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE, 128); \ + } else { \ + assert(false); \ + } + +#define LAUNCH_KV_COPY_FOR_Q_RATIO(Q_RATIO) \ + if (head_size == 64) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 64); \ + } else if (head_size == 80) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 80); \ + } else if (head_size == 128) { \ + LAUNCH_KV_COPY_FOR_Q_RATIO_HEAD_SIZE(Q_RATIO, 128); \ + } else { \ + assert(false); \ + } template void launch_kv_copy_kernel(T* kv_cache, @@ -283,23 +347,19 @@ void launch_kv_copy_kernel(T* kv_cache, cudaStream_t stream) { constexpr int vector_T = kv_rot::granularity / sizeof(T); - const int threads_per_head = head_size / vector_T; + const int padded_head_size = next_pow2(head_size); + const int threads_per_head = padded_head_size / vector_T; const int tokens_per_block = kv_rot::threads / threads_per_head; const dim3 block(kv_rot::threads); const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; const dim3 grid(n_q_heads, token_blocks); - DISPATCH_KV_COPY_IMPL(1, 64) - DISPATCH_KV_COPY_IMPL(1, 128) - DISPATCH_KV_COPY_IMPL(2, 64) - DISPATCH_KV_COPY_IMPL(2, 128) - DISPATCH_KV_COPY_IMPL(4, 64) - DISPATCH_KV_COPY_IMPL(4, 128) - DISPATCH_KV_COPY_IMPL(5, 64) - DISPATCH_KV_COPY_IMPL(5, 128) - DISPATCH_KV_COPY_IMPL(8, 64) - DISPATCH_KV_COPY_IMPL(8, 128) + LAUNCH_KV_COPY_FOR_Q_RATIO(1) + LAUNCH_KV_COPY_FOR_Q_RATIO(2) + LAUNCH_KV_COPY_FOR_Q_RATIO(4) + LAUNCH_KV_COPY_FOR_Q_RATIO(5) + LAUNCH_KV_COPY_FOR_Q_RATIO(8) } #define INSTANTIATE_KV_COPY_KERNEL(TYPE) \ diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh index be38ff30c46c..ff24b3f5bd80 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh @@ -18,6 +18,8 @@ void launch_kv_rotary_kernel(T* kv_cache, T* k, T* v, T* inv_freq, + const int32_t rotary_dim, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h index 0615825c0a21..c0700eda7147 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h @@ -45,6 +45,8 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache, torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, + const int32_t rotary_dim, + const float theta_base, torch::Tensor& batch_metadata, torch::Tensor& seq_metadata, torch::Tensor& tokens_to_seq, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py index 50d9aca061f3..7fe38d258e6c 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -18,10 +18,11 @@ class BlockedRotaryEmbeddings(DSKernelBase): """ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - supported_head_sizes = [64, 128] + supported_head_sizes = [64, 80, 128] supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71] - def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int, + theta_base: float) -> None: """ Args: head_size: The size of the attention head. @@ -51,6 +52,8 @@ def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch self.head_size = head_size self.n_q_heads = n_q_heads self.n_kv_heads = n_kv_heads + self.rotary_dim = rotary_dim + self.theta_base = theta_base def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None: """ @@ -66,5 +69,5 @@ def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: Ragg k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] - self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(), - ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) + self.kernel(kv_cache, q, k, v, self.rotary_dim, self.theta_base, ragged_batch.batch_metadata_buffer(), + ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py index 59da1db0f5d6..f8c5b2b13804 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py @@ -23,7 +23,7 @@ class BlockedTrainedRotaryEmbeddings(DSKernelBase): """ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - supported_head_sizes = [64, 128] + supported_head_sizes = [64, 80, 128] supported_q_ratios = [1, 2, 4, 5, 8] def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: @@ -65,7 +65,7 @@ def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: Ragg kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size] qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)] ragged_batch: Wrapper for the ragged batch. - inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, head_size // 2] + inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, rotary_dim // 2] """ q = qkv[:, :self.head_size * self.n_q_heads] diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py index c9f6ffd37b3e..a885eadd78a1 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py @@ -23,7 +23,7 @@ class LinearBlockedKVCopy(DSKernelBase): """ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - supported_head_sizes = [64, 128] + supported_head_sizes = [64, 80, 128] supported_q_ratios = [1, 2, 4, 5, 8] def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp index e55e1f48c125..506629406f0d 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp @@ -16,6 +16,8 @@ n_channels, \ n_experts, \ n_tokens, \ + n_top_k, \ + normalize_scales, \ at::cuda::getCurrentCUDAStream()); \ return; \ } @@ -27,17 +29,21 @@ void moe_gather(torch::Tensor& layer_output, const torch::Tensor& moe_output, const torch::Tensor& scores, const torch::Tensor& mapped_slots, - const torch::Tensor& expert_count) + const torch::Tensor& expert_count, + const bool normalize_scales) { const int32_t n_channels = layer_output.size(1); const int32_t n_experts = expert_count.size(0); const int32_t n_tokens = layer_output.size(0); + const int32_t n_top_k = mapped_slots.size(1); - TORCH_CHECK(moe_output.size(0) == n_tokens); + TORCH_CHECK(moe_output.size(0) == n_tokens * n_top_k); TORCH_CHECK(moe_output.size(1) == n_channels); TORCH_CHECK(scores.size(0) == n_tokens); TORCH_CHECK(mapped_slots.size(0) == n_tokens); + TORCH_CHECK(scores.size(1) == n_top_k); + TORCH_CHECK(layer_output.scalar_type() == moe_output.scalar_type()); TORCH_CHECK(scores.scalar_type() == torch::kFloat32); TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu index c2fae24f5080..4153a2a3636f 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu @@ -7,7 +7,8 @@ #include "ds_kernel_utils.h" #include "moe_gather.cuh" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" namespace gather { @@ -16,65 +17,105 @@ constexpr int threads = 256; } // namespace gather -template +template __global__ void moe_gather_kernel(T* layer_output, const T* moe_output, const float* scores, const int32_t* mapped_slots, int32_t* expert_counts, const int32_t n_channels, - const int32_t n_experts) + const int32_t n_experts, + const bool normalize_scales) { constexpr int32_t vector_size = gather::access_granularity / sizeof(T); constexpr int32_t stride = vector_size * gather::threads; const int32_t token_idx = blockIdx.x; - const int32_t mapped_slot = mapped_slots[token_idx]; + int32_t token_mapped_slots[N_TOP_K]; + + bool all_slots_invalid = true; + for (int i = 0; i < N_TOP_K; i++) { + token_mapped_slots[i] = mapped_slots[token_idx * N_TOP_K + i]; + all_slots_invalid &= (token_mapped_slots[i] == gating::unassigned); + } if (token_idx == 0) { // Reset expert counts for its next use. if (threadIdx.x < n_experts) { expert_counts[threadIdx.x] = 0; } } - if (mapped_slot == gating::unassigned) { - // This token was not assigned. + if (all_slots_invalid) { + // This token was not assigned to anything. // TODO(cmikeh2): It's possible we want different behavior here moving forward. return; } - const float score = scores[token_idx]; + float token_scores[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] = scores[token_idx * N_TOP_K + i]; } + + if (normalize_scales) { + // Normalize the scores so that they sum to 1. + float sum = 0.0f; + for (int i = 0; i < N_TOP_K; i++) { sum += token_scores[i]; } + + if (sum > 0.0f) { + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] /= sum; } + } + } + const int32_t channel_offset = threadIdx.x * vector_size; - const T* moe_output_base = moe_output + mapped_slot * n_channels + channel_offset; + const T* moe_output_bases[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + moe_output_bases[i] = moe_output + token_mapped_slots[i] * n_channels + channel_offset; + } + T* layer_output_base = layer_output + token_idx * n_channels + channel_offset; #pragma unroll for (int i = 0; i < copyUnroll; i++) { - T reg_buffer[vector_size]; - if (i * stride + channel_offset < n_channels) { - mem_access::load_global(reg_buffer, - moe_output_base + i * stride); + float accum_buffer[vector_size]; + for (int j = 0; j < vector_size; j++) { + accum_buffer[j] = reduce::init(); + } + +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + T reg_buffer[vector_size]; + mem_access::load_global( + reg_buffer, moe_output_bases[j] + i * stride); +#pragma unroll + for (int k = 0; k < vector_size; k++) { + float up_cast = conversion::to(reg_buffer[k]); + accum_buffer[k] += up_cast * token_scores[j]; + } + } + + T store_buffer[vector_size]; #pragma unroll for (int j = 0; j < vector_size; j++) { - // There are accuracy implications of downcasting the score to a 16-bit - // data type, so we up-convert the input to 32-bit, multiply, and then - // down-convert back to 16-bit. - float up_cast = conversion::to(reg_buffer[j]); - reg_buffer[j] = conversion::to(up_cast * score); + store_buffer[j] = conversion::to(accum_buffer[j]); } mem_access::store_global(layer_output_base + i * stride, - reg_buffer); + store_buffer); } } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - moe_gather_kernel<<>>( \ - layer_output, moe_output, scores, mapped_slots, expert_counts, n_channels, n_experts); \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_gather_kernel<<>>(layer_output, \ + moe_output, \ + scores, \ + mapped_slots, \ + expert_counts, \ + n_channels, \ + n_experts, \ + normalize_scales); \ break; template @@ -86,6 +127,8 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, const int32_t n_experts, const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, cudaStream_t stream) { constexpr int vals_per_unroll = gather::threads * gather::access_granularity / sizeof(T); @@ -94,14 +137,16 @@ void launch_moe_gather(T* layer_output, const dim3 block(gather::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1) - LAUNCH_FOR_UNROLL(2) - LAUNCH_FOR_UNROLL(3) - LAUNCH_FOR_UNROLL(4) - LAUNCH_FOR_UNROLL(5) - LAUNCH_FOR_UNROLL(6) - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1) + LAUNCH_FOR_UNROLL(2) + LAUNCH_FOR_UNROLL(3) + LAUNCH_FOR_UNROLL(4) + LAUNCH_FOR_UNROLL(5) + LAUNCH_FOR_UNROLL(6) + } + }); } #define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \ @@ -113,6 +158,8 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, \ const int32_t n_experts, \ const int32_t n_tokens, \ + const int32_t n_top_k, \ + const bool normalize_scales, \ cudaStream_t stream); INSTANTIATE_GATHER_FOR_TYPE(__half) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh index f98a727ead58..b348d0cfb330 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh @@ -17,4 +17,6 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, const int32_t n_experts, const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h index 7ffe9f8b4dc6..ec9e03057eb8 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h @@ -16,4 +16,5 @@ void moe_gather(torch::Tensor& layer_output, const torch::Tensor& moe_output, const torch::Tensor& scores, const torch::Tensor& mapped_slots, - const torch::Tensor& expert_counts); + const torch::Tensor& expert_counts, + const bool normalize_scales); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py index c37683d03fbe..f03938171ba4 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py @@ -18,7 +18,7 @@ class MoEGather(DSKernelBase): supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - def __init__(self, dtype: DtypeEnum, channels: int) -> None: + def __init__(self, dtype: DtypeEnum, channels: int, normalize_scores: bool = False) -> None: if not isinstance(dtype, DtypeEnum): dtype = DtypeEnum(dtype) @@ -31,6 +31,7 @@ def __init__(self, dtype: DtypeEnum, channels: int) -> None: inf_module = RaggedOpsBuilder().load() self.kernel = inf_module.moe_gather + self.normalize_scores = normalize_scores def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: torch.Tensor, mapped_slots: torch.Tensor, expert_counts: torch.Tensor) -> torch.Tensor: @@ -40,13 +41,13 @@ def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: Arguments: layer_output (torch.Tensor): The output of the layer of shape [n_tokens, hidden_size]. This has been scaled appropriately. - moe_output (torch.Tensor): The output of the MoE of shape [n_tokens, hidden_size]. + moe_output (torch.Tensor): The output of the MoE of shape [n_tokens * n_top_k, hidden_size]. scores (torch.Tensor): The gating scores of shape [n_tokens]. - mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. The index of token ``i`` in layer_output is ``mapped_slots[i]``. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. The indices of token ``i`` in layer_output is ``mapped_slots[i]``. expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. This is passed to fuse the clearing of this data structure into the gather. Returns: layer_output """ - self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts) + self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts, self.normalize_scores) return layer_output diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp index 902f1cc0ea15..8f7ecbd1a287 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp @@ -18,6 +18,7 @@ n_channels, \ n_tokens, \ n_experts, \ + n_top_k, \ at::cuda::getCurrentCUDAStream()); \ return; \ } @@ -36,13 +37,17 @@ void moe_scatter(torch::Tensor& moe_input, { const int32_t n_tokens = activations.size(0); const int32_t n_channels = activations.size(1); + const int32_t n_top_k = assignments.size(1); // Should have a lot of matching buffer sizes here. - TORCH_CHECK(n_tokens == moe_input.size(0)); TORCH_CHECK(n_tokens == assignments.size(0)); TORCH_CHECK(n_tokens == offsets.size(0)); TORCH_CHECK(n_channels == moe_input.size(1)); + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k * n_tokens == moe_input.size(0)); + TORCH_CHECK(n_top_k == mapped_slots.size(1)); + const int32_t n_experts = expert_count_cumsums.size(0); TORCH_CHECK(moe_input.scalar_type() == activations.scalar_type()); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu index 0746cd7be645..d3eb4f649e79 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu @@ -4,9 +4,9 @@ // DeepSpeed Team #include "ds_kernel_utils.h" -#include "moe_scatter.cuh" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; @@ -15,10 +15,11 @@ namespace scatter { constexpr int access_granularity = 16; constexpr int threads = 256; constexpr int warps = threads / hw_warp_size; +constexpr int max_experts = 1024; } // namespace scatter -template +template __global__ void moe_scatter_kernel(T* moe_input, int64_t* expert_count_cumsums, int32_t* mapped_slots, @@ -38,88 +39,78 @@ __global__ void moe_scatter_kernel(T* moe_input, // Bank aligned and sufficient __shared__ int32_t red_buffer[32]; - __shared__ int32_t token_0_row; + __shared__ int32_t expert_offsets[scatter::max_experts]; // CG helpers cg::thread_block tb = cg::this_thread_block(); cg::thread_block_tile warp = cg::tiled_partition(tb); - const int assigned_expert = assignments[token_idx]; - - // For the different codepaths, we'll converge on this variable for doing - // the token copy. - int32_t token_base_row; + // Fetch the assigned experts for this token. + int assigned_experts[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { + assigned_experts[i] = assignments[token_idx * N_TOP_K + i]; + } - if (token_idx == 0) { - // Token 0 will perform a cumsum on the data - int32_t expert_vals; - if (tidx < n_experts) { - expert_vals = expert_counts[tidx]; + bool all_unassigned = true; + for (int i = 0; i < N_TOP_K; i++) { + if (assigned_experts[i] != gating::unassigned) { + all_unassigned = false; } else { - expert_vals = 0; + mapped_slots[token_idx * N_TOP_K + i] = gating::unassigned; } + } + if (all_unassigned && token_idx != 0) return; + + // Do a prefix scan on the expert counts to get the base offsets. Here we use the + // single up-sweep variant. + int32_t expert_vals; + if (tidx < n_experts) { + expert_vals = expert_counts[tidx]; + } else { + expert_vals = 0; + } #pragma unroll - for (int i = 1; i < hw_warp_size; i *= 2) { - int32_t maybe_add = warp.shfl_up(expert_vals, i); - expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; - } + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(expert_vals, i); + expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; + } - if (warp.thread_rank() == hw_warp_size - 1) { - mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); - } + if (warp.thread_rank() == hw_warp_size - 1) { + mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); + } - tb.sync(); + tb.sync(); - int32_t phase_2_val = 0; - if (warp.thread_rank() < scatter::warps) { - mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); - } + int32_t phase_2_val = 0; + if (warp.thread_rank() < scatter::warps) { + mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); + } #pragma unroll - for (int i = 1; i < hw_warp_size; i *= 2) { - int32_t maybe_add = warp.shfl_up(phase_2_val, i); - phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; - } - - int warp_offset = 0; - if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } - const int32_t expert_cumsum = warp_offset + expert_vals; - - if (tidx < n_experts) { - int64_t expert_cumsum_64 = (int64_t)expert_cumsum; - expert_count_cumsums[tidx] = expert_cumsum_64; - } - - if (assigned_expert == gating::unassigned) return; - if (assigned_expert - 1 == tidx) token_0_row = expert_cumsum; + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(phase_2_val, i); + phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; + } - tb.sync(); + int warp_offset = 0; + if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } + const int32_t expert_cumsum = warp_offset + expert_vals; - if (assigned_expert != 0) { - token_base_row = token_0_row; - } else { - token_base_row = 0; - } + // Token 0 will write the + if (token_idx == 0 && tidx < n_experts) { + int64_t expert_cumsum_64 = (int64_t)expert_cumsum; + expert_count_cumsums[tidx] = expert_cumsum_64; + } - } else if (assigned_expert == gating::unassigned) { - // For whatever reason, don't need to perform the copy, so we'll early return - // and signal this wasn't mapped with a negative 1. - if (tidx == 0) mapped_slots[token_idx] = gating::unassigned; - return; - } else { - // For all other valid tokens, we can just do a block-scoped sum. - if (tidx < assigned_expert) { - token_base_row = expert_counts[tidx]; - } else { - token_base_row = 0; - } + // Since token 0 has now written the expert cumsum to global memory, + // if it has no valid experts, we can early return. + if (token_idx == 0 && all_unassigned) return; - warp.sync(); + if (tidx < n_experts) { expert_offsets[tidx] = expert_cumsum; } - // TODO(cmikeh2): Shouldn't use the internal api. - reduce::_block(tb, warp, &token_base_row); - } + // Ensure all the expert offsets are written in shared memory. + tb.sync(); // Data copy to appropriate location const int32_t thread_offset = tidx * vector_size; @@ -127,9 +118,16 @@ __global__ void moe_scatter_kernel(T* moe_input, const int32_t base_load_offset = token_idx * n_channels + thread_offset; const T* load_base_ptr = activations + base_load_offset; - const int32_t store_row = token_base_row + offsets[token_idx]; - const int32_t base_store_offset = store_row * n_channels + thread_offset; - T* store_base_ptr = moe_input + base_store_offset; + int32_t store_rows[N_TOP_K]; + T* store_base_ptrs[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + const int32_t cur_expert_offset = + (assigned_experts[i] > 0) ? expert_offsets[assigned_experts[i] - 1] : 0; + store_rows[i] = cur_expert_offset + offsets[token_idx * N_TOP_K + i]; + const int32_t base_store_offset = store_rows[i] * n_channels + thread_offset; + store_base_ptrs[i] = moe_input + base_store_offset; + } #pragma unroll for (int i = 0; i < copyUnroll; i++) { @@ -138,25 +136,31 @@ __global__ void moe_scatter_kernel(T* moe_input, if (i * load_stride + thread_offset < n_channels) { mem_access::load_global(tmp_buf, load_base_ptr + i * load_stride); - mem_access::store_global(store_base_ptr + i * load_stride, - tmp_buf); +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + mem_access::store_global( + store_base_ptrs[j] + i * load_stride, tmp_buf); + } } } - if (threadIdx.x == 0) { mapped_slots[token_idx] = store_row; } + if (threadIdx.x == 0) { + for (int i = 0; i < N_TOP_K; i++) { mapped_slots[token_idx * N_TOP_K + i] = store_rows[i]; } + } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - moe_scatter_kernel<<>>(moe_input, \ - expert_count_cumsums, \ - mapped_slots, \ - activations, \ - assignments, \ - expert_counts, \ - offsets, \ - n_channels, \ - n_experts); \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_scatter_kernel \ + <<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + activations, \ + assignments, \ + expert_counts, \ + offsets, \ + n_channels, \ + n_experts); \ break; template @@ -170,6 +174,7 @@ void launch_moe_scatter(T* moe_input, const int32_t n_channels, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream) { constexpr int vals_per_unroll = scatter::threads * scatter::access_granularity / sizeof(T); @@ -178,14 +183,16 @@ void launch_moe_scatter(T* moe_input, const dim3 block(scatter::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1); - LAUNCH_FOR_UNROLL(2); - LAUNCH_FOR_UNROLL(3); - LAUNCH_FOR_UNROLL(4); - LAUNCH_FOR_UNROLL(5); - LAUNCH_FOR_UNROLL(6); - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } + }); } #define INSTANTIATE_SCATTER_FOR_TYPE(TYPE) \ @@ -199,6 +206,7 @@ void launch_moe_scatter(T* moe_input, const int32_t, \ const int32_t, \ const int32_t, \ + const int32_t, \ cudaStream_t); INSTANTIATE_SCATTER_FOR_TYPE(__half); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh index 5c94cb0ef734..d9756c80f05a 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh @@ -19,4 +19,5 @@ void launch_moe_scatter(T* moe_input, const int32_t n_channels, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py index 5cd6ae5f0fe2..7efcedb4e880 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py @@ -40,13 +40,13 @@ def __call__(self, moe_input: torch.Tensor, expert_cumsum: torch.Tensor, mapped_ Scatters the hidden states such that the token stride for each expert's input is contiguous. Arguments: - moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, hidden_size]. + moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens * n_top_k, hidden_size]. expert_cumsum (torch.Tensor): The cumulative sum of the expert counts of shape [n_experts]. - mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. hidden_states (torch.Tensor): The hidden states of shape [n_tokens, hidden_size]. expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. - assignments (torch.Tensor): The expert assignments of shape [n_tokens]. - offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens]. + assignments (torch.Tensor): The expert assignments of shape [n_tokens, n_top_k]. + offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens, n_top_K]. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input (with scattered values), the cumsum of the offsets (for the MoE kernels themselves), and the assignments Tensor modified in place to show which row that token was mapped to in the input. diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp index 1c09fc52bbb1..f320f46e2620 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp @@ -12,7 +12,7 @@ #include "logits_gather.h" #include "moe_gather.h" #include "moe_scatter.h" -#include "top_1_gating.h" +#include "top_k_gating.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -43,6 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // moe_scatter.h m.def("moe_scatter", &moe_scatter, "MoE scatter for top-1-gating."); - // top_1_gating.h - m.def("top_1_gating", &top_1_gating, "Top-1 gating for MoE with ragged batch awareness."); + // top_k_gating.h + m.def("top_k_gating", &top_k_gating, "Top-1 gating for MoE with ragged batch awareness."); } diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py similarity index 69% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py index b50a0838d9f8..487735b015b0 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .top_1_gating import RaggedTop1Gating +from .top_k_gating import RaggedTopKGating diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp similarity index 67% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp index 55c68454b228..5eec7e2b955f 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp @@ -3,12 +3,12 @@ // DeepSpeed Team -#include "top_1_gating.h" +#include "top_k_gating.h" #include -#define DISPATCH_TOP_1_GATING(T_TYPE, C_TYPE) \ +#define DISPATCH_TOP_K_GATING(T_TYPE, C_TYPE) \ if (logits.options().dtype() == torch::T_TYPE) { \ - launch_top_1_gating((int32_t*)expert_counts.data_ptr(), \ + launch_top_k_gating((int32_t*)expert_counts.data_ptr(), \ (float*)scores.data_ptr(), \ (int32_t*)assignments.data_ptr(), \ (int32_t*)offsets.data_ptr(), \ @@ -16,14 +16,15 @@ batch_metadata_ptr, \ n_tokens, \ n_experts, \ + n_top_k, \ at::cuda::getCurrentCUDAStream()); \ return; \ } /* -Perform softmax plus atomics in order to do first pass of top_1_gating. +Perform softmax plus atomics in order to do first pass of top_k_gating. */ -void top_1_gating(torch::Tensor& expert_counts, +void top_k_gating(torch::Tensor& expert_counts, torch::Tensor& scores, torch::Tensor& assignments, torch::Tensor& offsets, @@ -31,10 +32,15 @@ void top_1_gating(torch::Tensor& expert_counts, torch::Tensor& batch_metadata) { const int32_t n_tokens = scores.size(0); + const int32_t n_top_k = scores.size(1); - // Should have the same buffer size for scores and offsets + // Should have the same buffer size for scores, offsets, and assignments TORCH_CHECK(n_tokens == offsets.size(0)); TORCH_CHECK(n_tokens == logits.size(0)); + TORCH_CHECK(n_tokens == assignments.size(0)); + + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k == assignments.size(1)); TORCH_CHECK(expert_counts.scalar_type() == torch::kInt32); TORCH_CHECK(scores.scalar_type() == torch::kFloat); @@ -45,11 +51,11 @@ void top_1_gating(torch::Tensor& expert_counts, const RaggedBatchDescriptor* batch_metadata_ptr = reinterpret_cast(batch_metadata.data_ptr()); - DISPATCH_TOP_1_GATING(kFloat, float) - DISPATCH_TOP_1_GATING(kHalf, __half) + DISPATCH_TOP_K_GATING(kFloat, float) + DISPATCH_TOP_K_GATING(kHalf, __half) #ifdef BF16_AVAILABLE - DISPATCH_TOP_1_GATING(kBFloat16, __nv_bfloat16) + DISPATCH_TOP_K_GATING(kBFloat16, __nv_bfloat16) #endif - TORCH_CHECK(false, "Unsupported dtype for logits in top_1_gating"); + TORCH_CHECK(false, "Unsupported dtype for logits in top_k_gating"); } diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu similarity index 59% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu index 02daee9f692e..58f95c045593 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu @@ -6,12 +6,13 @@ #include "conversion_utils.h" #include "memory_access_utils.h" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; -template -__global__ void top_1_gating_kernel(int32_t* expert_counts, +template +__global__ void top_k_gating_kernel(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -30,8 +31,11 @@ __global__ void top_1_gating_kernel(int32_t* expert_counts, // Padding tokens do not require if (token_idx >= batch_metadata->n_tokens) { if (threadIdx.x == 0) { - offsets[token_idx] = gating::unassigned; - assignments[token_idx] = gating::unassigned; +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + assignments[token_idx * TOP_K + i] = gating::unassigned; + offsets[token_idx * TOP_K + i] = gating::unassigned; + } } return; } @@ -44,34 +48,46 @@ __global__ void top_1_gating_kernel(int32_t* expert_counts, } else { reduce::init(&logit_val); } + float reduce_val = logit_val; + + int32_t local_assigned_experts[TOP_K]; + float local_assigned_logits[TOP_K]; // Training code tends to use ``torch.argmax`` to select the expert, which // which has ties broken by the lower index. Since our fused comparison algorithm // breaks ties by the higher index (since it's the lower 32-bits of the 64-bit // comparison), we invert the expert index to break ties by the lower index. int32_t inverted_expert = n_experts - expert_idx - 1; - // Perform softmax - const reduce::IdxReduceResult res = - reduce::idx_reduce(tb, warp, logit_val, inverted_expert); - // Recover the original expert index - const int32_t assigned_expert = n_experts - res.idx - 1; - const float max_logit = res.val; + // Find the top k logits + for (int i = 0; i < TOP_K; ++i) { + const reduce::IdxReduceResult res = + reduce::idx_reduce(tb, warp, reduce_val, inverted_expert); + local_assigned_experts[i] = n_experts - res.idx - 1; + local_assigned_logits[i] = res.val; + + // Set the max logit to -inf so that it is not selected again + if (threadIdx.x == n_experts - res.idx - 1) { reduce::init(&reduce_val); } + } + + const float max_logit = local_assigned_logits[0]; float softmax_sum = __expf(logit_val - max_logit); reduce::block(tb, warp, softmax_sum); - // Compute the score - const float score = __expf(max_logit - max_logit) / softmax_sum; + for (int i = 0; i < TOP_K; ++i) { + const float softmax = __expf(local_assigned_logits[i] - max_logit) / softmax_sum; - if (threadIdx.x == 0) { - scores[token_idx] = score; - assignments[token_idx] = assigned_expert; - offsets[token_idx] = atomicAdd(expert_counts + assigned_expert, 1); + if (threadIdx.x == 0) { + scores[token_idx * TOP_K + i] = softmax; + assignments[token_idx * TOP_K + i] = local_assigned_experts[i]; + offsets[token_idx * TOP_K + i] = + atomicAdd(expert_counts + local_assigned_experts[i], 1); + } } } template -void launch_top_1_gating(int32_t* expert_counts, +void launch_top_k_gating(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -79,17 +95,20 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream) { const dim3 grid(n_tokens); const dim3 block(((n_experts + hw_warp_size - 1) / hw_warp_size) * hw_warp_size); - top_1_gating_kernel<<>>( - expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); + TOP_K_SWITCH(n_top_k, [&] { + top_k_gating_kernel<<>>( + expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); + }); } -#define INSTANTIATE_TOP_1_KERNEL(T) \ - template void launch_top_1_gating(int32_t * expert_counts, \ +#define INSTANTIATE_top_k_KERNEL(T) \ + template void launch_top_k_gating(int32_t * expert_counts, \ float* scores, \ int32_t* assignments, \ int32_t* offsets, \ @@ -97,10 +116,10 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, \ const int32_t n_tokens, \ const int32_t n_experts, \ + const int32_t n_top_k, \ cudaStream_t stream); -INSTANTIATE_TOP_1_KERNEL(float) -INSTANTIATE_TOP_1_KERNEL(__half) +INSTANTIATE_top_k_KERNEL(float) INSTANTIATE_top_k_KERNEL(__half) #ifdef BF16_AVAILABLE -INSTANTIATE_TOP_1_KERNEL(__nv_bfloat16) + INSTANTIATE_top_k_KERNEL(__nv_bfloat16) #endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh similarity index 87% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh index c83ad56ff2f1..c525cc5f524e 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh @@ -13,7 +13,7 @@ constexpr int unassigned = -1; } // namespace gating template -void launch_top_1_gating(int32_t* expert_counts, +void launch_top_k_gating(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -21,4 +21,5 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h similarity index 86% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h index b431f4cad30c..00840c3c93b5 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h @@ -8,12 +8,12 @@ #include #include #include "ragged_dtypes.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" /* Perform softmax plus atomics to get token mapping. */ -void top_1_gating(torch::Tensor& expert_counts, +void top_k_gating(torch::Tensor& expert_counts, torch::Tensor& scores, torch::Tensor& assignments, torch::Tensor& offsets, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py similarity index 87% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py index 1df97c2e9f8d..72ba2b6019bb 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py @@ -13,7 +13,7 @@ from deepspeed.ops.op_builder import RaggedOpsBuilder -class RaggedTop1Gating(DSKernelBase): +class RaggedTopKGating(DSKernelBase): """ CUDA implementation of top-1 gating. This will perform a softmax on the logits, and return the scale as well as its idx within that expert's allocation. @@ -26,28 +26,28 @@ def __init__(self, logit_dtype: DtypeEnum) -> None: if not isinstance(logit_dtype, DtypeEnum): logit_dtype = DtypeEnum(logit_dtype) - if logit_dtype not in RaggedTop1Gating.supported_logit_dtypes: + if logit_dtype not in RaggedTopKGating.supported_logit_dtypes: raise RuntimeError(f"Unsupported logit dtype {logit_dtype}") inf_module = RaggedOpsBuilder().load() - self.kernel = inf_module.top_1_gating + self.kernel = inf_module.top_k_gating def __call__(self, expert_counts: torch.Tensor, scores: torch.Tensor, assignments: torch.Tensor, offsets: torch.Tensor, logits: torch.Tensor, batch: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Perform the ragged top_1_gating. + Perform the ragged top_k_gating. Arguments: expert_counts (torch.Tensor): Tensor of 0s of shape [n_experts] to be filled with number of tokens assigned to each expert. This must be filled with 0s else the copy kernel will buffer overflow. In order to minimize the zero-fill cost, it is recommended to write to 0 during the MoE output remapping. - scores (torch.Tensor): Preallocated output of shape [n_tokens] to place expert scaling + scores (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place expert scaling value. - expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens] to place + expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place which expert a token has been assigned to. - expert_offset (torch.Tensor): Preallocated output of shape [n_tokens] to place which + expert_offset (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place which offset within an experts group a token is. logits (torch.Tensor): Raw logits of gating function. batch (RaggedBatchWrapper): Batch information for ragged tensor. diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index 481be2e5940e..a08d966ac6d0 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -12,4 +12,6 @@ from .llama_v2 import * from .opt import * from .mistral import * +from .mixtral import * from .falcon import * +from .phi import * diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py index df5f1427a5cf..8ababf567ba9 100644 --- a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py +++ b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py @@ -33,7 +33,7 @@ class UnfusedMoEMLP1Parameter(ParameterBase): and need to be joined into a single group. """ - experts: ParamList("num_experts") # noqa: F821 + experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: stacked_experts = torch.stack([p for p in self.experts], dim=0) @@ -46,7 +46,7 @@ class UnfusedMoEMLP2Parameter(ParameterBase): and need to be joined into a single group. """ - experts: ParamList("num_experts") # noqa: F821 + experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: stacked_experts = torch.stack([p for p in self.experts], dim=0) @@ -57,13 +57,22 @@ class UnfusedMoEGatedMLPParameter(ParameterBase): """ MoE Parameter for a gated activation function in which the gating matrix is not fused in the same parameter as the non-gating matrix. + + This is a stacked version of the ``GatedMLPParameter``. Please see that class for more + documentation on the layout of the parameters. """ - gating_experts: ParamList("num_experts") # noqa: F821 + gating_experts: ParamList("n_experts") # noqa: F821 - up_experts: ParamList("num_experts") # noqa: F821 + up_experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: - fused_params = [torch.cat([gate, weight], dim=0) for gate, weight in zip(self.gating_experts, self.up_experts)] - stacked_params = torch.stack(fused_params, dim=0) - return self.inference_model.transform_moe_mlp_2_param(stacked_params) + transposed_experts = [] + for gate, up in zip(self.gating_experts, self.up_experts): + assert gate.shape[0] == up.shape[0], "Gated MLP parameters must have the same number of neurons." + total_neurons = gate.shape[0] + up.shape[0] + fused_expert = torch.cat([gate, up], dim=-1).reshape(total_neurons, -1) + transposed_experts.append(fused_expert) + + stacked_experts = torch.stack(transposed_experts, dim=0) + return self.inference_model.transform_moe_mlp_1_param(stacked_experts) diff --git a/deepspeed/inference/v2/model_implementations/falcon/__init__.py b/deepspeed/inference/v2/model_implementations/falcon/__init__.py index ff66879b44be..20f37538274c 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/__init__.py +++ b/deepspeed/inference/v2/model_implementations/falcon/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .falcon_policy import FalconPolicy +from .policy import FalconPolicy diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py b/deepspeed/inference/v2/model_implementations/falcon/container.py similarity index 97% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py rename to deepspeed/inference/v2/model_implementations/falcon/container.py index f3cbe6609cdd..caccfe1ecb00 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py +++ b/deepspeed/inference/v2/model_implementations/falcon/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF Falcon 7b model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py b/deepspeed/inference/v2/model_implementations/falcon/model.py similarity index 98% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_model.py rename to deepspeed/inference/v2/model_implementations/falcon/model.py index a00f754744a4..d1ccc38280a0 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py +++ b/deepspeed/inference/v2/model_implementations/falcon/model.py @@ -11,12 +11,12 @@ from ...allocator import empty_from from ...inference_utils import ActivationType, DtypeEnum -from ...model_implementations import * +from .. import * from ...modules.configs import * from ...modules.interfaces import * from ...ragged import RaggedBatchWrapper -from .falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer +from .container import FalconNonTransformerContainer, FalconTransformerContainer class FalconInferenceModel(DSTransformerModelBase): diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py b/deepspeed/inference/v2/model_implementations/falcon/policy.py similarity index 74% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py rename to deepspeed/inference/v2/model_implementations/falcon/policy.py index 5672d45a8d13..c6612090a0df 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py +++ b/deepspeed/inference/v2/model_implementations/falcon/policy.py @@ -6,10 +6,10 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.falcon.falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer -from ...model_implementations.falcon.falcon_containers import FalconNewArchTransformerContainer -from ...model_implementations.falcon.falcon_model import FalconInferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import FalconNonTransformerContainer, FalconTransformerContainer +from .container import FalconNewArchTransformerContainer +from .model import FalconInferenceModel class FalconPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index 8f6a0b7fa688..e78a161b4cd0 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -521,12 +521,26 @@ def transform_norm_param(self, param: torch.Tensor) -> InferenceParameter: class DSMoETransformerModelBase(DSTransformerModelBase): @property - def num_experts(self) -> int: + def n_experts(self) -> int: """ Return the number of experts in the model. """ raise NotImplementedError("Attempted to access an unimplemented number of experts") + @property + def n_top_k(self) -> int: + """ + Number of experts per token. + """ + raise NotImplementedError("Attempted to access an unimplemented number of experts per token") + + @property + def normalize_expert_scores(self) -> bool: + """ + Whether to normalize expert scores. If true, sum(expert_scores) = 1. + """ + raise NotImplementedError("Attempted to access an unimplemented normalization flag") + def make_moe_layer(self) -> None: """ Instantiates the MoE layer for the model. This sets the `self.moe` attribute. @@ -538,9 +552,11 @@ def make_moe_layer(self) -> None: model_dim=self.model_dim, intermediate_features=sharded_dim, activation=self.mlp_activation_fn, - n_experts=self.num_experts, + n_experts=self.n_experts, + top_k=self.n_top_k, input_dtype=self.activation_dtype, output_dtype=self.activation_dtype, + normalize_scores=self.normalize_expert_scores, ) self.moe = heuristics.instantiate_moe(moe_config, self._engine_config) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py index 5d2b5ae562ee..79605a76a4c2 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .llama_v2_policy import Llama2Policy +from .policy import Llama2Policy diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py b/deepspeed/inference/v2/model_implementations/llama_v2/container.py similarity index 95% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py rename to deepspeed/inference/v2/model_implementations/llama_v2/container.py index e9c473ce512b..9de9bdb34574 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF Llama model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py b/deepspeed/inference/v2/model_implementations/llama_v2/model.py similarity index 80% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py rename to deepspeed/inference/v2/model_implementations/llama_v2/model.py index 9b628f77de01..735e8f52cca3 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/model.py @@ -11,12 +11,13 @@ from ...allocator import empty_from from ...inference_utils import ActivationType, DtypeEnum -from ...model_implementations import * +from .. import * from ...modules.configs import * from ...modules.interfaces import * +from ...modules import heuristics from ...ragged import RaggedBatchWrapper -from .llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer class Llama2InferenceModel(DSTransformerModelBase): @@ -105,6 +106,27 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + """ Forward implementations """ @@ -145,8 +167,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) if self.tp_size > 1: @@ -176,8 +197,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge Performs unembedding of the hidden states to logits. This will only sample the final token of each sequence. """ - logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, - self._non_transformer.final_norm) + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) if self.tp_size > 1: comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py similarity index 76% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py rename to deepspeed/inference/v2/model_implementations/llama_v2/policy.py index c8253be79fad..bb13ab6d5bf4 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py @@ -6,9 +6,9 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.llama_v2.llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer -from ...model_implementations.llama_v2.llama_v2_model import Llama2InferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer +from .model import Llama2InferenceModel class Llama2Policy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/mistral/model.py b/deepspeed/inference/v2/model_implementations/mistral/model.py index d9b06b91e308..9c707026f9dd 100644 --- a/deepspeed/inference/v2/model_implementations/mistral/model.py +++ b/deepspeed/inference/v2/model_implementations/mistral/model.py @@ -14,6 +14,7 @@ from ...model_implementations import * from ...modules.configs import * from ...modules.interfaces import * +from ...modules import heuristics from ...ragged import RaggedBatchWrapper from .container import MistralNonTransformerContainer, MistralTransformerContainer @@ -104,6 +105,27 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + """ Forward implementations """ @@ -144,8 +166,7 @@ def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_st kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) if self.tp_size > 1: @@ -175,8 +196,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge Performs unembedding of the hidden states to logits. This will only sample the final token of each sequence. """ - logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, - self._non_transformer.final_norm) + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) if self.tp_size > 1: comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) diff --git a/deepspeed/inference/v2/model_implementations/mistral/policy.py b/deepspeed/inference/v2/model_implementations/mistral/policy.py index f6d0a0fe5987..b67ec311c952 100644 --- a/deepspeed/inference/v2/model_implementations/mistral/policy.py +++ b/deepspeed/inference/v2/model_implementations/mistral/policy.py @@ -5,10 +5,10 @@ from typing import Any -from deepspeed.inference.v2.config_v2 import RaggedInferenceEngineConfig -from deepspeed.inference.v2.model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from deepspeed.inference.v2.model_implementations.mistral.container import MistralNonTransformerContainer, MistralTransformerContainer -from deepspeed.inference.v2.model_implementations.mistral.model import MistralInferenceModel +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MistralNonTransformerContainer, MistralTransformerContainer +from .model import MistralInferenceModel class MistralPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/mixtral/__init__.py b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py new file mode 100644 index 000000000000..2cb1aa889291 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import MixtralPolicy diff --git a/deepspeed/inference/v2/model_implementations/mixtral/container.py b/deepspeed/inference/v2/model_implementations/mixtral/container.py new file mode 100644 index 000000000000..6ec4a0552b8f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/container.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class MixtralTransformerContainer(LayerContainer): + + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + moe_gate: MoEGatingWeightParameter + moe_mlp_1: UnfusedMoEGatedMLPParameter + moe_mlp_2: UnfusedMoEMLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "block_sparse_moe.gate.weight": "moe_gate.params", + "block_sparse_moe.experts.*.w1.weight": "moe_mlp_1.gating_experts", + "block_sparse_moe.experts.*.w3.weight": "moe_mlp_1.up_experts", + "block_sparse_moe.experts.*.w2.weight": "moe_mlp_2.experts", + } + + +class MixtralNonTransformerContainer(LayerContainer): + + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "lm_head.weight": "word_unembed.params", + "model.norm.weight": "final_norm.params", + } diff --git a/deepspeed/inference/v2/model_implementations/mixtral/model.py b/deepspeed/inference/v2/model_implementations/mixtral/model.py new file mode 100644 index 000000000000..d0cae0ff307b --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/model.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .container import MixtralNonTransformerContainer, MixtralTransformerContainer + + +class MixtralInferenceModel(DSMoETransformerModelBase): + """ + Inference model implementation for Mixtral models. + """ + + _non_transformer: Optional[MixtralNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[MixtralTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_position_embeddings + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Inherited from `DSMoETransformerModelBase` + """ + + @property + def n_experts(self) -> int: + return self._config.num_local_experts + + @property + def n_top_k(self) -> int: + return self._config.num_experts_per_tok + + @property + def normalize_expert_scores(self) -> bool: + return True + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_moe_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma) + + hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1, + cur_params.moe_mlp_2) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/mixtral/policy.py b/deepspeed/inference/v2/model_implementations/mixtral/policy.py new file mode 100644 index 000000000000..2f0087919720 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/policy.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MixtralTransformerContainer, MixtralNonTransformerContainer +from .model import MixtralInferenceModel + + +class MixtralPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> MixtralInferenceModel: + return MixtralInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + + map = ContainerMap() + + transformer_containers = [MixtralTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(MixtralNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/opt/container.py b/deepspeed/inference/v2/model_implementations/opt/container.py index 5ddbbde3f141..e97599ef8e50 100644 --- a/deepspeed/inference/v2/model_implementations/opt/container.py +++ b/deepspeed/inference/v2/model_implementations/opt/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF OPT model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/opt/model.py b/deepspeed/inference/v2/model_implementations/opt/model.py index fa221e15a0b7..8bad12f10475 100644 --- a/deepspeed/inference/v2/model_implementations/opt/model.py +++ b/deepspeed/inference/v2/model_implementations/opt/model.py @@ -131,8 +131,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) if self.tp_size > 1: @@ -164,8 +163,11 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid return residual, hidden_states def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: - logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, - self._non_transformer.final_norm_w, self._non_transformer.final_norm_b) + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm_w, + beta=self._non_transformer.final_norm_b) if self.tp_size > 1: comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) diff --git a/deepspeed/inference/v2/model_implementations/opt/policy.py b/deepspeed/inference/v2/model_implementations/opt/policy.py index af5750260ead..d57d5beb48d5 100644 --- a/deepspeed/inference/v2/model_implementations/opt/policy.py +++ b/deepspeed/inference/v2/model_implementations/opt/policy.py @@ -6,9 +6,9 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.opt.container import OPTNonTransformerContainer, OPTTransformerContainer -from ...model_implementations.opt.model import OPTInferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import OPTNonTransformerContainer, OPTTransformerContainer +from .model import OPTInferenceModel class OPTPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/phi/__init__.py b/deepspeed/inference/v2/model_implementations/phi/__init__.py new file mode 100644 index 000000000000..3ab107e75a91 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import PhiPolicy diff --git a/deepspeed/inference/v2/model_implementations/phi/containers.py b/deepspeed/inference/v2/model_implementations/phi/containers.py new file mode 100644 index 000000000000..ab6d0181611c --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/containers.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Phi-2 model looks like this: + +PhiForCausalLM( + (transformer): PhiModel( + (embd): Embedding( + (wte): Embedding(51200, 2560) + (drop): Dropout(p=0.0, inplace=False) + ) + (h): ModuleList( + (0-31): 32 x ParallelBlock( + (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) + (resid_dropout): Dropout(p=0.1, inplace=False) + (mixer): MHA( + (rotary_emb): RotaryEmbedding() + (Wqkv): Linear(in_features=2560, out_features=7680, bias=True) + (out_proj): Linear(in_features=2560, out_features=2560, bias=True) + (inner_attn): SelfAttention( + (drop): Dropout(p=0.0, inplace=False) + ) + (inner_cross_attn): CrossAttention( + (drop): Dropout(p=0.0, inplace=False) + ) + ) + (mlp): MLP( + (fc1): Linear(in_features=2560, out_features=10240, bias=True) + (fc2): Linear(in_features=10240, out_features=2560, bias=True) + (act): NewGELUActivation() + ) + ) + ) + ) + (lm_head): CausalLMHead( + (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) + (linear): Linear(in_features=2560, out_features=51200, bias=True) + ) + (loss): CausalLMLoss( + (loss_fct): CrossEntropyLoss() + ) +) +''' + + +class PhiTransformerContainer(LayerContainer): + """ + Transformer layer container for the Phi model. + """ + qkv_w: FusedQKVParameter + qkv_b: FusedQKVParameter + attn_out_w: AttentionOutputParameter + attn_out_b: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_1_b: MLP1Parameter + mlp_2_w: MLP2Parameter + mlp_2_b: MLP2Parameter + ln_gamma: NormParameter + ln_beta: NormParameter + + PARAM_MAPPING = { + "mixer.Wqkv.weight": "qkv_w.params", + "mixer.Wqkv.bias": "qkv_b.params", + "mixer.out_proj.weight": "attn_out_w.params", + "mixer.out_proj.bias": "attn_out_b.params", + "mlp.fc1.weight": "mlp_1_w.params", + "mlp.fc1.bias": "mlp_1_b.params", + "mlp.fc2.weight": "mlp_2_w.params", + "mlp.fc2.bias": "mlp_2_b.params", + "ln.weight": "ln_gamma.params", + "ln.bias": "ln_beta.params", + } + + +class PhiNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Phi model. + """ + word_emb: EmbeddingParameter + word_unembed_w: UnembedParameter + word_unembed_b: UnembedParameter + final_norm_gamma: NormParameter + final_norm_beta: NormParameter + + PARAM_MAPPING = { + "transformer.embd.wte.weight": "word_emb.params", + "lm_head.ln.weight": "final_norm_gamma.params", + "lm_head.ln.bias": "final_norm_beta.params", + "lm_head.linear.weight": "word_unembed_w.params", + "lm_head.linear.bias": "word_unembed_b.params", + } diff --git a/deepspeed/inference/v2/model_implementations/phi/model.py b/deepspeed/inference/v2/model_implementations/phi/model.py new file mode 100644 index 000000000000..a95b12bb119f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/model.py @@ -0,0 +1,256 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from .. import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .containers import PhiNonTransformerContainer, PhiTransformerContainer + + +class PhiInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[PhiNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[PhiTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties inherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties inherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def model_dim(self) -> int: + return self._config.n_embd + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.n_head + + @property + def intermediate_dim(self) -> int: + n_inner = getattr(self._config, "n_inner", None) + return n_inner if n_inner is not None else 4 * self.model_dim + + @property + def n_heads_kv(self) -> int: + return getattr(self._config, "n_head_kv", None) or self.n_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.GELU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(rotate_dim=self._config.rotary_dim) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + attn_ln_out = hidden_states + attn_hidden_state = self.qkv(attn_ln_out, cur_params.qkv_w, b=cur_params.qkv_b) + attn_hidden_state = self.attn(attn_hidden_state, kv_cache, ragged_batch_info) + attention_output = self.attn_out(attn_hidden_state, cur_params.attn_out_w, b=cur_params.attn_out_b) + + mlp_ln_out = hidden_states + mlp_hidden_state = self.mlp_1(mlp_ln_out, cur_params.mlp_1_w, b=cur_params.mlp_1_b) + mlp_output = self.mlp_2(mlp_hidden_state, cur_params.mlp_2_w, b=cur_params.mlp_2_b) + + mlp_output.add_(attention_output) + + if self.tp_size > 1: + dist.all_reduce(mlp_output, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, mlp_output = self.norm(residual, mlp_output, next_params.ln_gamma, beta=next_params.ln_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(mlp_output) + + return residual, mlp_output + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed_w, + ragged_batch_info, + bias=self._non_transformer.word_unembed_b, + gamma=self._non_transformer.final_norm_gamma, + beta=self._non_transformer.final_norm_beta) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, + None, + gamma=self._transformer[0].ln_gamma, + beta=self._transformer[0].ln_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/phi/policy.py b/deepspeed/inference/v2/model_implementations/phi/policy.py new file mode 100644 index 000000000000..1e9db24022f5 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/phi/policy.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .containers import PhiNonTransformerContainer, PhiTransformerContainer +from .model import PhiInferenceModel + + +class PhiPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> PhiInferenceModel: + return PhiInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + trans_container_cls = PhiTransformerContainer + transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['transformer.h'], transformer_containers) + + map.set_non_transformer_params(PhiNonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/modules/configs/__init__.py b/deepspeed/inference/v2/modules/configs/__init__.py index 19b9fb99ddea..3429e69b47de 100644 --- a/deepspeed/inference/v2/modules/configs/__init__.py +++ b/deepspeed/inference/v2/modules/configs/__init__.py @@ -3,7 +3,12 @@ # DeepSpeed Team -from .attention_configs import (DSSelfAttentionConfig, PositionalEmbeddingType, MaskingType) +from .attention_configs import ( + DSSelfAttentionConfig, + PositionalEmbeddingType, + MaskingType, + RotateHalfConfig, +) from .embedding_config import DSEmbeddingsConfig from .linear_config import DSLinearConfig from .moe_config import DSMoEConfig diff --git a/deepspeed/inference/v2/modules/configs/attention_configs.py b/deepspeed/inference/v2/modules/configs/attention_configs.py index bcdc3d2613d5..be6a3535024c 100644 --- a/deepspeed/inference/v2/modules/configs/attention_configs.py +++ b/deepspeed/inference/v2/modules/configs/attention_configs.py @@ -4,10 +4,11 @@ # DeepSpeed Team from enum import Enum -from typing import Dict +from typing import Dict, Optional from ...inference_utils import DtypeEnum from ...modules.ds_module import DSModuleConfig +from deepspeed.runtime.config_utils import DeepSpeedConfigModel class PositionalEmbeddingType(Enum): @@ -25,6 +26,28 @@ class PositionalEmbeddingType(Enum): alibi = "alibi" +class RotateHalfConfig(DeepSpeedConfigModel): + + use_trained_freqs: bool = False + """ + Whether to use a passed `trained_freqs` tensor for the attention implementation + or to use default synthesized frequencies. + """ + + theta_base: float = 10_000.0 + """ + Base for theta. This will only be used if `use_trained_freqs` is False. + """ + + rotate_dim: Optional[int] = None + """ + How many neurons to rotate. If None, then all neurons will be rotated. Many external configs + will set this number to half the head dimension and then internally multiply by 2. To make it + more clear to understand what is happening (rotate_dim < head_dim -> then only partial rotation), + we do not do this multiplication internally. + """ + + class MaskingType(Enum): # No masking @@ -79,4 +102,9 @@ class DSSelfAttentionConfig(DSModuleConfig): positional_embedding_type: PositionalEmbeddingType = PositionalEmbeddingType.none # Positional embedding args - positional_embedding_args: Dict = {} + positional_embedding_config: Optional[RotateHalfConfig] = None + """ + To extend this for the other positional embedding types, we would need to add + new configs for each type (as necessary) and annotate this with the + Union[RotateHalfConfig, OtherConfig, ...] type. + """ diff --git a/deepspeed/inference/v2/modules/configs/moe_config.py b/deepspeed/inference/v2/modules/configs/moe_config.py index 1a88d54af19f..7bc944f55e17 100644 --- a/deepspeed/inference/v2/modules/configs/moe_config.py +++ b/deepspeed/inference/v2/modules/configs/moe_config.py @@ -48,3 +48,9 @@ class DSMoEConfig(DSModuleConfig): """ Activation function of the first MLP1 """ + + normalize_scores: bool = False + """ + Whether normalization is applied to the selected scores. If true, the module + should rescale the scores such that their sum is 1.0. + """ diff --git a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py index bb482f0c58d6..5f41b5ff6e13 100644 --- a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py +++ b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py @@ -68,9 +68,16 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st Args: config (DSSelfAttentionConfig): The self attention config for all attention DSModules. - implementation_config (Dict[str, Any]): The implementation config for this DSModule may - contain a `trained_freqs` key. If passed, the implementation will expect a `trained_freqs` - tensor in the `forward` method and will not synthesize the frequencies internally. + implementation_config (Dict[str, Any]): + There are two (dependent) potential components in the implementtion config. + + 1. `trained_freqs` - If the embedding weights for RoPE are trained, the implementation + config should contain {'trained_freqs': True}. This will mean the implementation will + expect a `trained_freqs` tensor in the `forward` method and will not synthesize the + values internally. + + 2. `theta_base` - The base value for synthesized frequencies in the rotary embeddings. + This will only be used if `trained_freqs` is False or not present in the `implementation_config`. If this is not included, the default value of 10000.0 will be used. """ super().__init__(config, implementation_config) @@ -79,14 +86,18 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st self._kv_copy = LinearBlockedKVCopy(self._config.head_size, self._config.n_heads_q, self._config.n_heads_kv, self._config.input_dtype) elif embed_type == PositionalEmbeddingType.rotate_half: - use_trained_freqs = "trained_freqs" in self._config.positional_embedding_args and self._config.positional_embedding_args[ - "trained_freqs"] - if use_trained_freqs: + rotary_config = config.positional_embedding_config + if rotary_config.use_trained_freqs: + # Theta and rotary dim are effectively embedded into either the values (theta) or the shape (rotary_dim) + # of the trained_freqs tensor. self._kv_copy = BlockedTrainedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, self._config.n_heads_kv, self._config.input_dtype) else: + theta_base = rotary_config.theta_base + rotary_dim = rotary_config.rotate_dim if rotary_config.rotate_dim is not None else self._config.head_size self._kv_copy = BlockedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, - self._config.n_heads_kv, self._config.input_dtype) + self._config.n_heads_kv, self._config.input_dtype, rotary_dim, + theta_base) self._softmax_scale = self._config.scale_factor diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py index e43a737515ed..38c0000d7f78 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -9,12 +9,12 @@ from deepspeed.accelerator import get_accelerator from ....allocator import empty_from -from ....inference_utils import ActivationType -from ....kernels.core_ops import BlasLibLinear +from ....inference_utils import ActivationType, is_gated +from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation from ....kernels.ragged_ops import ( MoEGather, MoEScatter, - RaggedTop1Gating, + RaggedTopKGating, ) from ....ragged import RaggedBatchWrapper @@ -42,11 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool: if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: return False - if config.top_k != 1: - return False - - if config.activation in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]: - # Currently not supporting gated activations in MoE + if config.top_k != 1 and config.top_k != 2: return False return True @@ -57,15 +53,24 @@ def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) - # Convenience variables for frequently accessed items. self.max_tokens = self._config.max_tokens self.n_experts = self._config.n_experts + self.n_top_k = self._config.top_k self.intermediate_dim = self._config.intermediate_features - self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=config.activation) + moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation + + self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) + if is_gated(self._config.activation): + self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, + self._config.activation) + else: + self._activation = None + self._gate_proj = BlasLibLinear(self._config.input_dtype) - self._top_1_gate = RaggedTop1Gating(config.input_dtype) + self._top_1_gate = RaggedTopKGating(config.input_dtype) self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) - self._moe_gather = MoEGather(config.input_dtype, config.model_dim) + self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) self._create_buffers() @@ -78,32 +83,38 @@ def _create_buffers(self): self._expert_counts = torch.empty((self.n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - self._scores = torch.empty((self._config.max_tokens, ), + self._scores = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) - self._assignments = torch.empty((self._config.max_tokens, ), + self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - self._offsets = torch.empty((self._config.max_tokens, ), + self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) # Scatter buffers - self._moe_input = torch.empty((self._config.max_tokens, self._config.model_dim), + self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), dtype=self._config.input_dtype, device=get_accelerator().current_device()) self._expert_cumsum = torch.empty((self._config.n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - self._mapped_slots = torch.empty((self._config.max_tokens, ), + self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) # GEMM Buffers - self._intermediate = torch.empty((self._config.max_tokens, self._config.intermediate_features), + self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), dtype=self._config.output_dtype, device=get_accelerator().current_device()) - self._output_unordered = torch.empty((self._config.max_tokens, self._config.model_dim), + if self._activation is not None: + self._gated_intermediate = torch.empty( + (self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), dtype=self._config.output_dtype, device=get_accelerator().current_device()) @@ -167,11 +178,11 @@ def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, # Get views on the buffers for gating logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) - scores = empty_from(self._scores, (hidden_states.shape[0], )) - assignments = empty_from(self._assignments, (hidden_states.shape[0], )) - offsets = empty_from(self._offsets, (hidden_states.shape[0], )) - mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], )) - moe_input = empty_from(self._moe_input, (hidden_states.shape[0], self._moe_input.shape[-1])) + scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) + assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) + offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) + mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) + moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) self._gate_proj(logits, hidden_states, gate_w) self._expert_counts.zero_() @@ -200,18 +211,31 @@ def forward(self, moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) # Get views on the buffers for GEMM - intermediate = empty_from(self._intermediate, (hidden_states.shape[0], self._intermediate.shape[-1])) + intermediate = empty_from(self._intermediate, + (hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) output_unordered = empty_from(self._output_unordered, - (hidden_states.shape[0], self._output_unordered.shape[-1])) + (hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) - self._mlp_1( - intermediate, - moe_input, - mlp_1_w, - expert_cumsum, - mlp_1_b, - ) + if self._activation is not None: + gated_intermediate = empty_from( + self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) + self._mlp_1( + gated_intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) + self._activation(intermediate, gated_intermediate) + else: + self._mlp_1( + intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) self._mlp_2( output_unordered, diff --git a/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py index 40d70cbd4df7..36130902c665 100644 --- a/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py +++ b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py @@ -9,8 +9,8 @@ from deepspeed.accelerator import get_accelerator from ....allocator import empty_from -from ....inference_utils import DtypeEnum -from ....kernels.core_ops import CUDAFPLN, BlasLibLinear, CUDARMSNorm +from ....inference_utils import DtypeEnum, ActivationType +from ....kernels.core_ops import CUDAFPLN, BlasLibLinear, CUDARMSNorm, CUDABiasActivation from ....kernels.ragged_ops import RaggedLogitsGather from ....ragged import RaggedBatchWrapper from ...interfaces import DSUnembedBase, DSUnembedRegistry @@ -65,6 +65,8 @@ def __init__(self, config: DSUnembedConfig, implementation_config: Dict[str, Any self._norm = None self._linear = BlasLibLinear(self._config.dtype) + # Here the activation kernel is being used to apply bias, hence the identity activation type! + self._act_fn = CUDABiasActivation(self._config.vocab_size, self._config.dtype, ActivationType.IDENTITY) self._intermediate = torch.empty((self._config.max_sequences, self._config.model_dim), dtype=self._config.dtype, @@ -82,6 +84,7 @@ def forward(self, hidden_states: torch.Tensor, vocab_embedding: torch.Tensor, ragged_metadata: RaggedBatchWrapper, + bias: Optional[torch.Tensor] = None, gamma: Optional[torch.Tensor] = None, beta: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -111,5 +114,7 @@ def forward(self, output = empty_from(self._output, (ragged_metadata.current_sequences, self._config.vocab_size)) self._linear(output, cut_down_hidden_states, vocab_embedding) + if bias is not None: + self._act_fn(output, bias) return output diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py index 8eff4560b4d0..ecc3c52a5834 100644 --- a/deepspeed/inference/v2/ragged/ragged_manager.py +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -127,10 +127,7 @@ def get_sequence(self, uid: int) -> Optional[DSSequenceDescriptor]: Get the sequence descriptor for the given sequence id. If the sequence does not exist, then None is returned. """ - if uid not in self._seqs: - return None - - return self._seqs[uid] + return self._seqs.get(uid, None) def get_or_create_sequence(self, uid: int) -> DSSequenceDescriptor: """ @@ -139,8 +136,9 @@ def get_or_create_sequence(self, uid: int) -> DSSequenceDescriptor: if one may be allocated and should not be used from APIs that are attempting to test the schedulability of a hypothetical batch. """ - if uid in self._seqs: - return self._seqs[uid] + seq = self.get_sequence(uid) + if seq is not None: + return seq else: return self._create_sequence(uid) diff --git a/deepspeed/model_implementations/diffusers/unet.py b/deepspeed/model_implementations/diffusers/unet.py index 7da571975958..8d5ddd95437a 100644 --- a/deepspeed/model_implementations/diffusers/unet.py +++ b/deepspeed/model_implementations/diffusers/unet.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -29,7 +30,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def forward(self, *inputs, **kwargs): @@ -53,11 +54,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True diff --git a/deepspeed/model_implementations/diffusers/vae.py b/deepspeed/model_implementations/diffusers/vae.py index 05084f1b985a..ce50ade647a8 100644 --- a/deepspeed/model_implementations/diffusers/vae.py +++ b/deepspeed/model_implementations/diffusers/vae.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -27,7 +28,7 @@ def _graph_replay_decoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_decoder_kwargs[k].copy_(kwargs[k]) - self._decoder_cuda_graph.replay() + get_accelerator().replay_graph(self._decoder_cuda_graph) return self.static_decoder_output def _decode(self, x, return_dict=True, generator=None): @@ -43,11 +44,11 @@ def _create_cuda_graph_decoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._decoder_cuda_graph = torch.cuda.CUDAGraph() + self._decoder_cuda_graph = get_accelerator().create_graph() self.static_decoder_inputs = inputs self.static_decoder_kwargs = kwargs - with torch.cuda.graph(self._decoder_cuda_graph): + with get_accelerator().capture_to_graph(self._decoder_cuda_graph): self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs) self.decoder_cuda_graph_created = True @@ -70,7 +71,7 @@ def _graph_replay_encoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_encoder_kwargs[k].copy_(kwargs[k]) - self._encoder_cuda_graph.replay() + get_accelerator().replay_graph(self._encoder_cuda_graph) return self.static_encoder_output def _encode(self, x, return_dict=True): @@ -86,11 +87,11 @@ def _create_cuda_graph_encoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._encoder_cuda_graph = torch.cuda.CUDAGraph() + self._encoder_cuda_graph = get_accelerator().create_graph() self.static_encoder_inputs = inputs self.static_encoder_kwargs = kwargs - with torch.cuda.graph(self._encoder_cuda_graph): + with get_accelerator().capture_to_graph(self._encoder_cuda_graph): self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs) self.encoder_cuda_graph_created = True @@ -113,7 +114,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._all_cuda_graph.replay() + get_accelerator().replay_graph(self._all_cuda_graph) return self.static_output def forward(self, *inputs, **kwargs): @@ -137,11 +138,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._all_cuda_graph = torch.cuda.CUDAGraph() + self._all_cuda_graph = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._all_cuda_graph): + with get_accelerator().capture_to_graph(self._all_cuda_graph): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.all_cuda_graph_created = True diff --git a/deepspeed/model_implementations/transformers/clip_encoder.py b/deepspeed/model_implementations/transformers/clip_encoder.py index 8d9291896986..848a5b48dcf1 100644 --- a/deepspeed/model_implementations/transformers/clip_encoder.py +++ b/deepspeed/model_implementations/transformers/clip_encoder.py @@ -38,7 +38,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[self.iter][k].copy_(kwargs[k]) - self._cuda_graphs[self.iter].replay() + get_accelerator().replay_graph(self._cuda_graphs[self.iter]) return self.static_output[self.iter] def forward(self, *inputs, **kwargs): @@ -63,11 +63,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph() + self._cuda_graphs[self.iter] = get_accelerator().create_graph() self.static_inputs[self.iter] = inputs self.static_kwargs[self.iter] = kwargs - with torch.cuda.graph(self._cuda_graphs[self.iter]): + with get_accelerator().capture_to_graph(self._cuda_graphs[self.iter]): self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter], **self.static_kwargs[self.iter]) diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 772d23f2d0ac..02e0b197e927 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -439,7 +439,9 @@ def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint num_non_fp_tensors += 1 continue - arg.data = inp.data + arg.data = torch.empty([], device=arg.device).data + arg.saved_data = inp.data + new_args.append(arg) i = arg_index - num_non_fp_tensors @@ -472,7 +474,8 @@ def get_cpu_activations_for_backward(args, inputs): new_args.append(arg) continue - arg.data = inp.data + arg.data = torch.empty([], device=arg.device).data + arg.saved_data = inp.data new_args.append(arg) return new_args @@ -628,6 +631,12 @@ def backward(ctx, *grads): global cuda_device, transport_stream, PARTITION_ACTIVATIONS + # Rebuild deepspeed_saved_tensors + for t in ctx.deepspeed_saved_tensors: + if t is not None and hasattr(t, 'saved_data') and t.saved_data is not None: + t.data = t.saved_data.to(t.device) + t.saved_data = None + if PARTITION_ACTIVATIONS: # with get_accelerator().stream(transport_stream): inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors, diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 888505279290..0aefd1946c36 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -16,7 +16,7 @@ from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, - is_model_parallel_parameter, see_memory_usage) + is_model_parallel_parameter, see_memory_usage, graph_process) from deepspeed.utils import link_hp_params, fragment_address from deepspeed.checkpoint import enable_universal_checkpoint @@ -38,7 +38,8 @@ def __init__(self, allgather_bucket_size=5000000000, dp_process_group=None, timers=None, - grad_acc_dtype=None): + grad_acc_dtype=None, + graph_harvesting=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -81,7 +82,7 @@ def __init__(self, self.fp32_groups_has_gradients = [] self.group_paddings = [] - + self.graph_harvesting = graph_harvesting if self.using_real_optimizer: self._setup_for_real_optimizer() @@ -248,7 +249,8 @@ def step(self, closure=None): all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(), mpu=self.mpu, - norm_type=self.norm_type) + norm_type=self.norm_type, + use_graph=self.graph_harvesting) self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. @@ -256,7 +258,8 @@ def step(self, closure=None): clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True), max_norm=self.clip_grad, global_norm=all_groups_norm, - mpu=self.mpu) + mpu=self.mpu, + use_graph=self.graph_harvesting) self.optimizer.step() @@ -281,23 +284,33 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): + + def _update_hp_grads_func(clear_lp_grads=False): + for i, group in enumerate(self.bf16_groups): + for j, lp in enumerate(group): + if lp.grad is None: + continue + hp_grad = self.fp32_groups_gradients[i][j] + assert hp_grad is not None, \ + f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' + hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) + lp._hp_grad = hp_grad + self.fp32_groups_has_gradients[i][j] = True + # clear gradients + if clear_lp_grads: + lp.grad._zero() + + if self.graph_harvesting: + graph_process(False, _update_hp_grads_func, clear_lp_grads) + else: + _update_hp_grads_func(clear_lp_grads) + #cpu op for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue - - hp_grad = self.fp32_groups_gradients[i][j] - assert hp_grad is not None, \ - f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' - - hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) - lp._hp_grad = hp_grad self.fp32_groups_has_gradients[i][j] = True - # clear gradients - if clear_lp_grads: - lp.grad = None - @torch.no_grad() def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat @@ -348,7 +361,9 @@ def clear_hp_grads(self): def clear_lp_grads(self): for group in self.bf16_groups: for param in group: - param.grad = None + if param.grad is not None: + # Using zero_() fixed memory address for graph replay + param.grad.zero_() def state_dict(self): state_dict = {} diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index b49469b94f11..80754df50c20 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -279,6 +279,10 @@ def get_gradient_clipping(param_dict): return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT) +def get_graph_harvesting(param_dict): + return get_scalar_param(param_dict, GRAPH_HARVESTING, GRAPH_HARVESTING_DEFAULT) + + def get_sparse_attention(param_dict): if SPARSE_ATTENTION in param_dict.keys(): sparsity = param_dict[SPARSE_ATTENTION] @@ -823,6 +827,7 @@ def _initialize_params(self, param_dict): self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) self.compression_config = get_compression_config(param_dict) + self.graph_harvesting = get_graph_harvesting(param_dict) self.optimizer_name = get_optimizer_name(param_dict) if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS): diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 96f2a38bd05c..82d8a0557a41 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -210,6 +210,18 @@ GRADIENT_CLIPPING = 'gradient_clipping' GRADIENT_CLIPPING_DEFAULT = 0. +######################################### +# Capture graph for short kernels sequences +######################################### +# Graph harvesting. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +GRAPH_HARVESTING_FORMAT = ''' +Graph harvesting should be enabled as: +"graph_harvesting": true +''' +GRAPH_HARVESTING = 'graph_harvesting' +GRAPH_HARVESTING_DEFAULT = False + ######################################### # Communication data type ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 03a0de329397..3cbc4c8414b7 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -69,7 +69,7 @@ STEP_MICRO_TIMER, \ FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ STEP_GLOBAL_TIMER -from deepspeed.utils.debug import debug_extract_module_and_param_names +from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from deepspeed.runtime.utils import clip_grad_norm_ @@ -366,6 +366,7 @@ def __init__( def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() + debug_clear_module_and_param_names() def _get_model_parameters(self): if self.autotuning_profile_model_info(): @@ -774,6 +775,9 @@ def zero_legacy_stage1(self): def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters + def graph_harvesting(self): + return self._config.graph_harvesting + def fp16_enabled(self): return self._config.fp16_enabled @@ -1018,6 +1022,9 @@ def _supported_optims(self): # Validate configuration based on command line arguments def _do_sanity_check(self): + if self.fp16_enabled() and not get_accelerator().is_fp16_supported(): + raise ValueError("Type fp16 is not supported.") + expected_optim_types = self._supported_optims() expected_optim_types += [type(None), Callable] assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \ @@ -1455,7 +1462,8 @@ def _configure_bf16_optimizer(self, optimizer): allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, timers=timers, - grad_acc_dtype=self.get_data_types()[1]) + grad_acc_dtype=self.get_data_types()[1], + graph_harvesting=self.graph_harvesting()) return optimizer @@ -1750,14 +1758,17 @@ def eval(self): self.warn_unscaled_loss = True self.module.train(False) - def _scale_loss_by_gas(self, prescaled_loss): + def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None): + # In pipeline evaluation, there is an option to use different micro-bs, which creates different number of + # micro batches, thus the training gas, is not valid in this case. need to use the number of eval_micro_batches + scaling_factor = self.gradient_accumulation_steps() if eval_micro_batches is None else eval_micro_batches if isinstance(prescaled_loss, torch.Tensor): - scaled_loss = prescaled_loss / self.gradient_accumulation_steps() + scaled_loss = prescaled_loss / scaling_factor elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list): scaled_loss = [] for l in prescaled_loss: if isinstance(l, torch.Tensor): - scaled_loss.append(l / self.gradient_accumulation_steps()) + scaled_loss.append(l / scaling_factor) else: scaled_loss.append(l) else: @@ -3219,7 +3230,6 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa state.update(client_state) logger.info(f'Saving model checkpoint: {save_path}') self.checkpoint_engine.save(state, save_path) - self._curr_save_path = None def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 27fa5b69d35d..05029e44d0e8 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -402,7 +402,13 @@ def train_batch(self, data_iter=None): # TODO: should return precisely what loss returned and allow others to be queried? return self.agg_train_loss - def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg', bcast_loss=True): + def eval_batch(self, + data_iter, + return_logits=False, + compute_loss=True, + reduce_output='avg', + bcast_loss=True, + num_micro_batches=None): """Evaluate the pipeline on a batch of data from ``data_iter``. The engine will evaluate ``self.train_batch_size()`` total samples collectively across all workers. @@ -451,6 +457,9 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o train_iterator = self.data_iterator self.set_dataiterator(data_iter) + # set the number micro batches in case the user chose value than training + micro_batches = self.micro_batches if num_micro_batches is None else num_micro_batches + # Do the work sched = schedule.InferenceSchedule(micro_batches=self.micro_batches, stages=self.num_stages, @@ -463,7 +472,7 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o self._exec_schedule(sched) if self.is_last_stage(): - eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output) + eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output, micro_batches=micro_batches) if compute_loss and (bcast_loss or self.monitor.enabled): eval_output = self._bcast_pipe_scalar(eval_output) @@ -505,7 +514,7 @@ def is_last_stage(self): """True if this process is in the last stage in the pipeline.""" return self.stage_id == self.num_stages - 1 - def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True): + def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None): if reduce is None: return outputs @@ -520,7 +529,7 @@ def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True): reduced[idx] += out # Average over the microbatches - reduced = self._scale_loss_by_gas(reduced) + reduced = self._scale_loss_by_gas(reduced, eval_micro_batches=micro_batches) # Average over DP groups if reduce_dp and self.is_data_parallel: diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index bc7a782e590c..82f200fccf9f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -47,6 +47,27 @@ def __init__(self, params): self.param_groups.append({'params': params}) +graph_cache = {} + + +def graph_process(replay_first_step, func, *args, **kwargs): + # `func` should only contain operations on the GPU + # Please ensure that the memory address of the data required by 'func' remains constant + if func.__name__ not in graph_cache: + cuda_stream = get_accelerator().Stream() + cuda_stream.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(cuda_stream): + func(*args, **kwargs) + get_accelerator().current_stream().wait_stream(cuda_stream) + graph_cache[func.__name__] = get_accelerator().create_graph() + with get_accelerator().capture_to_graph(graph_cache[func.__name__]): + func(*args, **kwargs) + if replay_first_step: + get_accelerator().replay_graph(graph_cache[func.__name__]) + else: + get_accelerator().replay_graph(graph_cache[func.__name__]) + + def noop_decorator(func): return func @@ -831,7 +852,7 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep return global_grad_norm -def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False): """Get norm of an iterable of tensors. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -845,7 +866,6 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): Returns: Total norm of the tensors (viewed as a single vector). """ - assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}' assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' @@ -857,8 +877,24 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: - total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + if use_graph: + if 'norm_tensors_compute_buffer' not in graph_cache: + graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors] + compute_buffer = graph_cache['norm_tensors_compute_buffer'] + + def _norm_tensors(tensor_list, _compute_buffer, _norm_type): + for i, t in enumerate(tensor_list): + _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) + if i != 0: + _compute_buffer[0].data.add_(_compute_buffer[i].data) + + graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) + + total_norm = compute_buffer[0] + else: + total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) + + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) @@ -869,7 +905,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): return total_norm -def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6): +def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False): """Clip list of tensors by global norm. Args: input_tensors: List of tensors to be clipped @@ -880,14 +916,26 @@ def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, m float: the global norm """ if global_norm is None: - global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu) - + global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph) clip_coef = max_norm / (global_norm + eps) - if clip_coef < 1: - for t in input_tensors: - t.detach().mul_(clip_coef) + if use_graph: + def clip_tensors(_tensor_list, _clip_coef_tensor): + for t in _tensor_list: + t.detach().mul_(_clip_coef_tensor) + + if 'clip_coef_tensor' not in graph_cache: + # Alloc memory + graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef, + dtype=torch.float32).to(get_accelerator().device_name()) + clip_coef_tensor = graph_cache['clip_coef_tensor'] + clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32)) + graph_process(False, clip_tensors, input_tensors, clip_coef_tensor) + + else: + for t in input_tensors: + t.detach().mul_(clip_coef) return global_norm diff --git a/deepspeed/runtime/zero/contiguous_memory_allocator.py b/deepspeed/runtime/zero/contiguous_memory_allocator.py index 632f68ca1718..35b3d5c7dd5d 100644 --- a/deepspeed/runtime/zero/contiguous_memory_allocator.py +++ b/deepspeed/runtime/zero/contiguous_memory_allocator.py @@ -193,7 +193,7 @@ def _defragment_memory(self): tensor = self.tensor_map[self.tensor_ids[tensor_addr]] assert tensor_size == tensor.numel(), \ - "Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} " + f"Size mismatch. {tensor_size} is allocated at addr {tensor_addr} but tensor size is {tensor.numel()} " assert empty_addr != tensor_addr, \ f"Cannot have same empty address {empty_addr} and tensor address {tensor_addr}" diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index cdf7de512b9b..992dcd446ad6 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1064,7 +1064,9 @@ def all_gather(param_list=None, async_op=False, hierarchy=0): def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_process_group): partition_sz = sum(p.ds_tensor.ds_numel for p in params) - if params[0].ds_secondary_tensor is not None and not forward: + use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward + + if use_secondary_tensor: partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) flat_tensor = torch.empty(partition_sz * world_size, @@ -1076,13 +1078,11 @@ def _all_gather_dtype(dtype, params, forward, world_size, rank_in_group, ds_proc for i in range(world_size): partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) - if params[0].ds_secondary_tensor is not None and not forward: - use_secondary_tensor = True + if use_secondary_tensor: instrument_w_nvtx( torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params], out=partitions[rank_in_group]) else: - use_secondary_tensor = False instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], out=partitions[rank_in_group]) handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group) @@ -1118,7 +1118,7 @@ def all_gather_coalesced(params: Iterable[Parameter], ds_process_group = self.ds_process_group rank_in_group = self.rank world_size = self.dp_world_size - use_secondary_tensor = False + use_secondary_tensor = params[0].ds_secondary_tensor is not None and not forward if self.zero_param_process_group and not forward: ds_process_group = self.zero_param_process_group #intragroup rank_in_group = self.rank_in_group @@ -1149,10 +1149,10 @@ def all_gather_coalesced(params: Iterable[Parameter], # have an opportunity to avoid some intermediate memory allocations param, = params buffer_size = math.ceil(param.ds_numel / world_size) * world_size - if not forward and param.ds_secondary_tensor is not None: + if use_secondary_tensor: buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized - param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor + param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor param_buffer = torch.empty( buffer_size, dtype=param_ds_tensor.dtype if not quantize else torch.int8, @@ -1207,7 +1207,7 @@ def all_gather_coalesced(params: Iterable[Parameter], else: partition_sz = sum(p.ds_tensor.ds_numel for p in params) - if params[0].ds_secondary_tensor is not None and not forward: + if use_secondary_tensor: partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) flat_tensor = torch.empty(partition_sz * world_size, @@ -1215,8 +1215,7 @@ def all_gather_coalesced(params: Iterable[Parameter], device=get_accelerator().current_device_name(), requires_grad=False) - if params[0].ds_secondary_tensor is not None and not forward: - use_secondary_tensor = True + if use_secondary_tensor: if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): quantized_param = instrument_w_nvtx(torch.cat)([ p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 30a168dcd396..ce4137028195 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -377,6 +377,7 @@ def __init__( #creates backward hooks for gradient partitioning ###Calls all gather param + self._grad_acc_hooks = [] self.create_reduce_and_remove_grad_hooks() #exit(0) @@ -397,6 +398,9 @@ def __init__( def destroy(self): self.parameter_offload.destroy() + for hook in self._grad_acc_hooks: + hook.remove() + print_rank_0("Removed grad acc hooks", force=False) del self.__ipg_bucket_flat_buffer def initialize_ds_offload( @@ -1118,7 +1122,7 @@ def wrapper(param): def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - grad_acc.register_hook(reduce_partition_and_remove_grads) + self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) self.grad_accs.append(grad_acc) #print(f"param grad fn {param.expand_as(param).grad_fn}") @@ -1324,7 +1328,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): param_id = self.get_param_id(p) if param_id in self.norm_for_param_grads.keys(): param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + total_norm += param_norm**2 # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) @@ -1333,10 +1337,14 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0]**(1. / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm @@ -1665,7 +1673,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): # Take max across all GPUs. self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() + total_norm = total_norm_cuda[0] else: # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") @@ -1686,10 +1694,14 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda.item()**(1. / norm_type) + total_norm = total_norm_cuda**(1. / norm_type) + + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index aeb533698af3..2eec7cbc96a1 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -240,8 +240,8 @@ def __init__(self, if self.reduce_scatter: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" - assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - assert self.postscale_gradients, "pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" # param flattened by groups self.bit16_groups = [] @@ -490,6 +490,7 @@ def __init__(self, self.reset_partition_gradient_structures() # creates backward hooks for gradient partitioning + self._grad_acc_hooks = [] if self.partition_gradients or self.overlap_comm: self.create_reduce_and_remove_grad_hooks() @@ -522,6 +523,11 @@ def __init__(self, self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() + def destroy(self): + for hook in self._grad_acc_hooks: + hook.remove() + self.print_rank_0("Removed grad acc hooks") + def _enable_universal_checkpoint(self): for lp_param_group in self.bit16_groups: enable_universal_checkpoint(param_list=lp_param_group) @@ -864,7 +870,7 @@ def wrapper(param, i): def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param, i) - grad_acc.register_hook(reduce_partition_and_remove_grads) + self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) self.grad_accs.append(grad_acc) wrapper(param, i) diff --git a/deepspeed/utils/debug.py b/deepspeed/utils/debug.py index 02295fa98011..cebea56255d9 100644 --- a/deepspeed/utils/debug.py +++ b/deepspeed/utils/debug.py @@ -11,6 +11,13 @@ param_names = {} +def debug_clear_module_and_param_names(): + global module_names + global param_names + module_names = {} + param_names = {} + + def debug_extract_module_and_param_names(model): # extract the fully qualified names as soon as the model is acquired global module_names diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index c98caae31534..49b846633d6e 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -248,6 +248,11 @@ def _zero2_merge_frozen_params(state_dict, zero_model_states): print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): param_shapes = zero_model_states[0].param_shapes @@ -287,7 +292,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero avail_numel = full_single_fp32_vector.numel() for name, shape in shapes.items(): - unpartitioned_numel = shape.numel() + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) total_numel += unpartitioned_numel total_params += 1 diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md index 10197e62f681..d27ecf021421 100755 --- a/docs/_tutorials/advanced-install.md +++ b/docs/_tutorials/advanced-install.md @@ -27,7 +27,7 @@ ds_report ## Pre-install DeepSpeed Ops -**Note:** [PyTorch](https://pytorch.org/) must be installed _before_ pre-compiling any DeepSpeed c++/cuda ops. However, this is not required if using the default mode of JIT compilation of ops. +**Note:** [PyTorch](https://pytorch.org/) must be installed _before_ pre-compiling any DeepSpeed C++/CUDA ops. However, this is not required if using the default mode of JIT compilation of ops. {: .notice--info} Sometimes we have found it useful to pre-install either some or all DeepSpeed @@ -56,22 +56,22 @@ DS_BUILD_FUSED_LAMB=1 pip install deepspeed ``` Available `DS_BUILD` options include: -* `DS_BUILD_OPS` toggles all ops -* `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op -* `DS_BUILD_CCL_COMM` builds the communication collective libs -* `DS_BUILD_CPU_ADAM` builds the CPUAdam op -* `DS_BUILD_CPU_LION` builds the CPULion op -* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/)) -* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)) -* `DS_BUILD_FUSED_LION` builds the FusedLion op -* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op -* `DS_BUILD_FUSED_LAMB` builds the FusedLamb op -* `DS_BUILD_QUANTIZER` builds the quantizer op -* `DS_BUILD_RANDOM_LTD` builds the random ltd op -* `DS_BUILD_SPARSE_ATTN` builds the sparse attention op -* `DS_BUILD_TRANSFORMER` builds the transformer op -* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op -* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op +* `DS_BUILD_OPS` toggles all ops. +* `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op. +* `DS_BUILD_CCL_COMM` builds the communication collective libs. +* `DS_BUILD_CPU_ADAM` builds the CPUAdam op. +* `DS_BUILD_CPU_LION` builds the CPULion op. +* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/)). +* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)). +* `DS_BUILD_FUSED_LION` builds the FusedLion op. +* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op. +* `DS_BUILD_FUSED_LAMB` builds the FusedLamb op. +* `DS_BUILD_QUANTIZER` builds the quantizer op. +* `DS_BUILD_RANDOM_LTD` builds the random ltd op. +* `DS_BUILD_SPARSE_ATTN` builds the sparse attention op. +* `DS_BUILD_TRANSFORMER` builds the transformer op. +* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op. +* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op. To speed up the build-all process, you can parallelize the compilation process with: @@ -81,7 +81,7 @@ DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" --global-option This should complete the full build 2-3 times faster. You can adjust `-j` to specify how many cpu-cores are to be used during the build. In the example it is set to 8 cores. -You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, pytorch, python, etc.) +You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, PyTorch, Python, etc.) ```bash DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel @@ -107,7 +107,7 @@ pip install . For installs spanning multiple nodes we find it useful to install DeepSpeed using the [install.sh](https://github.com/microsoft/DeepSpeed/blob/master/install.sh) -script in the repo. This will build a python wheel locally and copy it to all +script in the repo. This will build a Python wheel locally and copy it to all the nodes listed in your hostfile (either given via `--hostfile`, or defaults to `/job/hostfile`). @@ -118,7 +118,7 @@ extensions will be loaded form that directory. If you use multiple virtual environments this could be a problem, since by default there is only one `torch_extensions` directory, but different virtual environments may use different setups (e.g., different -python or cuda versions) and then the loading of a CUDA extension built by another environment will +Python or CUDA versions) and then the loading of a CUDA extension built by another environment will fail. Therefore, if you need to you can override the default location with the help of the `TORCH_EXTENSIONS_DIR` environment variable. So in each virtual environment you can point it to a unique directory and DeepSpeed will use it to save and load CUDA extensions. @@ -146,9 +146,9 @@ If you're getting the following error: ``` RuntimeError: CUDA error: no kernel image is available for execution on the device ``` -when running deepspeed, that means that the cuda extensions weren't built for the card you're trying to use it for. +when running deepspeed, that means that the CUDA extensions weren't built for the card you're trying to use it for. -When building from source deepspeed will try to support a wide range of architectures, but under jit-mode it'll only +When building from source DeepSpeed will try to support a wide range of architectures, but under jit-mode it'll only support the architectures visible at the time of building. You can build specifically for a desired range of architectures by setting a `TORCH_CUDA_ARCH_LIST` env variable: @@ -159,9 +159,9 @@ TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... It will also make the build faster when you only build for a few architectures. -This is also recommended to ensure your exact architecture is used. Due to a variety of technical reasons, a distributed pytorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which architectures get included during the deepspeed build from source - save the log and grep for `-gencode` arguments. +This is also recommended to ensure your exact architecture is used. Due to a variety of technical reasons, a distributed PyTorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which architectures get included during the DeepSpeed build from source - save the log and grep for `-gencode` arguments. -The full list of nvidia GPUs and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus). +The full list of Nvidia GPUs and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus). ## CUDA version mismatch @@ -171,7 +171,7 @@ If you're getting the following error: Exception: >- DeepSpeed Op Builder: Installed CUDA version {VERSION} does not match the version torch was compiled with {VERSION}, unable to compile cuda/cpp extensions without a matching cuda version. ``` You have a misaligned version of CUDA installed compared to the version of CUDA -used to compile torch. A mismatch in the major version is likely to result in +used to compile Torch. A mismatch in the major version is likely to result in errors or unexpected behavior. The easiest fix for this error is changing the CUDA version installed (check diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index 13d71b476b5a..8cb372e96c37 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -73,8 +73,8 @@ def sources(self): "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp", "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu", "inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp", - "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp", - "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu", ] prefix = self.get_prefix() @@ -101,12 +101,13 @@ def include_paths(self): 'inference/v2/kernels/ragged_ops/atom_builder', 'inference/v2/kernels/ragged_ops/blocked_flash', 'inference/v2/kernels/ragged_ops/embed', + 'inference/v2/kernels/ragged_ops/includes', 'inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary', 'inference/v2/kernels/ragged_ops/logits_gather', 'inference/v2/kernels/ragged_ops/moe_gather', 'inference/v2/kernels/ragged_ops/moe_scatter', 'inference/v2/kernels/ragged_ops/ragged_helpers', - 'inference/v2/kernels/ragged_ops/top_1_gating', + 'inference/v2/kernels/ragged_ops/top_k_gating', ] prefix = self.get_prefix() diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py index 90fe26eb4490..5f1ef930952c 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_kv_copy.py @@ -13,11 +13,11 @@ @pytest.mark.inference_v2_ops @pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 8), (63, 1)]) -def test_single_sequence_single_block(n_tokens: int, history_size: int): +@pytest.mark.parametrize("head_size", [64, 80, 128]) +def test_single_sequence_single_block(n_tokens: int, history_size: int, head_size: int): """ Validate that the copy works correctly """ - head_size = 64 n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 @@ -46,11 +46,11 @@ def test_single_sequence_single_block(n_tokens: int, history_size: int): @pytest.mark.inference_v2_ops @pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) -def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int): +@pytest.mark.parametrize("head_size", [64, 80, 128]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, head_size: int): """ Validate that the copy works correctly """ - head_size = 64 n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 @@ -78,8 +78,8 @@ def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int): @pytest.mark.inference_v2_ops -def test_multi_sequence() -> None: - head_size = 64 +@pytest.mark.parametrize("head_size", [64, 80, 128]) +def test_multi_sequence(head_size: int) -> None: n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py index 618c2d3b87ec..156be9929d92 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_blocked_rotary_emb.py @@ -21,13 +21,19 @@ """ -def rotary_pos_embs(q: torch.Tensor, k: torch.Tensor, seq_descs: List[DSSequenceDescriptor], batch: RaggedBatchWrapper, - head_size: int): +def rotary_pos_embs(q: torch.Tensor, + k: torch.Tensor, + seq_descs: List[DSSequenceDescriptor], + batch: RaggedBatchWrapper, + head_size: int, + rotary_dim: int = -1) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + rotary_dim = rotary_dim if rotary_dim >= 0 else head_size def make_cos_sin_emb(seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: t = torch.arange(seq_len, dtype=torch.float32, device=get_accelerator().current_device()) inv_freq = (1.0 / (10000.0**(torch.arange( - 0, head_size, 2, dtype=torch.float32, device=get_accelerator().current_device()) / head_size))).half() + 0, rotary_dim, 2, dtype=torch.float32, device=get_accelerator().current_device()) / rotary_dim))).half() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) @@ -57,11 +63,17 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: k_src = k[start_idx:start_idx + n_tokens].reshape(n_tokens, n_heads_kv, head_size).float() freq_start_offset = seq_desc.seen_tokens + q_src_rot = q_src[:, :, :rotary_dim] + k_src_rot = k_src[:, :, :rotary_dim] + cos_chunk = cos[range(freq_start_offset, freq_start_offset + n_tokens)] sin_chunk = sin[range(freq_start_offset, freq_start_offset + n_tokens)] - q_emb = q_src * cos_chunk + rotate_half(q_src) * sin_chunk - k_emb = k_src * cos_chunk + rotate_half(k_src) * sin_chunk + q_rot = q_src_rot * cos_chunk + rotate_half(q_src_rot) * sin_chunk + k_rot = k_src_rot * cos_chunk + rotate_half(k_src_rot) * sin_chunk + + q_emb = torch.cat((q_rot, q_src[:, :, rotary_dim:]), dim=-1) + k_emb = torch.cat((k_rot, k_src[:, :, rotary_dim:]), dim=-1) q_out[start_idx:start_idx + n_tokens] = q_emb.reshape(n_tokens, n_heads_q * head_size).to(q_out.dtype) k_out[start_idx:start_idx + n_tokens] = k_emb.reshape(n_tokens, n_heads_kv * head_size).to(k_out.dtype) @@ -72,11 +84,11 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: @pytest.mark.inference_v2_ops @pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 15), (1, 63)]) @pytest.mark.parametrize("trained_emb", [False, True]) -def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_emb: bool): +@pytest.mark.parametrize("head_size", [64, 80]) +def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_emb: bool, head_size: int): """ Validate that the copy works correctly """ - head_size = 64 n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 @@ -106,7 +118,7 @@ def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_ copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) copy_impl(kv_cache, qkv, batch, freqs) else: - copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, head_size, 10000.0) copy_impl(kv_cache, qkv, batch) assert allclose(qkv[:, :head_size * n_heads_q], q_ref) @@ -116,11 +128,11 @@ def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_ @pytest.mark.inference_v2_ops @pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) @pytest.mark.parametrize("trained_emb", [False, True]) -def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, trained_emb: bool): +@pytest.mark.parametrize("head_size", [64, 80]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, trained_emb: bool, head_size: int): """ Validate that the copy works correctly """ - head_size = 64 n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 @@ -150,7 +162,7 @@ def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, train copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) copy_impl(kv_cache, qkv, batch, freqs) else: - copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, head_size, 10000.0) copy_impl(kv_cache, qkv, batch) assert allclose(qkv[:, :head_size * n_heads_q], q_ref) @@ -159,8 +171,8 @@ def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, train @pytest.mark.inference_v2_ops @pytest.mark.parametrize("trained_emb", [False, True]) -def test_multi_sequences(trained_emb: bool) -> None: - head_size = 64 +@pytest.mark.parametrize("head_size", [64, 80]) +def test_multi_sequences(trained_emb: bool, head_size: int) -> None: n_heads_q = 16 n_heads_kv = 16 kv_block_size = 64 @@ -196,8 +208,51 @@ def test_multi_sequences(trained_emb: bool) -> None: copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) copy_impl(kv_cache, qkv, batch, freqs) else: - copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, head_size, 10000.0) copy_impl(kv_cache, qkv, batch) assert allclose(qkv[:, :head_size * n_heads_q], q_ref) validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +def test_rotary_dim() -> None: + trained_emb = False + head_size = 80 + rotary_dim = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size, rotary_dim=rotary_dim) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16, rotary_dim, 10000.0) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py index 5fa375b49c19..3907fc3e3a4b 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py @@ -11,18 +11,28 @@ from deepspeed.inference.v2.kernels.ragged_ops import ( MoEGather, MoEScatter, - RaggedTop1Gating, + RaggedTopKGating, ) from .ragged_testing_utils import build_simple_batch """ -For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` and +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` and ``MoEScatter`` to produce correct inputs. If either of these kernels is broken these tests will fail, so double check the unit test results there before debugging here. """ +TEST_CASES = [ + # (n_tokens, n_experts, n_top_k) + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] -def build_inputs(n_tokens, n_experts, do_padding): + +def build_inputs(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): assert n_tokens <= 2048, "This test will break if n_tokens > 2048" @@ -39,22 +49,28 @@ def build_inputs(n_tokens, n_experts, do_padding): device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( batch.tensor_toks, 4096).contiguous() - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) # Gating outputs expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((batch.tensor_toks, ), + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) # Scatter outputs - moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) scatter = MoEScatter(DtypeEnum.fp16, 4096) scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) @@ -63,11 +79,12 @@ def build_inputs(n_tokens, n_experts, do_padding): @pytest.mark.inference_v2_ops -@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) -@pytest.mark.parametrize("do_padding", [True, False]) -def test_moe_gather(n_tokens, n_experts, do_padding): +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CASES) +@pytest.mark.parametrize("do_padding", [False]) +def test_moe_gather(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): + get_accelerator().manual_seed(0xC0FFEE) - batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, do_padding) + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) @@ -75,9 +92,31 @@ def test_moe_gather(n_tokens, n_experts, do_padding): gather(output, moe_input, scores, mapped_slots, expert_counts) for token_idx in range(n_tokens): + effective_score = scores[token_idx].sum().item() assert torch.equal( output[token_idx], torch.full((4096, ), - token_idx * scores[token_idx], + token_idx * effective_score, dtype=torch.float16, device=get_accelerator().current_device())) + + +@pytest.mark.inference_v2_ops +def test_moe_gather_normalize_scales(): + get_accelerator().manual_seed(0xC0FFEE) + + n_tokens = 72 + n_experts = 8 + n_top_k = 2 + do_padding = False + + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) + output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + + gather = MoEGather(DtypeEnum.fp16, 4096, normalize_scores=True) + gather(output, moe_input, scores, mapped_slots, expert_counts) + + for token_idx in range(n_tokens): + assert torch.equal( + output[token_idx], + torch.full((4096, ), token_idx, dtype=torch.float16, device=get_accelerator().current_device())) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py index 4ca051410c1c..aae459f06a6f 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py @@ -8,19 +8,28 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import DtypeEnum -from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTop1Gating +from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTopKGating from .ragged_testing_utils import build_simple_batch """ -For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` to produce correct -inputs. If ``RaggedTop1Gating`` is broken, these tests will fail, so double check +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` to produce correct +inputs. If ``RaggedTopKGating`` is broken, these tests will fail, so double check the unit test results there before debugging here. """ +TEST_CONFIGS = [ + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] + @pytest.mark.inference_v2_ops -@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) -@pytest.mark.parametrize("do_padding", [True, False]) -def test_moe_scatter(n_tokens, n_experts, do_padding): +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CONFIGS) +@pytest.mark.parametrize("do_padding", [False, True]) +def test_moe_scatter(n_tokens, n_experts, n_top_k, do_padding): # Sequence composition shouldn't matter here batch = build_simple_batch([n_tokens], padding=do_padding) @@ -35,40 +44,52 @@ def test_moe_scatter(n_tokens, n_experts, do_padding): device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( batch.tensor_toks, 4096).contiguous() - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) # Gating outputs expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((batch.tensor_toks, ), + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) # Scatter outputs - moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) scatter = MoEScatter(DtypeEnum.fp16, 4096) scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + get_accelerator().synchronize() assert torch.equal(expert_cumsum, torch.cumsum(expert_counts, dim=0).to(torch.int64)) + if not do_padding: + assert torch.unique(mapped_slots).size(0) == n_top_k * n_tokens + for token_idx in range(batch.tensor_toks): if token_idx < n_tokens: - expert_idx = expert_assignment[token_idx].item() - if expert_idx == 0: - expert_cumsum_val = 0 - else: - expert_cumsum_val = expert_cumsum[expert_idx - 1] - offset = expert_offset[token_idx] - total_offset = offset + expert_cumsum_val - - assert total_offset == mapped_slots[token_idx].item() - assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) + for k in range(n_top_k): + expert_idx = expert_assignment[token_idx][k].item() + if expert_idx == 0: + expert_cumsum_val = 0 + else: + expert_cumsum_val = expert_cumsum[expert_idx - 1] + offset = expert_offset[token_idx][k] + total_offset = offset + expert_cumsum_val + + assert total_offset == mapped_slots[token_idx][k].item() + assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) else: - assert mapped_slots[token_idx].item() == -1 + for k in range(n_top_k): + assert mapped_slots[token_idx][k].item() == -1 - assert expert_cumsum[-1] == n_tokens + assert expert_cumsum[-1] == n_tokens * n_top_k diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py similarity index 51% rename from tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py rename to tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py index 6ff2508bf320..5fa0c8a079f0 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py @@ -9,9 +9,52 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import DtypeEnum -from deepspeed.inference.v2.kernels.ragged_ops import RaggedTop1Gating +from deepspeed.inference.v2.kernels.ragged_ops import RaggedTopKGating from .ragged_testing_utils import build_simple_batch -from ....v2.inference_test_utils import allclose +from ...inference_test_utils import allclose + + +def _top_k_gating_testing_helper(n_tokens: int, n_experts: int, n_top_k: int, seed: int = 0xC0FFEE) -> None: + + torch.manual_seed(seed) + logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + gate = RaggedTopKGating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + ref_weights = F.softmax(logits, dim=-1, dtype=torch.float32) + ref_scores, ref_indices = torch.topk(ref_weights, n_top_k, dim=-1) + + assert allclose(scores, ref_scores), f"expected {ref_scores}, got {scores}" + assert torch.equal(expert_assignment, + ref_indices.to(torch.int32)), f"expected {ref_indices}, got {expert_assignment}" + assert expert_counts.sum( + ) == n_tokens * n_top_k, f"expected {n_tokens * n_top_k} tokens, got {expert_counts.sum()}" + + # Ensure that the expert offsets are unique + for i in range(n_experts): + expert_idxs = torch.where(expert_assignment == i, expert_offset, 0) + if expert_counts[i] > 0: + assert expert_idxs.unique().shape[0] == expert_counts[ + i], f"expected {expert_counts[i]} unique offsets, got {expert_idxs.unique().shape[0]}" + assert expert_idxs.max( + ) == expert_counts[i] - 1, f"expected max offset {expert_counts[i] - 1}, got {expert_idxs.max()}" + else: + # Should have all 0's so one unique value + assert expert_idxs.unique().shape[0] == 1 + assert expert_idxs.max() == 0 + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens', [1, 17, 32, 89, 433]) +def test_top_2_e_8_gating(n_tokens: int) -> None: + _top_k_gating_testing_helper(n_tokens=n_tokens, n_experts=8, n_top_k=2) def _test_single_mapping_helper(n_tokens: int, @@ -19,6 +62,8 @@ def _test_single_mapping_helper(n_tokens: int, assigned_expert: int, logit_fill: float = 0.0, match_fill: float = 1.0) -> None: + + n_top_k = 1 logits = torch.full((n_tokens, n_experts), logit_fill, dtype=torch.float16, @@ -26,12 +71,12 @@ def _test_single_mapping_helper(n_tokens: int, logits[:, assigned_expert] = match_fill - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) @@ -39,7 +84,7 @@ def _test_single_mapping_helper(n_tokens: int, assert expert_counts[assigned_expert] == n_tokens assert torch.all(expert_assignment == assigned_expert) assert torch.unique(expert_offset).shape[0] == n_tokens - assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert]) + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert].reshape(-1, n_top_k)) @pytest.mark.inference_v2_ops @@ -72,6 +117,7 @@ def test_determinism(): n_tokens = 512 n_experts = 64 + n_top_k = 1 logits = torch.zeros((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) @@ -79,13 +125,15 @@ def test_determinism(): logits[:, 19] = 1.0 logits[:, 26] = 1.0 - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) for _ in range(1024): expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) @@ -94,7 +142,7 @@ def test_determinism(): assert expert_counts[26] == 0 assert torch.all(expert_assignment == 19) assert torch.unique(expert_offset).shape[0] == n_tokens - assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19]) + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19].reshape(-1, 1)) @pytest.mark.inference_v2_ops @@ -105,16 +153,19 @@ def test_score_accuracy(n_tokens: int, n_experts: int) -> None: """ logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) + n_top_k = 1 - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) ref_scores = F.softmax(logits.float(), dim=1).max(dim=1).values + ref_scores = ref_scores.reshape(-1, 1) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + assert allclose(scores, ref_scores) assert expert_counts.sum() == n_tokens diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py index 260236562ee9..06ff9047d648 100644 --- a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py +++ b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py @@ -26,7 +26,7 @@ def __init__(self, experts_per_rank: int) -> None: self._num_experts = experts_per_rank @property - def num_experts(self) -> int: + def n_experts(self) -> int: return self._num_experts @on_device diff --git a/tests/unit/inference/v2/modules/test_blocked_attn.py b/tests/unit/inference/v2/modules/test_blocked_attn.py index 215ad64636b1..6556aa460a44 100644 --- a/tests/unit/inference/v2/modules/test_blocked_attn.py +++ b/tests/unit/inference/v2/modules/test_blocked_attn.py @@ -12,7 +12,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.modules import ConfigBundle -from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType +from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType, RotateHalfConfig from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase from ..kernels.ragged_ops.ragged_testing_utils import build_batch_and_manager @@ -37,13 +37,10 @@ def _blocked_flash_testing_helper(head_size: int, """ if trained_freqs is None: embed_type = PositionalEmbeddingType.none - embed_args = {} + embed_args = None else: embed_type = PositionalEmbeddingType.rotate_half - if trained_freqs: - embed_args = {'trained_freqs': True} - else: - embed_args = {'trained_freqs': False} + embed_args = RotateHalfConfig(use_trained_freqs=trained_freqs) attn_config = DSSelfAttentionConfig(max_tokens=2048, n_heads_q=n_heads_q, @@ -51,7 +48,7 @@ def _blocked_flash_testing_helper(head_size: int, head_size=head_size, max_sequences=32, positional_embedding_type=embed_type, - positional_embedding_args=embed_args) + positional_embedding_config=embed_args) config = ConfigBundle(name='dense_blocked_attention', config=attn_config) attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config) diff --git a/tests/unit/inference/v2/modules/test_cutlass_moe.py b/tests/unit/inference/v2/modules/test_cutlass_moe.py index e21170c9ed8f..b14ba127c6be 100644 --- a/tests/unit/inference/v2/modules/test_cutlass_moe.py +++ b/tests/unit/inference/v2/modules/test_cutlass_moe.py @@ -212,3 +212,117 @@ def test_in_out_channels(in_channels: int, out_channels: int) -> None: dtype=DtypeEnum.fp16, activation_type=ActivationType.IDENTITY, use_bias=True) + + +def _mixtral_moe_baseline(hidden_states: torch.Tensor, + gate_weight: torch.Tensor, + mlp_w1: torch.Tensor, + mlp_w2: torch.Tensor, + mlp_w3: torch.Tensor, + force_float: bool = False) -> torch.Tensor: + """ + Baseline implementation for mixtral MoE module. + + Based on transformers implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + """ + output_dtype = hidden_states.dtype + if force_float: + hidden_states = hidden_states.float() + gate_weight = gate_weight.float() + mlp_w1 = mlp_w1.float() + mlp_w2 = mlp_w2.float() + mlp_w3 = mlp_w3.float() + + router_logits = torch.nn.functional.linear(hidden_states, gate_weight) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = routing_weights.topk(k=2, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # NOTE(cmikeh2): This is a difference implementation, ours will preserve the original scale + # as float32 and perform in-kernel fused FP16->FP32->FP16 conversion. + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros_like(hidden_states) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=gate_weight.shape[0]).permute(2, 1, 0) + get_accelerator().synchronize() + + for expert_idx in range(gate_weight.shape[0]): + exp_mlp_w1 = mlp_w1[expert_idx] + exp_mlp_w2 = mlp_w2[expert_idx] + exp_mlp_w3 = mlp_w3[expert_idx] + + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states[top_x_list] + + linear = torch.nn.functional.linear + intermediate = torch.nn.functional.silu(linear(current_state, exp_mlp_w1)) * linear(current_state, exp_mlp_w3) + output = linear(intermediate, exp_mlp_w2) * routing_weights[top_x_list, idx_list].unsqueeze(-1) + final_hidden_states.index_add_(0, top_x, output.to(final_hidden_states.dtype)) + + return final_hidden_states.to(output_dtype) + + +@pytest.mark.inference_v2_ops +def test_mixtral_moe_config(): + + experts = 8 + n_top_k = 2 + in_channels = 4096 + intermediate_dim = 2048 + dtype = DtypeEnum.bf16 + + # Parameters + gate_weight = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + mlp_w1 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w3 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w2 = torch.randn( + (experts, in_channels, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + n_tokens = 256 + hidden_states = torch.randn( + (n_tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + baseline = _mixtral_moe_baseline(hidden_states, gate_weight, mlp_w1, mlp_w2, mlp_w3) + + mlp_w13_fused = torch.cat([mlp_w1, mlp_w3], dim=-1).reshape(experts, 2 * intermediate_dim, in_channels) + + config = DSMoEConfig(max_tokens=4096, + model_dim=in_channels, + intermediate_features=intermediate_dim, + n_experts=experts, + activation=ActivationType.SiGLU, + input_dtype=dtype, + output_dtype=dtype, + top_k=n_top_k, + normalize_scores=True) + + implementation_config = {"weight_dtype": DtypeEnum(dtype)} + + bundle = ConfigBundle(name='cutlass_multi_gemm_moe', config=config, implementation_config=implementation_config) + moe_module = DSMoERegistry.instantiate_config(bundle) + + batch = build_simple_batch([n_tokens]) + + gate_ds = moe_module.transform_gate_param(gate_weight) + mlp_w1_ds = moe_module.transform_moe_mlp_1_param(mlp_w13_fused) + mlp_w2_ds = moe_module.transform_moe_mlp_2_param(mlp_w2) + + output = moe_module(hidden_states, batch, gate_ds, mlp_w1_ds, mlp_w2_ds) + + # NOTE(cmikeh2): These are higher than the other tests for reasons that aren't quite + # clear to me. My best guess is that the SiGLU activation is causing larger numerical + # divergence. The thresholds chosen here is based on the observed error between the + # float and bfloat16 reference implementations. + assert allclose(output, baseline.to(dtype.value), tolerances=(5e-2, 5e-2)) diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index 6cd01644fad5..880282bb7e57 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -90,7 +90,7 @@ class TestBatchConfig(DistributedTest): def test(self, num_ranks, batch, micro_batch, gas, success): assert dist.get_world_size() == num_ranks, \ - 'The test assumes a world size of f{num_ranks}' + f'The test assumes a world size of {num_ranks}' ds_batch_config = get_test_path('ds_batch_config.json') ds_config = DeepSpeedConfig(ds_batch_config) diff --git a/version.txt b/version.txt index dabff2f13810..e2e3067ddc5f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.12.6 +0.12.7