From d5a7c1e0b494fbd0958bf8274bde0bacb2c16854 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 21 Dec 2023 04:51:36 +0800 Subject: [PATCH 01/10] Capture short kernel sequences to graph (#4318) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Motivation:** 1. This is a series of cases where short kernel sequences are launched and executed serially(no dynamic shape), with the launch overhead being much higher than the execution overhead. We can use a graph to solve this problem. Compared to ```multi-tensor-apply```, using graph is more concise and only requires PyTorch as a dependency. 2. Some device software stacks also support lazy-mode PyTorch, enabling full utilization of the compiler to perform graph optimization. However, in lazy mode, operation accumulation time (host time) could become significantly higher compared to device time in such scenario, and devices are usually not well utilized. By using the same API(after adding to accelerator cc @delock ) with cuda graph, this issue could also be resolved. **Change:** We modified three functions, ```update_hp_grads```. Here, we executed the operations for the CPU and GPU separately because the graph is unable to record the execution of CPU operations. Additionally, the data input required by the graph must not have its address modified, or the address modification must be captured by the capture operation(In this case, set ```replay_first_step``` to ```True```). Therefore, we changed ```grad=None``` to ```grad.zero_()```. Similarly, we have also placed some inputs that require fixed addresses in the ```graph_cache``` For ```clip_tensors_by_global_norm```, ```clip_coef``` is a scalar with a non-fixed value, so it needs to be moved to the GPU when using a graph. For ```total_norm = sum ([t. data. float (). norm (norm_type). item () * * norm_type for t in input_tensors])```, ```item () ```, synchronous operation is also not supported by graph. We directly put the ```sum``` and ```* * norm_type``` on the GPU to execute the computation. Other similar scenarios can also use this ```graph_process()```, or a slightly modified version of ```graph_process()``` you can checkout [4abab21](https://github.com/microsoft/DeepSpeed/pull/4318/commits/4abab212c8f5aef1eec4f8abe10b4262bb5a5c8a) and set it to True here to do some benchmarking. https://github.com/microsoft/DeepSpeed/pull/4318/commits/4abab212c8f5aef1eec4f8abe10b4262bb5a5c8a#diff-f8f0b3feb55b0374615405e542c1c3e0f017982b177c46c562bf688532ac935cR42 --------- Co-authored-by: Olatunji Ruwase --- accelerator/abstract_accelerator.py | 13 ++++ accelerator/cpu_accelerator.py | 12 +++- accelerator/cuda_accelerator.py | 11 +++ accelerator/mps_accelerator.py | 11 +++ accelerator/npu_accelerator.py | 11 +++ deepspeed/inference/engine.py | 6 +- .../model_implementations/diffusers/unet.py | 7 +- .../model_implementations/diffusers/vae.py | 19 +++--- .../transformers/clip_encoder.py | 6 +- deepspeed/runtime/bf16_optimizer.py | 49 ++++++++----- deepspeed/runtime/config.py | 5 ++ deepspeed/runtime/constants.py | 12 ++++ deepspeed/runtime/engine.py | 6 +- deepspeed/runtime/utils.py | 68 ++++++++++++++++--- 14 files changed, 189 insertions(+), 47 deletions(-) diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 7c1f463009b2..8ba652a8de92 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -193,6 +193,19 @@ def communication_backend_name(self): def is_triton_supported(self): ... + # Graph operations + @abc.abstractmethod + def create_graph(self): + ... + + @abc.abstractmethod + def capture_to_graph(self, graph, pool=None, stream=None): + ... + + @abc.abstractmethod + def replay_graph(self, graph): + ... + # Tensor operations @property @abc.abstractmethod diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 05c4f6c650f0..a02777f5223b 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -198,8 +198,18 @@ def is_fp16_supported(self): def supported_dtypes(self): return [torch.float, torch.bfloat16] - # Tensor operations + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property def BFloat16Tensor(self): return torch.BFloat16Tensor diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 4b94ddb6865c..2d74daecf3df 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -222,6 +222,17 @@ def is_triton_supported(self): else: return False + # Graph operations + def create_graph(self): + return torch.cuda.CUDAGraph() + + def capture_to_graph(self, graph, pool=None, stream=None): + return torch.cuda.graph(graph, pool, stream) + + def replay_graph(self, graph): + graph.replay() + return + # Tensor operations @property diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index f0d4cac721b5..f6303cf9890f 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -166,6 +166,17 @@ def communication_backend_name(self): def is_triton_supported(self): return False + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property def BFloat16Tensor(self): diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 8bfd59cd2ad6..4e20445d9d32 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -176,6 +176,17 @@ def communication_backend_name(self): def is_triton_supported(self): return False + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + # Tensor operations @property diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index ccecc8376ad6..5cdd99ff0b90 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -531,11 +531,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): get_accelerator().current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self.module(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True @@ -547,7 +547,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def model_times(self): diff --git a/deepspeed/model_implementations/diffusers/unet.py b/deepspeed/model_implementations/diffusers/unet.py index 7da571975958..8d5ddd95437a 100644 --- a/deepspeed/model_implementations/diffusers/unet.py +++ b/deepspeed/model_implementations/diffusers/unet.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -29,7 +30,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._cuda_graphs.replay() + get_accelerator().replay_graph(self._cuda_graphs) return self.static_output def forward(self, *inputs, **kwargs): @@ -53,11 +54,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs = torch.cuda.CUDAGraph() + self._cuda_graphs = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._cuda_graphs): + with get_accelerator().capture_to_graph(self._cuda_graphs): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True diff --git a/deepspeed/model_implementations/diffusers/vae.py b/deepspeed/model_implementations/diffusers/vae.py index 05084f1b985a..ce50ade647a8 100644 --- a/deepspeed/model_implementations/diffusers/vae.py +++ b/deepspeed/model_implementations/diffusers/vae.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +from deepspeed.accelerator import get_accelerator from ..features.cuda_graph import CUDAGraph @@ -27,7 +28,7 @@ def _graph_replay_decoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_decoder_kwargs[k].copy_(kwargs[k]) - self._decoder_cuda_graph.replay() + get_accelerator().replay_graph(self._decoder_cuda_graph) return self.static_decoder_output def _decode(self, x, return_dict=True, generator=None): @@ -43,11 +44,11 @@ def _create_cuda_graph_decoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._decoder_cuda_graph = torch.cuda.CUDAGraph() + self._decoder_cuda_graph = get_accelerator().create_graph() self.static_decoder_inputs = inputs self.static_decoder_kwargs = kwargs - with torch.cuda.graph(self._decoder_cuda_graph): + with get_accelerator().capture_to_graph(self._decoder_cuda_graph): self.static_decoder_output = self._decode(*self.static_decoder_inputs, **self.static_decoder_kwargs) self.decoder_cuda_graph_created = True @@ -70,7 +71,7 @@ def _graph_replay_encoder(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_encoder_kwargs[k].copy_(kwargs[k]) - self._encoder_cuda_graph.replay() + get_accelerator().replay_graph(self._encoder_cuda_graph) return self.static_encoder_output def _encode(self, x, return_dict=True): @@ -86,11 +87,11 @@ def _create_cuda_graph_encoder(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._encoder_cuda_graph = torch.cuda.CUDAGraph() + self._encoder_cuda_graph = get_accelerator().create_graph() self.static_encoder_inputs = inputs self.static_encoder_kwargs = kwargs - with torch.cuda.graph(self._encoder_cuda_graph): + with get_accelerator().capture_to_graph(self._encoder_cuda_graph): self.static_encoder_output = self._encode(*self.static_encoder_inputs, **self.static_encoder_kwargs) self.encoder_cuda_graph_created = True @@ -113,7 +114,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) - self._all_cuda_graph.replay() + get_accelerator().replay_graph(self._all_cuda_graph) return self.static_output def forward(self, *inputs, **kwargs): @@ -137,11 +138,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._all_cuda_graph = torch.cuda.CUDAGraph() + self._all_cuda_graph = get_accelerator().create_graph() self.static_inputs = inputs self.static_kwargs = kwargs - with torch.cuda.graph(self._all_cuda_graph): + with get_accelerator().capture_to_graph(self._all_cuda_graph): self.static_output = self._forward(*self.static_inputs, **self.static_kwargs) self.all_cuda_graph_created = True diff --git a/deepspeed/model_implementations/transformers/clip_encoder.py b/deepspeed/model_implementations/transformers/clip_encoder.py index 8d9291896986..848a5b48dcf1 100644 --- a/deepspeed/model_implementations/transformers/clip_encoder.py +++ b/deepspeed/model_implementations/transformers/clip_encoder.py @@ -38,7 +38,7 @@ def _graph_replay(self, *inputs, **kwargs): for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[self.iter][k].copy_(kwargs[k]) - self._cuda_graphs[self.iter].replay() + get_accelerator().replay_graph(self._cuda_graphs[self.iter]) return self.static_output[self.iter] def forward(self, *inputs, **kwargs): @@ -63,11 +63,11 @@ def _create_cuda_graph(self, *inputs, **kwargs): torch.cuda.current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs - self._cuda_graphs[self.iter] = torch.cuda.CUDAGraph() + self._cuda_graphs[self.iter] = get_accelerator().create_graph() self.static_inputs[self.iter] = inputs self.static_kwargs[self.iter] = kwargs - with torch.cuda.graph(self._cuda_graphs[self.iter]): + with get_accelerator().capture_to_graph(self._cuda_graphs[self.iter]): self.static_output[self.iter] = self._forward(*self.static_inputs[self.iter], **self.static_kwargs[self.iter]) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 888505279290..0aefd1946c36 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -16,7 +16,7 @@ from deepspeed.git_version_info import version from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim, align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, - is_model_parallel_parameter, see_memory_usage) + is_model_parallel_parameter, see_memory_usage, graph_process) from deepspeed.utils import link_hp_params, fragment_address from deepspeed.checkpoint import enable_universal_checkpoint @@ -38,7 +38,8 @@ def __init__(self, allgather_bucket_size=5000000000, dp_process_group=None, timers=None, - grad_acc_dtype=None): + grad_acc_dtype=None, + graph_harvesting=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -81,7 +82,7 @@ def __init__(self, self.fp32_groups_has_gradients = [] self.group_paddings = [] - + self.graph_harvesting = graph_harvesting if self.using_real_optimizer: self._setup_for_real_optimizer() @@ -248,7 +249,8 @@ def step(self, closure=None): all_groups_norm = get_global_norm_of_tensors(input_tensors=self.get_grads_for_norm(), mpu=self.mpu, - norm_type=self.norm_type) + norm_type=self.norm_type, + use_graph=self.graph_harvesting) self._global_grad_norm = all_groups_norm assert all_groups_norm > 0. @@ -256,7 +258,8 @@ def step(self, closure=None): clip_tensors_by_global_norm(input_tensors=self.get_grads_for_norm(for_clipping=True), max_norm=self.clip_grad, global_norm=all_groups_norm, - mpu=self.mpu) + mpu=self.mpu, + use_graph=self.graph_harvesting) self.optimizer.step() @@ -281,23 +284,33 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg @torch.no_grad() def update_hp_grads(self, clear_lp_grads=False): + + def _update_hp_grads_func(clear_lp_grads=False): + for i, group in enumerate(self.bf16_groups): + for j, lp in enumerate(group): + if lp.grad is None: + continue + hp_grad = self.fp32_groups_gradients[i][j] + assert hp_grad is not None, \ + f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' + hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) + lp._hp_grad = hp_grad + self.fp32_groups_has_gradients[i][j] = True + # clear gradients + if clear_lp_grads: + lp.grad._zero() + + if self.graph_harvesting: + graph_process(False, _update_hp_grads_func, clear_lp_grads) + else: + _update_hp_grads_func(clear_lp_grads) + #cpu op for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): if lp.grad is None: continue - - hp_grad = self.fp32_groups_gradients[i][j] - assert hp_grad is not None, \ - f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' - - hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) - lp._hp_grad = hp_grad self.fp32_groups_has_gradients[i][j] = True - # clear gradients - if clear_lp_grads: - lp.grad = None - @torch.no_grad() def get_grads_for_reduction(self): return self.fp32_groups_gradients_flat @@ -348,7 +361,9 @@ def clear_hp_grads(self): def clear_lp_grads(self): for group in self.bf16_groups: for param in group: - param.grad = None + if param.grad is not None: + # Using zero_() fixed memory address for graph replay + param.grad.zero_() def state_dict(self): state_dict = {} diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index b49469b94f11..80754df50c20 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -279,6 +279,10 @@ def get_gradient_clipping(param_dict): return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT) +def get_graph_harvesting(param_dict): + return get_scalar_param(param_dict, GRAPH_HARVESTING, GRAPH_HARVESTING_DEFAULT) + + def get_sparse_attention(param_dict): if SPARSE_ATTENTION in param_dict.keys(): sparsity = param_dict[SPARSE_ATTENTION] @@ -823,6 +827,7 @@ def _initialize_params(self, param_dict): self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) self.compression_config = get_compression_config(param_dict) + self.graph_harvesting = get_graph_harvesting(param_dict) self.optimizer_name = get_optimizer_name(param_dict) if (self.optimizer_name is not None and self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS): diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 96f2a38bd05c..82d8a0557a41 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -210,6 +210,18 @@ GRADIENT_CLIPPING = 'gradient_clipping' GRADIENT_CLIPPING_DEFAULT = 0. +######################################### +# Capture graph for short kernels sequences +######################################### +# Graph harvesting. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +GRAPH_HARVESTING_FORMAT = ''' +Graph harvesting should be enabled as: +"graph_harvesting": true +''' +GRAPH_HARVESTING = 'graph_harvesting' +GRAPH_HARVESTING_DEFAULT = False + ######################################### # Communication data type ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c5f4d3e6530d..a8cd4fffcce9 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -770,6 +770,9 @@ def zero_legacy_stage1(self): def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters + def graph_harvesting(self): + return self._config.graph_harvesting + def fp16_enabled(self): return self._config.fp16_enabled @@ -1451,7 +1454,8 @@ def _configure_bf16_optimizer(self, optimizer): allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, timers=timers, - grad_acc_dtype=self.get_data_types()[1]) + grad_acc_dtype=self.get_data_types()[1], + graph_harvesting=self.graph_harvesting()) return optimizer diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index bc7a782e590c..82f200fccf9f 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -47,6 +47,27 @@ def __init__(self, params): self.param_groups.append({'params': params}) +graph_cache = {} + + +def graph_process(replay_first_step, func, *args, **kwargs): + # `func` should only contain operations on the GPU + # Please ensure that the memory address of the data required by 'func' remains constant + if func.__name__ not in graph_cache: + cuda_stream = get_accelerator().Stream() + cuda_stream.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(cuda_stream): + func(*args, **kwargs) + get_accelerator().current_stream().wait_stream(cuda_stream) + graph_cache[func.__name__] = get_accelerator().create_graph() + with get_accelerator().capture_to_graph(graph_cache[func.__name__]): + func(*args, **kwargs) + if replay_first_step: + get_accelerator().replay_graph(graph_cache[func.__name__]) + else: + get_accelerator().replay_graph(graph_cache[func.__name__]) + + def noop_decorator(func): return func @@ -831,7 +852,7 @@ def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, ep return global_grad_norm -def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False): """Get norm of an iterable of tensors. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -845,7 +866,6 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): Returns: Total norm of the tensors (viewed as a single vector). """ - assert isinstance(input_tensors, Iterable), f'expected Iterable type not {type(input_tensors)}' assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' @@ -857,8 +877,24 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item() else: - total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + if use_graph: + if 'norm_tensors_compute_buffer' not in graph_cache: + graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors] + compute_buffer = graph_cache['norm_tensors_compute_buffer'] + + def _norm_tensors(tensor_list, _compute_buffer, _norm_type): + for i, t in enumerate(tensor_list): + _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) + if i != 0: + _compute_buffer[0].data.add_(_compute_buffer[i].data) + + graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) + + total_norm = compute_buffer[0] + else: + total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) + + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) total_norm = total_norm_cuda[0].item()**(1. / norm_type) @@ -869,7 +905,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): return total_norm -def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6): +def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, mpu=None, eps=1e-6, use_graph=False): """Clip list of tensors by global norm. Args: input_tensors: List of tensors to be clipped @@ -880,14 +916,26 @@ def clip_tensors_by_global_norm(input_tensors, max_norm=1.0, global_norm=None, m float: the global norm """ if global_norm is None: - global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu) - + global_norm = get_global_norm_of_tensors(input_tensors, mpu=mpu, use_graph=use_graph) clip_coef = max_norm / (global_norm + eps) - if clip_coef < 1: - for t in input_tensors: - t.detach().mul_(clip_coef) + if use_graph: + def clip_tensors(_tensor_list, _clip_coef_tensor): + for t in _tensor_list: + t.detach().mul_(_clip_coef_tensor) + + if 'clip_coef_tensor' not in graph_cache: + # Alloc memory + graph_cache['clip_coef_tensor'] = torch.tensor(clip_coef, + dtype=torch.float32).to(get_accelerator().device_name()) + clip_coef_tensor = graph_cache['clip_coef_tensor'] + clip_coef_tensor.copy_(torch.tensor(clip_coef, dtype=torch.float32)) + graph_process(False, clip_tensors, input_tensors, clip_coef_tensor) + + else: + for t in input_tensors: + t.detach().mul_(clip_coef) return global_norm From 9e455d76516e785cbaf058d351b6a78d02c42ed8 Mon Sep 17 00:00:00 2001 From: Max Kovalenko <75629718+deepcharm@users.noreply.github.com> Date: Wed, 20 Dec 2023 23:52:31 +0200 Subject: [PATCH 02/10] Checkpointing: Avoid assigning tensor storage with different device (#4836) On some back-ends, assigning tensor.data to a storage being on a different device than the tensor is not supported. The fix is to save the storage in a temp data member and restore tensor.data when needed. Co-authored-by: Olatunji Ruwase --- .../activation_checkpointing/checkpointing.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 772d23f2d0ac..02e0b197e927 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -439,7 +439,9 @@ def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint num_non_fp_tensors += 1 continue - arg.data = inp.data + arg.data = torch.empty([], device=arg.device).data + arg.saved_data = inp.data + new_args.append(arg) i = arg_index - num_non_fp_tensors @@ -472,7 +474,8 @@ def get_cpu_activations_for_backward(args, inputs): new_args.append(arg) continue - arg.data = inp.data + arg.data = torch.empty([], device=arg.device).data + arg.saved_data = inp.data new_args.append(arg) return new_args @@ -628,6 +631,12 @@ def backward(ctx, *grads): global cuda_device, transport_stream, PARTITION_ACTIVATIONS + # Rebuild deepspeed_saved_tensors + for t in ctx.deepspeed_saved_tensors: + if t is not None and hasattr(t, 'saved_data') and t.saved_data is not None: + t.data = t.saved_data.to(t.device) + t.saved_data = None + if PARTITION_ACTIVATIONS: # with get_accelerator().stream(transport_stream): inputs = gather_partitioned_activations(ctx.deepspeed_saved_tensors, From 18643914bb5b4be9150711fa26abddc2de4641e7 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Thu, 21 Dec 2023 01:39:50 +0200 Subject: [PATCH 03/10] engine.py: remove unused _curr_save_path (#4844) Co-authored-by: Michael Wyatt --- deepspeed/runtime/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a8cd4fffcce9..9c9641a1c4cf 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3212,7 +3212,6 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa state.update(client_state) logger.info(f'Saving model checkpoint: {save_path}') self.checkpoint_engine.save(state, save_path) - self._curr_save_path = None def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name) From c00388a2ef933f243b28d89bafb1b329d72557ad Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Wed, 20 Dec 2023 16:05:26 -0800 Subject: [PATCH 04/10] Mixtral FastGen Support (#4828) Adds support for Mixtral with FastGen. Key features implemented: 1. Top-2 MoE support 2. Better support for RoPE thetas 3. The mistral model implementation --------- Co-authored-by: Michael Wyatt --- .../v2/checkpoint/huggingface_engine.py | 2 +- deepspeed/inference/v2/engine_factory.py | 7 + .../v2/kernels/ragged_ops/__init__.py | 2 +- .../kernels/ragged_ops/includes/top_k_utils.h | 15 + .../blocked_kv_rotary.cpp | 4 + .../blocked_kv_rotary.cu | 26 +- .../blocked_kv_rotary.cuh | 1 + .../blocked_kv_rotary.h | 1 + .../blocked_kv_rotary.py | 12 +- .../ragged_ops/moe_gather/moe_gather.cpp | 10 +- .../ragged_ops/moe_gather/moe_gather.cu | 107 +++++-- .../ragged_ops/moe_gather/moe_gather.cuh | 2 + .../ragged_ops/moe_gather/moe_gather.h | 3 +- .../ragged_ops/moe_gather/moe_gather.py | 9 +- .../ragged_ops/moe_scatter/moe_scatter.cpp | 7 +- .../ragged_ops/moe_scatter/moe_scatter.cu | 188 ++++++------ .../ragged_ops/moe_scatter/moe_scatter.cuh | 1 + .../ragged_ops/moe_scatter/moe_scatter.py | 8 +- .../v2/kernels/ragged_ops/ragged_ops.cpp | 6 +- .../__init__.py | 2 +- .../top_k_gating.cpp} | 26 +- .../top_k_gating.cu} | 69 +++-- .../top_k_gating.cuh} | 3 +- .../top_k_gating.h} | 4 +- .../top_k_gating.py} | 14 +- .../v2/model_implementations/__init__.py | 1 + .../common_parameters/moe_parameters.py | 23 +- .../model_implementations/falcon/__init__.py | 2 +- .../{falcon_containers.py => container.py} | 4 +- .../falcon/{falcon_model.py => model.py} | 4 +- .../falcon/{falcon_policy.py => policy.py} | 8 +- .../inference_transformer_base.py | 20 +- .../llama_v2/__init__.py | 2 +- .../{llama_v2_containers.py => container.py} | 4 +- .../llama_v2/{llama_v2_model.py => model.py} | 29 +- .../{llama_v2_policy.py => policy.py} | 6 +- .../v2/model_implementations/mistral/model.py | 25 +- .../model_implementations/mistral/policy.py | 8 +- .../model_implementations/mixtral/__init__.py | 6 + .../mixtral/container.py | 46 +++ .../v2/model_implementations/mixtral/model.py | 274 ++++++++++++++++++ .../model_implementations/mixtral/policy.py | 31 ++ .../v2/model_implementations/opt/container.py | 4 +- .../v2/model_implementations/opt/model.py | 3 +- .../v2/model_implementations/opt/policy.py | 6 +- .../inference/v2/modules/configs/__init__.py | 7 +- .../v2/modules/configs/attention_configs.py | 24 +- .../v2/modules/configs/moe_config.py | 6 + .../attention/dense_blocked_attention.py | 20 +- .../implementations/moe/cutlass_multi_gemm.py | 88 ++++-- op_builder/ragged_ops.py | 7 +- .../v2/kernels/ragged_ops/test_moe_gather.py | 67 ++++- .../v2/kernels/ragged_ops/test_moe_scatter.py | 69 +++-- ...t_top_1_gating.py => test_top_k_gating.py} | 83 +++++- .../parameters/test_parameter_list.py | 2 +- .../inference/v2/modules/test_blocked_attn.py | 11 +- .../inference/v2/modules/test_cutlass_moe.py | 114 ++++++++ 57 files changed, 1193 insertions(+), 340 deletions(-) create mode 100644 deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating => top_k_gating}/__init__.py (69%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.cpp => top_k_gating/top_k_gating.cpp} (67%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.cu => top_k_gating/top_k_gating.cu} (59%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.cuh => top_k_gating/top_k_gating.cuh} (87%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.h => top_k_gating/top_k_gating.h} (86%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.py => top_k_gating/top_k_gating.py} (87%) rename deepspeed/inference/v2/model_implementations/falcon/{falcon_containers.py => container.py} (97%) rename deepspeed/inference/v2/model_implementations/falcon/{falcon_model.py => model.py} (98%) rename deepspeed/inference/v2/model_implementations/falcon/{falcon_policy.py => policy.py} (74%) rename deepspeed/inference/v2/model_implementations/llama_v2/{llama_v2_containers.py => container.py} (95%) rename deepspeed/inference/v2/model_implementations/llama_v2/{llama_v2_model.py => model.py} (83%) rename deepspeed/inference/v2/model_implementations/llama_v2/{llama_v2_policy.py => policy.py} (76%) create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/__init__.py create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/container.py create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/model.py create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/policy.py rename tests/unit/inference/v2/kernels/ragged_ops/{test_top_1_gating.py => test_top_k_gating.py} (51%) diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 6b64ed3185a2..ca9fb113b15a 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -61,7 +61,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool: # We need to download the checkpoint files from HF if model_has_safetensors(self.model_name_or_path): # Prioritize downloading safetensors if they are available - allow_patterns = ["*.safetensors", "*.json", "*.pt"] + allow_patterns = ["*.safetensors", "*.json"] else: # Fallback to bin files when safetensors are not present allow_patterns = ["*.bin", "*.json", "*.pt"] diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index 9558125ff934..a0dc050bbbf9 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -17,6 +17,7 @@ OPTPolicy, Llama2Policy, MistralPolicy, + MixtralPolicy, FalconPolicy, ) from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy @@ -105,6 +106,12 @@ def build_hf_engine(path: str, assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \ f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}" policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "mixtral": + # Ensure we're using the correct version of transformers for mistral + import transformers + assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \ + f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}" + policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "falcon": policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine) else: diff --git a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py index 988152b2e7c0..38a4ebd6fba3 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py @@ -10,4 +10,4 @@ from .logits_gather import * from .moe_gather import * from .moe_scatter import * -from .top_1_gating import * +from .top_k_gating import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h new file mode 100644 index 000000000000..abb9e15f8f6f --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#define TOP_K_SWITCH(N_TOP_K, ...) \ + [&] { \ + if (1 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 1; \ + __VA_ARGS__(); \ + } else if (2 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 2; \ + __VA_ARGS__(); \ + } \ + }() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp index 8493bbf4b9af..a640c2b30164 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp @@ -13,6 +13,7 @@ (C_TYPE*)k.data_ptr(), \ (C_TYPE*)v.data_ptr(), \ (C_TYPE*)inv_freq_ptr, \ + theta_base, \ batch_wrapper, \ qkv_stride, \ kv_cache_stride, \ @@ -51,6 +52,8 @@ void kv_trained_rotary_embeddings(torch::Tensor& kv_cache, TORCH_CHECK(n_tokens == k.size(0)); TORCH_CHECK(n_tokens == v.size(0)); + const float theta_base = 0.f; + // Dimensions const int32_t block_size = kv_cache.size(1); const int32_t n_kv_heads = kv_cache.size(3); @@ -91,6 +94,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache, torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, + const float theta_base, torch::Tensor& batch_metadata, torch::Tensor& seq_metadata, torch::Tensor& tokens_to_seq, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu index 980334f02b0b..5dd79f0c636a 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu @@ -27,6 +27,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, T* k, T* v, const T* inv_freq, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, @@ -114,7 +115,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, // Conversion to T and back means that both branches of this if statement // will produce the same results if using the same algo for producing the // freqs. - T trunc_freq = conversion::to(1.0 / powf(10000.0, inv_freq_flt)); + T trunc_freq = conversion::to(1.0 / powf(theta_base, inv_freq_flt)); inv_freq_flt = conversion::to(trunc_freq) * (float)global_token_idx; } @@ -158,7 +159,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, } else { inv_freq_flt = (float)((head_neuron_idx % half_head_size) * 2) / (float)headSize; - inv_freq_flt = 1.0 / powf(10000.0, inv_freq_flt) * (float)global_token_idx; + inv_freq_flt = 1.0 / powf(theta_base, inv_freq_flt) * (float)global_token_idx; } float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f; @@ -186,6 +187,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, k, \ v, \ inv_freq, \ + theta_base, \ batch_desc, \ qkv_stride, \ kv_cache_stride, \ @@ -198,6 +200,7 @@ void launch_kv_rotary_kernel(T* kv_cache, T* k, T* v, T* inv_freq, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, @@ -245,6 +248,7 @@ void launch_kv_rotary_kernel(T* kv_cache, TYPE * k, \ TYPE * v, \ TYPE * inv_freq, \ + const float theta_base, \ const BatchWrapperCPP batch_desc, \ const int qkv_stride, \ const int kv_cache_stride, \ @@ -262,10 +266,20 @@ INSTANTIATE_KV_ROTARY_KERNEL(__half) INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16) #endif -#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \ - if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ - kv_rotary_pos_kernel<<>>( \ - kv_cache, q, k, v, nullptr, batch_desc, qkv_stride, kv_cache_stride, v_offset, 0); +#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + nullptr, \ + 0.f, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + 0); template void launch_kv_copy_kernel(T* kv_cache, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh index be38ff30c46c..41a69d3b397b 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh @@ -18,6 +18,7 @@ void launch_kv_rotary_kernel(T* kv_cache, T* k, T* v, T* inv_freq, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h index 0615825c0a21..e56ce644dbbc 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h @@ -45,6 +45,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache, torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, + const float theta_base, torch::Tensor& batch_metadata, torch::Tensor& seq_metadata, torch::Tensor& tokens_to_seq, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py index 50d9aca061f3..f206a4f5d28c 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -21,7 +21,12 @@ class BlockedRotaryEmbeddings(DSKernelBase): supported_head_sizes = [64, 128] supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71] - def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + def __init__(self, + head_size: int, + n_q_heads: int, + n_kv_heads: int, + dtype: torch.dtype, + theta_base: float = 10000.0) -> None: """ Args: head_size: The size of the attention head. @@ -51,6 +56,7 @@ def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch self.head_size = head_size self.n_q_heads = n_q_heads self.n_kv_heads = n_kv_heads + self.theta_base = theta_base def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None: """ @@ -66,5 +72,5 @@ def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: Ragg k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] - self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(), - ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) + self.kernel(kv_cache, q, k, v, self.theta_base, ragged_batch.batch_metadata_buffer(), + ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp index e55e1f48c125..506629406f0d 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp @@ -16,6 +16,8 @@ n_channels, \ n_experts, \ n_tokens, \ + n_top_k, \ + normalize_scales, \ at::cuda::getCurrentCUDAStream()); \ return; \ } @@ -27,17 +29,21 @@ void moe_gather(torch::Tensor& layer_output, const torch::Tensor& moe_output, const torch::Tensor& scores, const torch::Tensor& mapped_slots, - const torch::Tensor& expert_count) + const torch::Tensor& expert_count, + const bool normalize_scales) { const int32_t n_channels = layer_output.size(1); const int32_t n_experts = expert_count.size(0); const int32_t n_tokens = layer_output.size(0); + const int32_t n_top_k = mapped_slots.size(1); - TORCH_CHECK(moe_output.size(0) == n_tokens); + TORCH_CHECK(moe_output.size(0) == n_tokens * n_top_k); TORCH_CHECK(moe_output.size(1) == n_channels); TORCH_CHECK(scores.size(0) == n_tokens); TORCH_CHECK(mapped_slots.size(0) == n_tokens); + TORCH_CHECK(scores.size(1) == n_top_k); + TORCH_CHECK(layer_output.scalar_type() == moe_output.scalar_type()); TORCH_CHECK(scores.scalar_type() == torch::kFloat32); TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu index c2fae24f5080..4153a2a3636f 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu @@ -7,7 +7,8 @@ #include "ds_kernel_utils.h" #include "moe_gather.cuh" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" namespace gather { @@ -16,65 +17,105 @@ constexpr int threads = 256; } // namespace gather -template +template __global__ void moe_gather_kernel(T* layer_output, const T* moe_output, const float* scores, const int32_t* mapped_slots, int32_t* expert_counts, const int32_t n_channels, - const int32_t n_experts) + const int32_t n_experts, + const bool normalize_scales) { constexpr int32_t vector_size = gather::access_granularity / sizeof(T); constexpr int32_t stride = vector_size * gather::threads; const int32_t token_idx = blockIdx.x; - const int32_t mapped_slot = mapped_slots[token_idx]; + int32_t token_mapped_slots[N_TOP_K]; + + bool all_slots_invalid = true; + for (int i = 0; i < N_TOP_K; i++) { + token_mapped_slots[i] = mapped_slots[token_idx * N_TOP_K + i]; + all_slots_invalid &= (token_mapped_slots[i] == gating::unassigned); + } if (token_idx == 0) { // Reset expert counts for its next use. if (threadIdx.x < n_experts) { expert_counts[threadIdx.x] = 0; } } - if (mapped_slot == gating::unassigned) { - // This token was not assigned. + if (all_slots_invalid) { + // This token was not assigned to anything. // TODO(cmikeh2): It's possible we want different behavior here moving forward. return; } - const float score = scores[token_idx]; + float token_scores[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] = scores[token_idx * N_TOP_K + i]; } + + if (normalize_scales) { + // Normalize the scores so that they sum to 1. + float sum = 0.0f; + for (int i = 0; i < N_TOP_K; i++) { sum += token_scores[i]; } + + if (sum > 0.0f) { + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] /= sum; } + } + } + const int32_t channel_offset = threadIdx.x * vector_size; - const T* moe_output_base = moe_output + mapped_slot * n_channels + channel_offset; + const T* moe_output_bases[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + moe_output_bases[i] = moe_output + token_mapped_slots[i] * n_channels + channel_offset; + } + T* layer_output_base = layer_output + token_idx * n_channels + channel_offset; #pragma unroll for (int i = 0; i < copyUnroll; i++) { - T reg_buffer[vector_size]; - if (i * stride + channel_offset < n_channels) { - mem_access::load_global(reg_buffer, - moe_output_base + i * stride); + float accum_buffer[vector_size]; + for (int j = 0; j < vector_size; j++) { + accum_buffer[j] = reduce::init(); + } + +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + T reg_buffer[vector_size]; + mem_access::load_global( + reg_buffer, moe_output_bases[j] + i * stride); +#pragma unroll + for (int k = 0; k < vector_size; k++) { + float up_cast = conversion::to(reg_buffer[k]); + accum_buffer[k] += up_cast * token_scores[j]; + } + } + + T store_buffer[vector_size]; #pragma unroll for (int j = 0; j < vector_size; j++) { - // There are accuracy implications of downcasting the score to a 16-bit - // data type, so we up-convert the input to 32-bit, multiply, and then - // down-convert back to 16-bit. - float up_cast = conversion::to(reg_buffer[j]); - reg_buffer[j] = conversion::to(up_cast * score); + store_buffer[j] = conversion::to(accum_buffer[j]); } mem_access::store_global(layer_output_base + i * stride, - reg_buffer); + store_buffer); } } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - moe_gather_kernel<<>>( \ - layer_output, moe_output, scores, mapped_slots, expert_counts, n_channels, n_experts); \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_gather_kernel<<>>(layer_output, \ + moe_output, \ + scores, \ + mapped_slots, \ + expert_counts, \ + n_channels, \ + n_experts, \ + normalize_scales); \ break; template @@ -86,6 +127,8 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, const int32_t n_experts, const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, cudaStream_t stream) { constexpr int vals_per_unroll = gather::threads * gather::access_granularity / sizeof(T); @@ -94,14 +137,16 @@ void launch_moe_gather(T* layer_output, const dim3 block(gather::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1) - LAUNCH_FOR_UNROLL(2) - LAUNCH_FOR_UNROLL(3) - LAUNCH_FOR_UNROLL(4) - LAUNCH_FOR_UNROLL(5) - LAUNCH_FOR_UNROLL(6) - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1) + LAUNCH_FOR_UNROLL(2) + LAUNCH_FOR_UNROLL(3) + LAUNCH_FOR_UNROLL(4) + LAUNCH_FOR_UNROLL(5) + LAUNCH_FOR_UNROLL(6) + } + }); } #define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \ @@ -113,6 +158,8 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, \ const int32_t n_experts, \ const int32_t n_tokens, \ + const int32_t n_top_k, \ + const bool normalize_scales, \ cudaStream_t stream); INSTANTIATE_GATHER_FOR_TYPE(__half) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh index f98a727ead58..b348d0cfb330 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh @@ -17,4 +17,6 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, const int32_t n_experts, const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h index 7ffe9f8b4dc6..ec9e03057eb8 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h @@ -16,4 +16,5 @@ void moe_gather(torch::Tensor& layer_output, const torch::Tensor& moe_output, const torch::Tensor& scores, const torch::Tensor& mapped_slots, - const torch::Tensor& expert_counts); + const torch::Tensor& expert_counts, + const bool normalize_scales); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py index c37683d03fbe..f03938171ba4 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py @@ -18,7 +18,7 @@ class MoEGather(DSKernelBase): supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - def __init__(self, dtype: DtypeEnum, channels: int) -> None: + def __init__(self, dtype: DtypeEnum, channels: int, normalize_scores: bool = False) -> None: if not isinstance(dtype, DtypeEnum): dtype = DtypeEnum(dtype) @@ -31,6 +31,7 @@ def __init__(self, dtype: DtypeEnum, channels: int) -> None: inf_module = RaggedOpsBuilder().load() self.kernel = inf_module.moe_gather + self.normalize_scores = normalize_scores def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: torch.Tensor, mapped_slots: torch.Tensor, expert_counts: torch.Tensor) -> torch.Tensor: @@ -40,13 +41,13 @@ def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: Arguments: layer_output (torch.Tensor): The output of the layer of shape [n_tokens, hidden_size]. This has been scaled appropriately. - moe_output (torch.Tensor): The output of the MoE of shape [n_tokens, hidden_size]. + moe_output (torch.Tensor): The output of the MoE of shape [n_tokens * n_top_k, hidden_size]. scores (torch.Tensor): The gating scores of shape [n_tokens]. - mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. The index of token ``i`` in layer_output is ``mapped_slots[i]``. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. The indices of token ``i`` in layer_output is ``mapped_slots[i]``. expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. This is passed to fuse the clearing of this data structure into the gather. Returns: layer_output """ - self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts) + self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts, self.normalize_scores) return layer_output diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp index 902f1cc0ea15..8f7ecbd1a287 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp @@ -18,6 +18,7 @@ n_channels, \ n_tokens, \ n_experts, \ + n_top_k, \ at::cuda::getCurrentCUDAStream()); \ return; \ } @@ -36,13 +37,17 @@ void moe_scatter(torch::Tensor& moe_input, { const int32_t n_tokens = activations.size(0); const int32_t n_channels = activations.size(1); + const int32_t n_top_k = assignments.size(1); // Should have a lot of matching buffer sizes here. - TORCH_CHECK(n_tokens == moe_input.size(0)); TORCH_CHECK(n_tokens == assignments.size(0)); TORCH_CHECK(n_tokens == offsets.size(0)); TORCH_CHECK(n_channels == moe_input.size(1)); + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k * n_tokens == moe_input.size(0)); + TORCH_CHECK(n_top_k == mapped_slots.size(1)); + const int32_t n_experts = expert_count_cumsums.size(0); TORCH_CHECK(moe_input.scalar_type() == activations.scalar_type()); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu index 0746cd7be645..d3eb4f649e79 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu @@ -4,9 +4,9 @@ // DeepSpeed Team #include "ds_kernel_utils.h" -#include "moe_scatter.cuh" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; @@ -15,10 +15,11 @@ namespace scatter { constexpr int access_granularity = 16; constexpr int threads = 256; constexpr int warps = threads / hw_warp_size; +constexpr int max_experts = 1024; } // namespace scatter -template +template __global__ void moe_scatter_kernel(T* moe_input, int64_t* expert_count_cumsums, int32_t* mapped_slots, @@ -38,88 +39,78 @@ __global__ void moe_scatter_kernel(T* moe_input, // Bank aligned and sufficient __shared__ int32_t red_buffer[32]; - __shared__ int32_t token_0_row; + __shared__ int32_t expert_offsets[scatter::max_experts]; // CG helpers cg::thread_block tb = cg::this_thread_block(); cg::thread_block_tile warp = cg::tiled_partition(tb); - const int assigned_expert = assignments[token_idx]; - - // For the different codepaths, we'll converge on this variable for doing - // the token copy. - int32_t token_base_row; + // Fetch the assigned experts for this token. + int assigned_experts[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { + assigned_experts[i] = assignments[token_idx * N_TOP_K + i]; + } - if (token_idx == 0) { - // Token 0 will perform a cumsum on the data - int32_t expert_vals; - if (tidx < n_experts) { - expert_vals = expert_counts[tidx]; + bool all_unassigned = true; + for (int i = 0; i < N_TOP_K; i++) { + if (assigned_experts[i] != gating::unassigned) { + all_unassigned = false; } else { - expert_vals = 0; + mapped_slots[token_idx * N_TOP_K + i] = gating::unassigned; } + } + if (all_unassigned && token_idx != 0) return; + + // Do a prefix scan on the expert counts to get the base offsets. Here we use the + // single up-sweep variant. + int32_t expert_vals; + if (tidx < n_experts) { + expert_vals = expert_counts[tidx]; + } else { + expert_vals = 0; + } #pragma unroll - for (int i = 1; i < hw_warp_size; i *= 2) { - int32_t maybe_add = warp.shfl_up(expert_vals, i); - expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; - } + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(expert_vals, i); + expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; + } - if (warp.thread_rank() == hw_warp_size - 1) { - mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); - } + if (warp.thread_rank() == hw_warp_size - 1) { + mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); + } - tb.sync(); + tb.sync(); - int32_t phase_2_val = 0; - if (warp.thread_rank() < scatter::warps) { - mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); - } + int32_t phase_2_val = 0; + if (warp.thread_rank() < scatter::warps) { + mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); + } #pragma unroll - for (int i = 1; i < hw_warp_size; i *= 2) { - int32_t maybe_add = warp.shfl_up(phase_2_val, i); - phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; - } - - int warp_offset = 0; - if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } - const int32_t expert_cumsum = warp_offset + expert_vals; - - if (tidx < n_experts) { - int64_t expert_cumsum_64 = (int64_t)expert_cumsum; - expert_count_cumsums[tidx] = expert_cumsum_64; - } - - if (assigned_expert == gating::unassigned) return; - if (assigned_expert - 1 == tidx) token_0_row = expert_cumsum; + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(phase_2_val, i); + phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; + } - tb.sync(); + int warp_offset = 0; + if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } + const int32_t expert_cumsum = warp_offset + expert_vals; - if (assigned_expert != 0) { - token_base_row = token_0_row; - } else { - token_base_row = 0; - } + // Token 0 will write the + if (token_idx == 0 && tidx < n_experts) { + int64_t expert_cumsum_64 = (int64_t)expert_cumsum; + expert_count_cumsums[tidx] = expert_cumsum_64; + } - } else if (assigned_expert == gating::unassigned) { - // For whatever reason, don't need to perform the copy, so we'll early return - // and signal this wasn't mapped with a negative 1. - if (tidx == 0) mapped_slots[token_idx] = gating::unassigned; - return; - } else { - // For all other valid tokens, we can just do a block-scoped sum. - if (tidx < assigned_expert) { - token_base_row = expert_counts[tidx]; - } else { - token_base_row = 0; - } + // Since token 0 has now written the expert cumsum to global memory, + // if it has no valid experts, we can early return. + if (token_idx == 0 && all_unassigned) return; - warp.sync(); + if (tidx < n_experts) { expert_offsets[tidx] = expert_cumsum; } - // TODO(cmikeh2): Shouldn't use the internal api. - reduce::_block(tb, warp, &token_base_row); - } + // Ensure all the expert offsets are written in shared memory. + tb.sync(); // Data copy to appropriate location const int32_t thread_offset = tidx * vector_size; @@ -127,9 +118,16 @@ __global__ void moe_scatter_kernel(T* moe_input, const int32_t base_load_offset = token_idx * n_channels + thread_offset; const T* load_base_ptr = activations + base_load_offset; - const int32_t store_row = token_base_row + offsets[token_idx]; - const int32_t base_store_offset = store_row * n_channels + thread_offset; - T* store_base_ptr = moe_input + base_store_offset; + int32_t store_rows[N_TOP_K]; + T* store_base_ptrs[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + const int32_t cur_expert_offset = + (assigned_experts[i] > 0) ? expert_offsets[assigned_experts[i] - 1] : 0; + store_rows[i] = cur_expert_offset + offsets[token_idx * N_TOP_K + i]; + const int32_t base_store_offset = store_rows[i] * n_channels + thread_offset; + store_base_ptrs[i] = moe_input + base_store_offset; + } #pragma unroll for (int i = 0; i < copyUnroll; i++) { @@ -138,25 +136,31 @@ __global__ void moe_scatter_kernel(T* moe_input, if (i * load_stride + thread_offset < n_channels) { mem_access::load_global(tmp_buf, load_base_ptr + i * load_stride); - mem_access::store_global(store_base_ptr + i * load_stride, - tmp_buf); +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + mem_access::store_global( + store_base_ptrs[j] + i * load_stride, tmp_buf); + } } } - if (threadIdx.x == 0) { mapped_slots[token_idx] = store_row; } + if (threadIdx.x == 0) { + for (int i = 0; i < N_TOP_K; i++) { mapped_slots[token_idx * N_TOP_K + i] = store_rows[i]; } + } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - moe_scatter_kernel<<>>(moe_input, \ - expert_count_cumsums, \ - mapped_slots, \ - activations, \ - assignments, \ - expert_counts, \ - offsets, \ - n_channels, \ - n_experts); \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_scatter_kernel \ + <<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + activations, \ + assignments, \ + expert_counts, \ + offsets, \ + n_channels, \ + n_experts); \ break; template @@ -170,6 +174,7 @@ void launch_moe_scatter(T* moe_input, const int32_t n_channels, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream) { constexpr int vals_per_unroll = scatter::threads * scatter::access_granularity / sizeof(T); @@ -178,14 +183,16 @@ void launch_moe_scatter(T* moe_input, const dim3 block(scatter::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1); - LAUNCH_FOR_UNROLL(2); - LAUNCH_FOR_UNROLL(3); - LAUNCH_FOR_UNROLL(4); - LAUNCH_FOR_UNROLL(5); - LAUNCH_FOR_UNROLL(6); - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } + }); } #define INSTANTIATE_SCATTER_FOR_TYPE(TYPE) \ @@ -199,6 +206,7 @@ void launch_moe_scatter(T* moe_input, const int32_t, \ const int32_t, \ const int32_t, \ + const int32_t, \ cudaStream_t); INSTANTIATE_SCATTER_FOR_TYPE(__half); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh index 5c94cb0ef734..d9756c80f05a 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh @@ -19,4 +19,5 @@ void launch_moe_scatter(T* moe_input, const int32_t n_channels, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py index 5cd6ae5f0fe2..7efcedb4e880 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py @@ -40,13 +40,13 @@ def __call__(self, moe_input: torch.Tensor, expert_cumsum: torch.Tensor, mapped_ Scatters the hidden states such that the token stride for each expert's input is contiguous. Arguments: - moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, hidden_size]. + moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens * n_top_k, hidden_size]. expert_cumsum (torch.Tensor): The cumulative sum of the expert counts of shape [n_experts]. - mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. hidden_states (torch.Tensor): The hidden states of shape [n_tokens, hidden_size]. expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. - assignments (torch.Tensor): The expert assignments of shape [n_tokens]. - offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens]. + assignments (torch.Tensor): The expert assignments of shape [n_tokens, n_top_k]. + offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens, n_top_K]. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input (with scattered values), the cumsum of the offsets (for the MoE kernels themselves), and the assignments Tensor modified in place to show which row that token was mapped to in the input. diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp index 1c09fc52bbb1..f320f46e2620 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp @@ -12,7 +12,7 @@ #include "logits_gather.h" #include "moe_gather.h" #include "moe_scatter.h" -#include "top_1_gating.h" +#include "top_k_gating.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -43,6 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // moe_scatter.h m.def("moe_scatter", &moe_scatter, "MoE scatter for top-1-gating."); - // top_1_gating.h - m.def("top_1_gating", &top_1_gating, "Top-1 gating for MoE with ragged batch awareness."); + // top_k_gating.h + m.def("top_k_gating", &top_k_gating, "Top-1 gating for MoE with ragged batch awareness."); } diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py similarity index 69% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py index b50a0838d9f8..487735b015b0 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .top_1_gating import RaggedTop1Gating +from .top_k_gating import RaggedTopKGating diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp similarity index 67% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp index 55c68454b228..5eec7e2b955f 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp @@ -3,12 +3,12 @@ // DeepSpeed Team -#include "top_1_gating.h" +#include "top_k_gating.h" #include -#define DISPATCH_TOP_1_GATING(T_TYPE, C_TYPE) \ +#define DISPATCH_TOP_K_GATING(T_TYPE, C_TYPE) \ if (logits.options().dtype() == torch::T_TYPE) { \ - launch_top_1_gating((int32_t*)expert_counts.data_ptr(), \ + launch_top_k_gating((int32_t*)expert_counts.data_ptr(), \ (float*)scores.data_ptr(), \ (int32_t*)assignments.data_ptr(), \ (int32_t*)offsets.data_ptr(), \ @@ -16,14 +16,15 @@ batch_metadata_ptr, \ n_tokens, \ n_experts, \ + n_top_k, \ at::cuda::getCurrentCUDAStream()); \ return; \ } /* -Perform softmax plus atomics in order to do first pass of top_1_gating. +Perform softmax plus atomics in order to do first pass of top_k_gating. */ -void top_1_gating(torch::Tensor& expert_counts, +void top_k_gating(torch::Tensor& expert_counts, torch::Tensor& scores, torch::Tensor& assignments, torch::Tensor& offsets, @@ -31,10 +32,15 @@ void top_1_gating(torch::Tensor& expert_counts, torch::Tensor& batch_metadata) { const int32_t n_tokens = scores.size(0); + const int32_t n_top_k = scores.size(1); - // Should have the same buffer size for scores and offsets + // Should have the same buffer size for scores, offsets, and assignments TORCH_CHECK(n_tokens == offsets.size(0)); TORCH_CHECK(n_tokens == logits.size(0)); + TORCH_CHECK(n_tokens == assignments.size(0)); + + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k == assignments.size(1)); TORCH_CHECK(expert_counts.scalar_type() == torch::kInt32); TORCH_CHECK(scores.scalar_type() == torch::kFloat); @@ -45,11 +51,11 @@ void top_1_gating(torch::Tensor& expert_counts, const RaggedBatchDescriptor* batch_metadata_ptr = reinterpret_cast(batch_metadata.data_ptr()); - DISPATCH_TOP_1_GATING(kFloat, float) - DISPATCH_TOP_1_GATING(kHalf, __half) + DISPATCH_TOP_K_GATING(kFloat, float) + DISPATCH_TOP_K_GATING(kHalf, __half) #ifdef BF16_AVAILABLE - DISPATCH_TOP_1_GATING(kBFloat16, __nv_bfloat16) + DISPATCH_TOP_K_GATING(kBFloat16, __nv_bfloat16) #endif - TORCH_CHECK(false, "Unsupported dtype for logits in top_1_gating"); + TORCH_CHECK(false, "Unsupported dtype for logits in top_k_gating"); } diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu similarity index 59% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu index 02daee9f692e..58f95c045593 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu @@ -6,12 +6,13 @@ #include "conversion_utils.h" #include "memory_access_utils.h" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; -template -__global__ void top_1_gating_kernel(int32_t* expert_counts, +template +__global__ void top_k_gating_kernel(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -30,8 +31,11 @@ __global__ void top_1_gating_kernel(int32_t* expert_counts, // Padding tokens do not require if (token_idx >= batch_metadata->n_tokens) { if (threadIdx.x == 0) { - offsets[token_idx] = gating::unassigned; - assignments[token_idx] = gating::unassigned; +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + assignments[token_idx * TOP_K + i] = gating::unassigned; + offsets[token_idx * TOP_K + i] = gating::unassigned; + } } return; } @@ -44,34 +48,46 @@ __global__ void top_1_gating_kernel(int32_t* expert_counts, } else { reduce::init(&logit_val); } + float reduce_val = logit_val; + + int32_t local_assigned_experts[TOP_K]; + float local_assigned_logits[TOP_K]; // Training code tends to use ``torch.argmax`` to select the expert, which // which has ties broken by the lower index. Since our fused comparison algorithm // breaks ties by the higher index (since it's the lower 32-bits of the 64-bit // comparison), we invert the expert index to break ties by the lower index. int32_t inverted_expert = n_experts - expert_idx - 1; - // Perform softmax - const reduce::IdxReduceResult res = - reduce::idx_reduce(tb, warp, logit_val, inverted_expert); - // Recover the original expert index - const int32_t assigned_expert = n_experts - res.idx - 1; - const float max_logit = res.val; + // Find the top k logits + for (int i = 0; i < TOP_K; ++i) { + const reduce::IdxReduceResult res = + reduce::idx_reduce(tb, warp, reduce_val, inverted_expert); + local_assigned_experts[i] = n_experts - res.idx - 1; + local_assigned_logits[i] = res.val; + + // Set the max logit to -inf so that it is not selected again + if (threadIdx.x == n_experts - res.idx - 1) { reduce::init(&reduce_val); } + } + + const float max_logit = local_assigned_logits[0]; float softmax_sum = __expf(logit_val - max_logit); reduce::block(tb, warp, softmax_sum); - // Compute the score - const float score = __expf(max_logit - max_logit) / softmax_sum; + for (int i = 0; i < TOP_K; ++i) { + const float softmax = __expf(local_assigned_logits[i] - max_logit) / softmax_sum; - if (threadIdx.x == 0) { - scores[token_idx] = score; - assignments[token_idx] = assigned_expert; - offsets[token_idx] = atomicAdd(expert_counts + assigned_expert, 1); + if (threadIdx.x == 0) { + scores[token_idx * TOP_K + i] = softmax; + assignments[token_idx * TOP_K + i] = local_assigned_experts[i]; + offsets[token_idx * TOP_K + i] = + atomicAdd(expert_counts + local_assigned_experts[i], 1); + } } } template -void launch_top_1_gating(int32_t* expert_counts, +void launch_top_k_gating(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -79,17 +95,20 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream) { const dim3 grid(n_tokens); const dim3 block(((n_experts + hw_warp_size - 1) / hw_warp_size) * hw_warp_size); - top_1_gating_kernel<<>>( - expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); + TOP_K_SWITCH(n_top_k, [&] { + top_k_gating_kernel<<>>( + expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); + }); } -#define INSTANTIATE_TOP_1_KERNEL(T) \ - template void launch_top_1_gating(int32_t * expert_counts, \ +#define INSTANTIATE_top_k_KERNEL(T) \ + template void launch_top_k_gating(int32_t * expert_counts, \ float* scores, \ int32_t* assignments, \ int32_t* offsets, \ @@ -97,10 +116,10 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, \ const int32_t n_tokens, \ const int32_t n_experts, \ + const int32_t n_top_k, \ cudaStream_t stream); -INSTANTIATE_TOP_1_KERNEL(float) -INSTANTIATE_TOP_1_KERNEL(__half) +INSTANTIATE_top_k_KERNEL(float) INSTANTIATE_top_k_KERNEL(__half) #ifdef BF16_AVAILABLE -INSTANTIATE_TOP_1_KERNEL(__nv_bfloat16) + INSTANTIATE_top_k_KERNEL(__nv_bfloat16) #endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh similarity index 87% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh index c83ad56ff2f1..c525cc5f524e 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh @@ -13,7 +13,7 @@ constexpr int unassigned = -1; } // namespace gating template -void launch_top_1_gating(int32_t* expert_counts, +void launch_top_k_gating(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -21,4 +21,5 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h similarity index 86% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h index b431f4cad30c..00840c3c93b5 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h @@ -8,12 +8,12 @@ #include #include #include "ragged_dtypes.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" /* Perform softmax plus atomics to get token mapping. */ -void top_1_gating(torch::Tensor& expert_counts, +void top_k_gating(torch::Tensor& expert_counts, torch::Tensor& scores, torch::Tensor& assignments, torch::Tensor& offsets, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py similarity index 87% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py index 1df97c2e9f8d..72ba2b6019bb 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py @@ -13,7 +13,7 @@ from deepspeed.ops.op_builder import RaggedOpsBuilder -class RaggedTop1Gating(DSKernelBase): +class RaggedTopKGating(DSKernelBase): """ CUDA implementation of top-1 gating. This will perform a softmax on the logits, and return the scale as well as its idx within that expert's allocation. @@ -26,28 +26,28 @@ def __init__(self, logit_dtype: DtypeEnum) -> None: if not isinstance(logit_dtype, DtypeEnum): logit_dtype = DtypeEnum(logit_dtype) - if logit_dtype not in RaggedTop1Gating.supported_logit_dtypes: + if logit_dtype not in RaggedTopKGating.supported_logit_dtypes: raise RuntimeError(f"Unsupported logit dtype {logit_dtype}") inf_module = RaggedOpsBuilder().load() - self.kernel = inf_module.top_1_gating + self.kernel = inf_module.top_k_gating def __call__(self, expert_counts: torch.Tensor, scores: torch.Tensor, assignments: torch.Tensor, offsets: torch.Tensor, logits: torch.Tensor, batch: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Perform the ragged top_1_gating. + Perform the ragged top_k_gating. Arguments: expert_counts (torch.Tensor): Tensor of 0s of shape [n_experts] to be filled with number of tokens assigned to each expert. This must be filled with 0s else the copy kernel will buffer overflow. In order to minimize the zero-fill cost, it is recommended to write to 0 during the MoE output remapping. - scores (torch.Tensor): Preallocated output of shape [n_tokens] to place expert scaling + scores (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place expert scaling value. - expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens] to place + expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place which expert a token has been assigned to. - expert_offset (torch.Tensor): Preallocated output of shape [n_tokens] to place which + expert_offset (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place which offset within an experts group a token is. logits (torch.Tensor): Raw logits of gating function. batch (RaggedBatchWrapper): Batch information for ragged tensor. diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index 481be2e5940e..ab1f984fba7e 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -12,4 +12,5 @@ from .llama_v2 import * from .opt import * from .mistral import * +from .mixtral import * from .falcon import * diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py index df5f1427a5cf..8ababf567ba9 100644 --- a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py +++ b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py @@ -33,7 +33,7 @@ class UnfusedMoEMLP1Parameter(ParameterBase): and need to be joined into a single group. """ - experts: ParamList("num_experts") # noqa: F821 + experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: stacked_experts = torch.stack([p for p in self.experts], dim=0) @@ -46,7 +46,7 @@ class UnfusedMoEMLP2Parameter(ParameterBase): and need to be joined into a single group. """ - experts: ParamList("num_experts") # noqa: F821 + experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: stacked_experts = torch.stack([p for p in self.experts], dim=0) @@ -57,13 +57,22 @@ class UnfusedMoEGatedMLPParameter(ParameterBase): """ MoE Parameter for a gated activation function in which the gating matrix is not fused in the same parameter as the non-gating matrix. + + This is a stacked version of the ``GatedMLPParameter``. Please see that class for more + documentation on the layout of the parameters. """ - gating_experts: ParamList("num_experts") # noqa: F821 + gating_experts: ParamList("n_experts") # noqa: F821 - up_experts: ParamList("num_experts") # noqa: F821 + up_experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: - fused_params = [torch.cat([gate, weight], dim=0) for gate, weight in zip(self.gating_experts, self.up_experts)] - stacked_params = torch.stack(fused_params, dim=0) - return self.inference_model.transform_moe_mlp_2_param(stacked_params) + transposed_experts = [] + for gate, up in zip(self.gating_experts, self.up_experts): + assert gate.shape[0] == up.shape[0], "Gated MLP parameters must have the same number of neurons." + total_neurons = gate.shape[0] + up.shape[0] + fused_expert = torch.cat([gate, up], dim=-1).reshape(total_neurons, -1) + transposed_experts.append(fused_expert) + + stacked_experts = torch.stack(transposed_experts, dim=0) + return self.inference_model.transform_moe_mlp_1_param(stacked_experts) diff --git a/deepspeed/inference/v2/model_implementations/falcon/__init__.py b/deepspeed/inference/v2/model_implementations/falcon/__init__.py index ff66879b44be..20f37538274c 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/__init__.py +++ b/deepspeed/inference/v2/model_implementations/falcon/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .falcon_policy import FalconPolicy +from .policy import FalconPolicy diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py b/deepspeed/inference/v2/model_implementations/falcon/container.py similarity index 97% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py rename to deepspeed/inference/v2/model_implementations/falcon/container.py index f3cbe6609cdd..caccfe1ecb00 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py +++ b/deepspeed/inference/v2/model_implementations/falcon/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF Falcon 7b model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py b/deepspeed/inference/v2/model_implementations/falcon/model.py similarity index 98% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_model.py rename to deepspeed/inference/v2/model_implementations/falcon/model.py index a00f754744a4..d1ccc38280a0 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py +++ b/deepspeed/inference/v2/model_implementations/falcon/model.py @@ -11,12 +11,12 @@ from ...allocator import empty_from from ...inference_utils import ActivationType, DtypeEnum -from ...model_implementations import * +from .. import * from ...modules.configs import * from ...modules.interfaces import * from ...ragged import RaggedBatchWrapper -from .falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer +from .container import FalconNonTransformerContainer, FalconTransformerContainer class FalconInferenceModel(DSTransformerModelBase): diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py b/deepspeed/inference/v2/model_implementations/falcon/policy.py similarity index 74% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py rename to deepspeed/inference/v2/model_implementations/falcon/policy.py index 5672d45a8d13..c6612090a0df 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py +++ b/deepspeed/inference/v2/model_implementations/falcon/policy.py @@ -6,10 +6,10 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.falcon.falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer -from ...model_implementations.falcon.falcon_containers import FalconNewArchTransformerContainer -from ...model_implementations.falcon.falcon_model import FalconInferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import FalconNonTransformerContainer, FalconTransformerContainer +from .container import FalconNewArchTransformerContainer +from .model import FalconInferenceModel class FalconPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index 8f6a0b7fa688..e78a161b4cd0 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -521,12 +521,26 @@ def transform_norm_param(self, param: torch.Tensor) -> InferenceParameter: class DSMoETransformerModelBase(DSTransformerModelBase): @property - def num_experts(self) -> int: + def n_experts(self) -> int: """ Return the number of experts in the model. """ raise NotImplementedError("Attempted to access an unimplemented number of experts") + @property + def n_top_k(self) -> int: + """ + Number of experts per token. + """ + raise NotImplementedError("Attempted to access an unimplemented number of experts per token") + + @property + def normalize_expert_scores(self) -> bool: + """ + Whether to normalize expert scores. If true, sum(expert_scores) = 1. + """ + raise NotImplementedError("Attempted to access an unimplemented normalization flag") + def make_moe_layer(self) -> None: """ Instantiates the MoE layer for the model. This sets the `self.moe` attribute. @@ -538,9 +552,11 @@ def make_moe_layer(self) -> None: model_dim=self.model_dim, intermediate_features=sharded_dim, activation=self.mlp_activation_fn, - n_experts=self.num_experts, + n_experts=self.n_experts, + top_k=self.n_top_k, input_dtype=self.activation_dtype, output_dtype=self.activation_dtype, + normalize_scores=self.normalize_expert_scores, ) self.moe = heuristics.instantiate_moe(moe_config, self._engine_config) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py index 5d2b5ae562ee..79605a76a4c2 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .llama_v2_policy import Llama2Policy +from .policy import Llama2Policy diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py b/deepspeed/inference/v2/model_implementations/llama_v2/container.py similarity index 95% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py rename to deepspeed/inference/v2/model_implementations/llama_v2/container.py index e9c473ce512b..9de9bdb34574 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF Llama model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py b/deepspeed/inference/v2/model_implementations/llama_v2/model.py similarity index 83% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py rename to deepspeed/inference/v2/model_implementations/llama_v2/model.py index 9b628f77de01..b91e3258caa0 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/model.py @@ -11,12 +11,13 @@ from ...allocator import empty_from from ...inference_utils import ActivationType, DtypeEnum -from ...model_implementations import * +from .. import * from ...modules.configs import * from ...modules.interfaces import * +from ...modules import heuristics from ...ragged import RaggedBatchWrapper -from .llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer class Llama2InferenceModel(DSTransformerModelBase): @@ -105,6 +106,27 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + """ Forward implementations """ @@ -145,8 +167,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) if self.tp_size > 1: diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py similarity index 76% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py rename to deepspeed/inference/v2/model_implementations/llama_v2/policy.py index c8253be79fad..bb13ab6d5bf4 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py @@ -6,9 +6,9 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.llama_v2.llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer -from ...model_implementations.llama_v2.llama_v2_model import Llama2InferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer +from .model import Llama2InferenceModel class Llama2Policy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/mistral/model.py b/deepspeed/inference/v2/model_implementations/mistral/model.py index d9b06b91e308..08a9dae78e43 100644 --- a/deepspeed/inference/v2/model_implementations/mistral/model.py +++ b/deepspeed/inference/v2/model_implementations/mistral/model.py @@ -14,6 +14,7 @@ from ...model_implementations import * from ...modules.configs import * from ...modules.interfaces import * +from ...modules import heuristics from ...ragged import RaggedBatchWrapper from .container import MistralNonTransformerContainer, MistralTransformerContainer @@ -104,6 +105,27 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + """ Forward implementations """ @@ -144,8 +166,7 @@ def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_st kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) if self.tp_size > 1: diff --git a/deepspeed/inference/v2/model_implementations/mistral/policy.py b/deepspeed/inference/v2/model_implementations/mistral/policy.py index f6d0a0fe5987..b67ec311c952 100644 --- a/deepspeed/inference/v2/model_implementations/mistral/policy.py +++ b/deepspeed/inference/v2/model_implementations/mistral/policy.py @@ -5,10 +5,10 @@ from typing import Any -from deepspeed.inference.v2.config_v2 import RaggedInferenceEngineConfig -from deepspeed.inference.v2.model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from deepspeed.inference.v2.model_implementations.mistral.container import MistralNonTransformerContainer, MistralTransformerContainer -from deepspeed.inference.v2.model_implementations.mistral.model import MistralInferenceModel +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MistralNonTransformerContainer, MistralTransformerContainer +from .model import MistralInferenceModel class MistralPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/mixtral/__init__.py b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py new file mode 100644 index 000000000000..2cb1aa889291 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import MixtralPolicy diff --git a/deepspeed/inference/v2/model_implementations/mixtral/container.py b/deepspeed/inference/v2/model_implementations/mixtral/container.py new file mode 100644 index 000000000000..6ec4a0552b8f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/container.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class MixtralTransformerContainer(LayerContainer): + + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + moe_gate: MoEGatingWeightParameter + moe_mlp_1: UnfusedMoEGatedMLPParameter + moe_mlp_2: UnfusedMoEMLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "block_sparse_moe.gate.weight": "moe_gate.params", + "block_sparse_moe.experts.*.w1.weight": "moe_mlp_1.gating_experts", + "block_sparse_moe.experts.*.w3.weight": "moe_mlp_1.up_experts", + "block_sparse_moe.experts.*.w2.weight": "moe_mlp_2.experts", + } + + +class MixtralNonTransformerContainer(LayerContainer): + + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "lm_head.weight": "word_unembed.params", + "model.norm.weight": "final_norm.params", + } diff --git a/deepspeed/inference/v2/model_implementations/mixtral/model.py b/deepspeed/inference/v2/model_implementations/mixtral/model.py new file mode 100644 index 000000000000..731a907716f4 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/model.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .container import MixtralNonTransformerContainer, MixtralTransformerContainer + + +class MixtralInferenceModel(DSMoETransformerModelBase): + """ + Inference model implementation for Mixtral models. + """ + + _non_transformer: Optional[MixtralNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[MixtralTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_position_embeddings + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Inherited from `DSMoETransformerModelBase` + """ + + @property + def n_experts(self) -> int: + return self._config.num_local_experts + + @property + def n_top_k(self) -> int: + return self._config.num_experts_per_tok + + @property + def normalize_expert_scores(self) -> bool: + return True + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_moe_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma) + + hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1, + cur_params.moe_mlp_2) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, + self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/mixtral/policy.py b/deepspeed/inference/v2/model_implementations/mixtral/policy.py new file mode 100644 index 000000000000..2f0087919720 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/policy.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MixtralTransformerContainer, MixtralNonTransformerContainer +from .model import MixtralInferenceModel + + +class MixtralPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> MixtralInferenceModel: + return MixtralInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + + map = ContainerMap() + + transformer_containers = [MixtralTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(MixtralNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/opt/container.py b/deepspeed/inference/v2/model_implementations/opt/container.py index 5ddbbde3f141..e97599ef8e50 100644 --- a/deepspeed/inference/v2/model_implementations/opt/container.py +++ b/deepspeed/inference/v2/model_implementations/opt/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF OPT model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/opt/model.py b/deepspeed/inference/v2/model_implementations/opt/model.py index fa221e15a0b7..8bd26ba044e5 100644 --- a/deepspeed/inference/v2/model_implementations/opt/model.py +++ b/deepspeed/inference/v2/model_implementations/opt/model.py @@ -131,8 +131,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) if self.tp_size > 1: diff --git a/deepspeed/inference/v2/model_implementations/opt/policy.py b/deepspeed/inference/v2/model_implementations/opt/policy.py index af5750260ead..d57d5beb48d5 100644 --- a/deepspeed/inference/v2/model_implementations/opt/policy.py +++ b/deepspeed/inference/v2/model_implementations/opt/policy.py @@ -6,9 +6,9 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.opt.container import OPTNonTransformerContainer, OPTTransformerContainer -from ...model_implementations.opt.model import OPTInferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import OPTNonTransformerContainer, OPTTransformerContainer +from .model import OPTInferenceModel class OPTPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/modules/configs/__init__.py b/deepspeed/inference/v2/modules/configs/__init__.py index 19b9fb99ddea..3429e69b47de 100644 --- a/deepspeed/inference/v2/modules/configs/__init__.py +++ b/deepspeed/inference/v2/modules/configs/__init__.py @@ -3,7 +3,12 @@ # DeepSpeed Team -from .attention_configs import (DSSelfAttentionConfig, PositionalEmbeddingType, MaskingType) +from .attention_configs import ( + DSSelfAttentionConfig, + PositionalEmbeddingType, + MaskingType, + RotateHalfConfig, +) from .embedding_config import DSEmbeddingsConfig from .linear_config import DSLinearConfig from .moe_config import DSMoEConfig diff --git a/deepspeed/inference/v2/modules/configs/attention_configs.py b/deepspeed/inference/v2/modules/configs/attention_configs.py index bcdc3d2613d5..823104b13fc2 100644 --- a/deepspeed/inference/v2/modules/configs/attention_configs.py +++ b/deepspeed/inference/v2/modules/configs/attention_configs.py @@ -4,10 +4,11 @@ # DeepSpeed Team from enum import Enum -from typing import Dict +from typing import Dict, Optional from ...inference_utils import DtypeEnum from ...modules.ds_module import DSModuleConfig +from deepspeed.runtime.config_utils import DeepSpeedConfigModel class PositionalEmbeddingType(Enum): @@ -25,6 +26,20 @@ class PositionalEmbeddingType(Enum): alibi = "alibi" +class RotateHalfConfig(DeepSpeedConfigModel): + + use_trained_freqs: bool = False + """ + Whether to use a passed `trained_freqs` tensor for the attention implementation + or to use default synthesized frequencies. + """ + + theta_base: float = 10_000.0 + """ + Base for theta. This will only be used if `use_trained_freqs` is False. + """ + + class MaskingType(Enum): # No masking @@ -79,4 +94,9 @@ class DSSelfAttentionConfig(DSModuleConfig): positional_embedding_type: PositionalEmbeddingType = PositionalEmbeddingType.none # Positional embedding args - positional_embedding_args: Dict = {} + positional_embedding_config: Optional[RotateHalfConfig] = None + """ + To extend this for the other positional embedding types, we would need to add + new configs for each type (as necessary) and annotate this with the + Union[RotateHalfConfig, OtherConfig, ...] type. + """ diff --git a/deepspeed/inference/v2/modules/configs/moe_config.py b/deepspeed/inference/v2/modules/configs/moe_config.py index 1a88d54af19f..7bc944f55e17 100644 --- a/deepspeed/inference/v2/modules/configs/moe_config.py +++ b/deepspeed/inference/v2/modules/configs/moe_config.py @@ -48,3 +48,9 @@ class DSMoEConfig(DSModuleConfig): """ Activation function of the first MLP1 """ + + normalize_scores: bool = False + """ + Whether normalization is applied to the selected scores. If true, the module + should rescale the scores such that their sum is 1.0. + """ diff --git a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py index bb482f0c58d6..b2727ffca620 100644 --- a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py +++ b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py @@ -68,9 +68,16 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st Args: config (DSSelfAttentionConfig): The self attention config for all attention DSModules. - implementation_config (Dict[str, Any]): The implementation config for this DSModule may - contain a `trained_freqs` key. If passed, the implementation will expect a `trained_freqs` - tensor in the `forward` method and will not synthesize the frequencies internally. + implementation_config (Dict[str, Any]): + There are two (dependent) potential components in the implementtion config. + + 1. `trained_freqs` - If the embedding weights for RoPE are trained, the implementation + config should contain {'trained_freqs': True}. This will mean the implementation will + expect a `trained_freqs` tensor in the `forward` method and will not synthesize the + values internally. + + 2. `theta_base` - The base value for synthesized frequencies in the rotary embeddings. + This will only be used if `trained_freqs` is False or not present in the `implementation_config`. If this is not included, the default value of 10000.0 will be used. """ super().__init__(config, implementation_config) @@ -79,14 +86,13 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st self._kv_copy = LinearBlockedKVCopy(self._config.head_size, self._config.n_heads_q, self._config.n_heads_kv, self._config.input_dtype) elif embed_type == PositionalEmbeddingType.rotate_half: - use_trained_freqs = "trained_freqs" in self._config.positional_embedding_args and self._config.positional_embedding_args[ - "trained_freqs"] - if use_trained_freqs: + if config.positional_embedding_config.use_trained_freqs: self._kv_copy = BlockedTrainedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, self._config.n_heads_kv, self._config.input_dtype) else: + theta_base = config.positional_embedding_config.theta_base self._kv_copy = BlockedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, - self._config.n_heads_kv, self._config.input_dtype) + self._config.n_heads_kv, self._config.input_dtype, theta_base) self._softmax_scale = self._config.scale_factor diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py index e43a737515ed..38c0000d7f78 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -9,12 +9,12 @@ from deepspeed.accelerator import get_accelerator from ....allocator import empty_from -from ....inference_utils import ActivationType -from ....kernels.core_ops import BlasLibLinear +from ....inference_utils import ActivationType, is_gated +from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation from ....kernels.ragged_ops import ( MoEGather, MoEScatter, - RaggedTop1Gating, + RaggedTopKGating, ) from ....ragged import RaggedBatchWrapper @@ -42,11 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool: if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: return False - if config.top_k != 1: - return False - - if config.activation in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]: - # Currently not supporting gated activations in MoE + if config.top_k != 1 and config.top_k != 2: return False return True @@ -57,15 +53,24 @@ def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) - # Convenience variables for frequently accessed items. self.max_tokens = self._config.max_tokens self.n_experts = self._config.n_experts + self.n_top_k = self._config.top_k self.intermediate_dim = self._config.intermediate_features - self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=config.activation) + moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation + + self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) + if is_gated(self._config.activation): + self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, + self._config.activation) + else: + self._activation = None + self._gate_proj = BlasLibLinear(self._config.input_dtype) - self._top_1_gate = RaggedTop1Gating(config.input_dtype) + self._top_1_gate = RaggedTopKGating(config.input_dtype) self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) - self._moe_gather = MoEGather(config.input_dtype, config.model_dim) + self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) self._create_buffers() @@ -78,32 +83,38 @@ def _create_buffers(self): self._expert_counts = torch.empty((self.n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - self._scores = torch.empty((self._config.max_tokens, ), + self._scores = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) - self._assignments = torch.empty((self._config.max_tokens, ), + self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - self._offsets = torch.empty((self._config.max_tokens, ), + self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) # Scatter buffers - self._moe_input = torch.empty((self._config.max_tokens, self._config.model_dim), + self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), dtype=self._config.input_dtype, device=get_accelerator().current_device()) self._expert_cumsum = torch.empty((self._config.n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - self._mapped_slots = torch.empty((self._config.max_tokens, ), + self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) # GEMM Buffers - self._intermediate = torch.empty((self._config.max_tokens, self._config.intermediate_features), + self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), dtype=self._config.output_dtype, device=get_accelerator().current_device()) - self._output_unordered = torch.empty((self._config.max_tokens, self._config.model_dim), + if self._activation is not None: + self._gated_intermediate = torch.empty( + (self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), dtype=self._config.output_dtype, device=get_accelerator().current_device()) @@ -167,11 +178,11 @@ def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, # Get views on the buffers for gating logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) - scores = empty_from(self._scores, (hidden_states.shape[0], )) - assignments = empty_from(self._assignments, (hidden_states.shape[0], )) - offsets = empty_from(self._offsets, (hidden_states.shape[0], )) - mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], )) - moe_input = empty_from(self._moe_input, (hidden_states.shape[0], self._moe_input.shape[-1])) + scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) + assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) + offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) + mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) + moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) self._gate_proj(logits, hidden_states, gate_w) self._expert_counts.zero_() @@ -200,18 +211,31 @@ def forward(self, moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) # Get views on the buffers for GEMM - intermediate = empty_from(self._intermediate, (hidden_states.shape[0], self._intermediate.shape[-1])) + intermediate = empty_from(self._intermediate, + (hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) output_unordered = empty_from(self._output_unordered, - (hidden_states.shape[0], self._output_unordered.shape[-1])) + (hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) - self._mlp_1( - intermediate, - moe_input, - mlp_1_w, - expert_cumsum, - mlp_1_b, - ) + if self._activation is not None: + gated_intermediate = empty_from( + self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) + self._mlp_1( + gated_intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) + self._activation(intermediate, gated_intermediate) + else: + self._mlp_1( + intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) self._mlp_2( output_unordered, diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index 13d71b476b5a..8cb372e96c37 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -73,8 +73,8 @@ def sources(self): "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp", "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu", "inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp", - "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp", - "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu", ] prefix = self.get_prefix() @@ -101,12 +101,13 @@ def include_paths(self): 'inference/v2/kernels/ragged_ops/atom_builder', 'inference/v2/kernels/ragged_ops/blocked_flash', 'inference/v2/kernels/ragged_ops/embed', + 'inference/v2/kernels/ragged_ops/includes', 'inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary', 'inference/v2/kernels/ragged_ops/logits_gather', 'inference/v2/kernels/ragged_ops/moe_gather', 'inference/v2/kernels/ragged_ops/moe_scatter', 'inference/v2/kernels/ragged_ops/ragged_helpers', - 'inference/v2/kernels/ragged_ops/top_1_gating', + 'inference/v2/kernels/ragged_ops/top_k_gating', ] prefix = self.get_prefix() diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py index 5fa375b49c19..3907fc3e3a4b 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py @@ -11,18 +11,28 @@ from deepspeed.inference.v2.kernels.ragged_ops import ( MoEGather, MoEScatter, - RaggedTop1Gating, + RaggedTopKGating, ) from .ragged_testing_utils import build_simple_batch """ -For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` and +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` and ``MoEScatter`` to produce correct inputs. If either of these kernels is broken these tests will fail, so double check the unit test results there before debugging here. """ +TEST_CASES = [ + # (n_tokens, n_experts, n_top_k) + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] -def build_inputs(n_tokens, n_experts, do_padding): + +def build_inputs(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): assert n_tokens <= 2048, "This test will break if n_tokens > 2048" @@ -39,22 +49,28 @@ def build_inputs(n_tokens, n_experts, do_padding): device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( batch.tensor_toks, 4096).contiguous() - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) # Gating outputs expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((batch.tensor_toks, ), + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) # Scatter outputs - moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) scatter = MoEScatter(DtypeEnum.fp16, 4096) scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) @@ -63,11 +79,12 @@ def build_inputs(n_tokens, n_experts, do_padding): @pytest.mark.inference_v2_ops -@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) -@pytest.mark.parametrize("do_padding", [True, False]) -def test_moe_gather(n_tokens, n_experts, do_padding): +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CASES) +@pytest.mark.parametrize("do_padding", [False]) +def test_moe_gather(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): + get_accelerator().manual_seed(0xC0FFEE) - batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, do_padding) + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) @@ -75,9 +92,31 @@ def test_moe_gather(n_tokens, n_experts, do_padding): gather(output, moe_input, scores, mapped_slots, expert_counts) for token_idx in range(n_tokens): + effective_score = scores[token_idx].sum().item() assert torch.equal( output[token_idx], torch.full((4096, ), - token_idx * scores[token_idx], + token_idx * effective_score, dtype=torch.float16, device=get_accelerator().current_device())) + + +@pytest.mark.inference_v2_ops +def test_moe_gather_normalize_scales(): + get_accelerator().manual_seed(0xC0FFEE) + + n_tokens = 72 + n_experts = 8 + n_top_k = 2 + do_padding = False + + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) + output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + + gather = MoEGather(DtypeEnum.fp16, 4096, normalize_scores=True) + gather(output, moe_input, scores, mapped_slots, expert_counts) + + for token_idx in range(n_tokens): + assert torch.equal( + output[token_idx], + torch.full((4096, ), token_idx, dtype=torch.float16, device=get_accelerator().current_device())) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py index 4ca051410c1c..aae459f06a6f 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py @@ -8,19 +8,28 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import DtypeEnum -from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTop1Gating +from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTopKGating from .ragged_testing_utils import build_simple_batch """ -For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` to produce correct -inputs. If ``RaggedTop1Gating`` is broken, these tests will fail, so double check +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` to produce correct +inputs. If ``RaggedTopKGating`` is broken, these tests will fail, so double check the unit test results there before debugging here. """ +TEST_CONFIGS = [ + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] + @pytest.mark.inference_v2_ops -@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) -@pytest.mark.parametrize("do_padding", [True, False]) -def test_moe_scatter(n_tokens, n_experts, do_padding): +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CONFIGS) +@pytest.mark.parametrize("do_padding", [False, True]) +def test_moe_scatter(n_tokens, n_experts, n_top_k, do_padding): # Sequence composition shouldn't matter here batch = build_simple_batch([n_tokens], padding=do_padding) @@ -35,40 +44,52 @@ def test_moe_scatter(n_tokens, n_experts, do_padding): device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( batch.tensor_toks, 4096).contiguous() - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) # Gating outputs expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((batch.tensor_toks, ), + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) # Scatter outputs - moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) scatter = MoEScatter(DtypeEnum.fp16, 4096) scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + get_accelerator().synchronize() assert torch.equal(expert_cumsum, torch.cumsum(expert_counts, dim=0).to(torch.int64)) + if not do_padding: + assert torch.unique(mapped_slots).size(0) == n_top_k * n_tokens + for token_idx in range(batch.tensor_toks): if token_idx < n_tokens: - expert_idx = expert_assignment[token_idx].item() - if expert_idx == 0: - expert_cumsum_val = 0 - else: - expert_cumsum_val = expert_cumsum[expert_idx - 1] - offset = expert_offset[token_idx] - total_offset = offset + expert_cumsum_val - - assert total_offset == mapped_slots[token_idx].item() - assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) + for k in range(n_top_k): + expert_idx = expert_assignment[token_idx][k].item() + if expert_idx == 0: + expert_cumsum_val = 0 + else: + expert_cumsum_val = expert_cumsum[expert_idx - 1] + offset = expert_offset[token_idx][k] + total_offset = offset + expert_cumsum_val + + assert total_offset == mapped_slots[token_idx][k].item() + assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) else: - assert mapped_slots[token_idx].item() == -1 + for k in range(n_top_k): + assert mapped_slots[token_idx][k].item() == -1 - assert expert_cumsum[-1] == n_tokens + assert expert_cumsum[-1] == n_tokens * n_top_k diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py similarity index 51% rename from tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py rename to tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py index 6ff2508bf320..5fa0c8a079f0 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py @@ -9,9 +9,52 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import DtypeEnum -from deepspeed.inference.v2.kernels.ragged_ops import RaggedTop1Gating +from deepspeed.inference.v2.kernels.ragged_ops import RaggedTopKGating from .ragged_testing_utils import build_simple_batch -from ....v2.inference_test_utils import allclose +from ...inference_test_utils import allclose + + +def _top_k_gating_testing_helper(n_tokens: int, n_experts: int, n_top_k: int, seed: int = 0xC0FFEE) -> None: + + torch.manual_seed(seed) + logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + gate = RaggedTopKGating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + ref_weights = F.softmax(logits, dim=-1, dtype=torch.float32) + ref_scores, ref_indices = torch.topk(ref_weights, n_top_k, dim=-1) + + assert allclose(scores, ref_scores), f"expected {ref_scores}, got {scores}" + assert torch.equal(expert_assignment, + ref_indices.to(torch.int32)), f"expected {ref_indices}, got {expert_assignment}" + assert expert_counts.sum( + ) == n_tokens * n_top_k, f"expected {n_tokens * n_top_k} tokens, got {expert_counts.sum()}" + + # Ensure that the expert offsets are unique + for i in range(n_experts): + expert_idxs = torch.where(expert_assignment == i, expert_offset, 0) + if expert_counts[i] > 0: + assert expert_idxs.unique().shape[0] == expert_counts[ + i], f"expected {expert_counts[i]} unique offsets, got {expert_idxs.unique().shape[0]}" + assert expert_idxs.max( + ) == expert_counts[i] - 1, f"expected max offset {expert_counts[i] - 1}, got {expert_idxs.max()}" + else: + # Should have all 0's so one unique value + assert expert_idxs.unique().shape[0] == 1 + assert expert_idxs.max() == 0 + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens', [1, 17, 32, 89, 433]) +def test_top_2_e_8_gating(n_tokens: int) -> None: + _top_k_gating_testing_helper(n_tokens=n_tokens, n_experts=8, n_top_k=2) def _test_single_mapping_helper(n_tokens: int, @@ -19,6 +62,8 @@ def _test_single_mapping_helper(n_tokens: int, assigned_expert: int, logit_fill: float = 0.0, match_fill: float = 1.0) -> None: + + n_top_k = 1 logits = torch.full((n_tokens, n_experts), logit_fill, dtype=torch.float16, @@ -26,12 +71,12 @@ def _test_single_mapping_helper(n_tokens: int, logits[:, assigned_expert] = match_fill - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) @@ -39,7 +84,7 @@ def _test_single_mapping_helper(n_tokens: int, assert expert_counts[assigned_expert] == n_tokens assert torch.all(expert_assignment == assigned_expert) assert torch.unique(expert_offset).shape[0] == n_tokens - assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert]) + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert].reshape(-1, n_top_k)) @pytest.mark.inference_v2_ops @@ -72,6 +117,7 @@ def test_determinism(): n_tokens = 512 n_experts = 64 + n_top_k = 1 logits = torch.zeros((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) @@ -79,13 +125,15 @@ def test_determinism(): logits[:, 19] = 1.0 logits[:, 26] = 1.0 - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) for _ in range(1024): expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) @@ -94,7 +142,7 @@ def test_determinism(): assert expert_counts[26] == 0 assert torch.all(expert_assignment == 19) assert torch.unique(expert_offset).shape[0] == n_tokens - assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19]) + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19].reshape(-1, 1)) @pytest.mark.inference_v2_ops @@ -105,16 +153,19 @@ def test_score_accuracy(n_tokens: int, n_experts: int) -> None: """ logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) + n_top_k = 1 - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) ref_scores = F.softmax(logits.float(), dim=1).max(dim=1).values + ref_scores = ref_scores.reshape(-1, 1) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + assert allclose(scores, ref_scores) assert expert_counts.sum() == n_tokens diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py index 260236562ee9..06ff9047d648 100644 --- a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py +++ b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py @@ -26,7 +26,7 @@ def __init__(self, experts_per_rank: int) -> None: self._num_experts = experts_per_rank @property - def num_experts(self) -> int: + def n_experts(self) -> int: return self._num_experts @on_device diff --git a/tests/unit/inference/v2/modules/test_blocked_attn.py b/tests/unit/inference/v2/modules/test_blocked_attn.py index 215ad64636b1..6556aa460a44 100644 --- a/tests/unit/inference/v2/modules/test_blocked_attn.py +++ b/tests/unit/inference/v2/modules/test_blocked_attn.py @@ -12,7 +12,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.modules import ConfigBundle -from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType +from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType, RotateHalfConfig from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase from ..kernels.ragged_ops.ragged_testing_utils import build_batch_and_manager @@ -37,13 +37,10 @@ def _blocked_flash_testing_helper(head_size: int, """ if trained_freqs is None: embed_type = PositionalEmbeddingType.none - embed_args = {} + embed_args = None else: embed_type = PositionalEmbeddingType.rotate_half - if trained_freqs: - embed_args = {'trained_freqs': True} - else: - embed_args = {'trained_freqs': False} + embed_args = RotateHalfConfig(use_trained_freqs=trained_freqs) attn_config = DSSelfAttentionConfig(max_tokens=2048, n_heads_q=n_heads_q, @@ -51,7 +48,7 @@ def _blocked_flash_testing_helper(head_size: int, head_size=head_size, max_sequences=32, positional_embedding_type=embed_type, - positional_embedding_args=embed_args) + positional_embedding_config=embed_args) config = ConfigBundle(name='dense_blocked_attention', config=attn_config) attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config) diff --git a/tests/unit/inference/v2/modules/test_cutlass_moe.py b/tests/unit/inference/v2/modules/test_cutlass_moe.py index e21170c9ed8f..b14ba127c6be 100644 --- a/tests/unit/inference/v2/modules/test_cutlass_moe.py +++ b/tests/unit/inference/v2/modules/test_cutlass_moe.py @@ -212,3 +212,117 @@ def test_in_out_channels(in_channels: int, out_channels: int) -> None: dtype=DtypeEnum.fp16, activation_type=ActivationType.IDENTITY, use_bias=True) + + +def _mixtral_moe_baseline(hidden_states: torch.Tensor, + gate_weight: torch.Tensor, + mlp_w1: torch.Tensor, + mlp_w2: torch.Tensor, + mlp_w3: torch.Tensor, + force_float: bool = False) -> torch.Tensor: + """ + Baseline implementation for mixtral MoE module. + + Based on transformers implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + """ + output_dtype = hidden_states.dtype + if force_float: + hidden_states = hidden_states.float() + gate_weight = gate_weight.float() + mlp_w1 = mlp_w1.float() + mlp_w2 = mlp_w2.float() + mlp_w3 = mlp_w3.float() + + router_logits = torch.nn.functional.linear(hidden_states, gate_weight) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = routing_weights.topk(k=2, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # NOTE(cmikeh2): This is a difference implementation, ours will preserve the original scale + # as float32 and perform in-kernel fused FP16->FP32->FP16 conversion. + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros_like(hidden_states) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=gate_weight.shape[0]).permute(2, 1, 0) + get_accelerator().synchronize() + + for expert_idx in range(gate_weight.shape[0]): + exp_mlp_w1 = mlp_w1[expert_idx] + exp_mlp_w2 = mlp_w2[expert_idx] + exp_mlp_w3 = mlp_w3[expert_idx] + + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states[top_x_list] + + linear = torch.nn.functional.linear + intermediate = torch.nn.functional.silu(linear(current_state, exp_mlp_w1)) * linear(current_state, exp_mlp_w3) + output = linear(intermediate, exp_mlp_w2) * routing_weights[top_x_list, idx_list].unsqueeze(-1) + final_hidden_states.index_add_(0, top_x, output.to(final_hidden_states.dtype)) + + return final_hidden_states.to(output_dtype) + + +@pytest.mark.inference_v2_ops +def test_mixtral_moe_config(): + + experts = 8 + n_top_k = 2 + in_channels = 4096 + intermediate_dim = 2048 + dtype = DtypeEnum.bf16 + + # Parameters + gate_weight = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + mlp_w1 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w3 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w2 = torch.randn( + (experts, in_channels, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + n_tokens = 256 + hidden_states = torch.randn( + (n_tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + baseline = _mixtral_moe_baseline(hidden_states, gate_weight, mlp_w1, mlp_w2, mlp_w3) + + mlp_w13_fused = torch.cat([mlp_w1, mlp_w3], dim=-1).reshape(experts, 2 * intermediate_dim, in_channels) + + config = DSMoEConfig(max_tokens=4096, + model_dim=in_channels, + intermediate_features=intermediate_dim, + n_experts=experts, + activation=ActivationType.SiGLU, + input_dtype=dtype, + output_dtype=dtype, + top_k=n_top_k, + normalize_scores=True) + + implementation_config = {"weight_dtype": DtypeEnum(dtype)} + + bundle = ConfigBundle(name='cutlass_multi_gemm_moe', config=config, implementation_config=implementation_config) + moe_module = DSMoERegistry.instantiate_config(bundle) + + batch = build_simple_batch([n_tokens]) + + gate_ds = moe_module.transform_gate_param(gate_weight) + mlp_w1_ds = moe_module.transform_moe_mlp_1_param(mlp_w13_fused) + mlp_w2_ds = moe_module.transform_moe_mlp_2_param(mlp_w2) + + output = moe_module(hidden_states, batch, gate_ds, mlp_w1_ds, mlp_w2_ds) + + # NOTE(cmikeh2): These are higher than the other tests for reasons that aren't quite + # clear to me. My best guess is that the SiGLU activation is causing larger numerical + # divergence. The thresholds chosen here is based on the observed error between the + # float and bfloat16 reference implementations. + assert allclose(output, baseline.to(dtype.value), tolerances=(5e-2, 5e-2)) From 827e1ca8aa76926d0765743efc59d8c2d1cb9db5 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Wed, 20 Dec 2023 16:48:47 -0800 Subject: [PATCH 05/10] Update version.txt after 0.12.6 release (#4850) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.12.6 Author - @mrwyattii Co-authored-by: mrwyattii --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index dabff2f13810..e2e3067ddc5f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.12.6 +0.12.7 From 75c772021484fa0f9d9a3872a90c0876f7cf59d8 Mon Sep 17 00:00:00 2001 From: Gavin Goodship Date: Thu, 21 Dec 2023 19:13:24 +0000 Subject: [PATCH 06/10] doc corrections (#4861) --- docs/_tutorials/advanced-install.md | 50 ++++++++++++++--------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md index 10197e62f681..d27ecf021421 100755 --- a/docs/_tutorials/advanced-install.md +++ b/docs/_tutorials/advanced-install.md @@ -27,7 +27,7 @@ ds_report ## Pre-install DeepSpeed Ops -**Note:** [PyTorch](https://pytorch.org/) must be installed _before_ pre-compiling any DeepSpeed c++/cuda ops. However, this is not required if using the default mode of JIT compilation of ops. +**Note:** [PyTorch](https://pytorch.org/) must be installed _before_ pre-compiling any DeepSpeed C++/CUDA ops. However, this is not required if using the default mode of JIT compilation of ops. {: .notice--info} Sometimes we have found it useful to pre-install either some or all DeepSpeed @@ -56,22 +56,22 @@ DS_BUILD_FUSED_LAMB=1 pip install deepspeed ``` Available `DS_BUILD` options include: -* `DS_BUILD_OPS` toggles all ops -* `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op -* `DS_BUILD_CCL_COMM` builds the communication collective libs -* `DS_BUILD_CPU_ADAM` builds the CPUAdam op -* `DS_BUILD_CPU_LION` builds the CPULion op -* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/)) -* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)) -* `DS_BUILD_FUSED_LION` builds the FusedLion op -* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op -* `DS_BUILD_FUSED_LAMB` builds the FusedLamb op -* `DS_BUILD_QUANTIZER` builds the quantizer op -* `DS_BUILD_RANDOM_LTD` builds the random ltd op -* `DS_BUILD_SPARSE_ATTN` builds the sparse attention op -* `DS_BUILD_TRANSFORMER` builds the transformer op -* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op -* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op +* `DS_BUILD_OPS` toggles all ops. +* `DS_BUILD_AIO` builds asynchronous (NVMe) I/O op. +* `DS_BUILD_CCL_COMM` builds the communication collective libs. +* `DS_BUILD_CPU_ADAM` builds the CPUAdam op. +* `DS_BUILD_CPU_LION` builds the CPULion op. +* `DS_BUILD_EVOFORMER_ATTN` builds the EvoformerAttn op (from [Alphafold](https://www.deepspeed.ai/tutorials/ds4sci_evoformerattention/)). +* `DS_BUILD_FUSED_ADAM` builds the FusedAdam op (from [apex](https://github.com/NVIDIA/apex)). +* `DS_BUILD_FUSED_LION` builds the FusedLion op. +* `DS_BUILD_CPU_ADAGRAD` builds the CPUAdagrad op. +* `DS_BUILD_FUSED_LAMB` builds the FusedLamb op. +* `DS_BUILD_QUANTIZER` builds the quantizer op. +* `DS_BUILD_RANDOM_LTD` builds the random ltd op. +* `DS_BUILD_SPARSE_ATTN` builds the sparse attention op. +* `DS_BUILD_TRANSFORMER` builds the transformer op. +* `DS_BUILD_TRANSFORMER_INFERENCE` builds the transformer-inference op. +* `DS_BUILD_STOCHASTIC_TRANSFORMER` builds the stochastic transformer op. To speed up the build-all process, you can parallelize the compilation process with: @@ -81,7 +81,7 @@ DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" --global-option This should complete the full build 2-3 times faster. You can adjust `-j` to specify how many cpu-cores are to be used during the build. In the example it is set to 8 cores. -You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, pytorch, python, etc.) +You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, PyTorch, Python, etc.) ```bash DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel @@ -107,7 +107,7 @@ pip install . For installs spanning multiple nodes we find it useful to install DeepSpeed using the [install.sh](https://github.com/microsoft/DeepSpeed/blob/master/install.sh) -script in the repo. This will build a python wheel locally and copy it to all +script in the repo. This will build a Python wheel locally and copy it to all the nodes listed in your hostfile (either given via `--hostfile`, or defaults to `/job/hostfile`). @@ -118,7 +118,7 @@ extensions will be loaded form that directory. If you use multiple virtual environments this could be a problem, since by default there is only one `torch_extensions` directory, but different virtual environments may use different setups (e.g., different -python or cuda versions) and then the loading of a CUDA extension built by another environment will +Python or CUDA versions) and then the loading of a CUDA extension built by another environment will fail. Therefore, if you need to you can override the default location with the help of the `TORCH_EXTENSIONS_DIR` environment variable. So in each virtual environment you can point it to a unique directory and DeepSpeed will use it to save and load CUDA extensions. @@ -146,9 +146,9 @@ If you're getting the following error: ``` RuntimeError: CUDA error: no kernel image is available for execution on the device ``` -when running deepspeed, that means that the cuda extensions weren't built for the card you're trying to use it for. +when running deepspeed, that means that the CUDA extensions weren't built for the card you're trying to use it for. -When building from source deepspeed will try to support a wide range of architectures, but under jit-mode it'll only +When building from source DeepSpeed will try to support a wide range of architectures, but under jit-mode it'll only support the architectures visible at the time of building. You can build specifically for a desired range of architectures by setting a `TORCH_CUDA_ARCH_LIST` env variable: @@ -159,9 +159,9 @@ TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... It will also make the build faster when you only build for a few architectures. -This is also recommended to ensure your exact architecture is used. Due to a variety of technical reasons, a distributed pytorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which architectures get included during the deepspeed build from source - save the log and grep for `-gencode` arguments. +This is also recommended to ensure your exact architecture is used. Due to a variety of technical reasons, a distributed PyTorch binary isn't built to fully support all architectures, skipping binary compatible ones, at a potential cost of underutilizing your full card's compute capabilities. To see which architectures get included during the DeepSpeed build from source - save the log and grep for `-gencode` arguments. -The full list of nvidia GPUs and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus). +The full list of Nvidia GPUs and their compute capabilities can be found [here](https://developer.nvidia.com/cuda-gpus). ## CUDA version mismatch @@ -171,7 +171,7 @@ If you're getting the following error: Exception: >- DeepSpeed Op Builder: Installed CUDA version {VERSION} does not match the version torch was compiled with {VERSION}, unable to compile cuda/cpp extensions without a matching cuda version. ``` You have a misaligned version of CUDA installed compared to the version of CUDA -used to compile torch. A mismatch in the major version is likely to result in +used to compile Torch. A mismatch in the major version is likely to result in errors or unexpected behavior. The easiest fix for this error is changing the CUDA version installed (check From c37fe9cbfb8bc10c8dd6ccd8cac9b34ded218990 Mon Sep 17 00:00:00 2001 From: Heyang Qin Date: Fri, 22 Dec 2023 11:48:48 -0800 Subject: [PATCH 07/10] Fix exception handling in get_all_ranks_from_group() function (#4862) In the latest Pytorch nightly, the exception thrown from `torch.distributed.distributed_c10d.get_global_rank()` is changed from `RuntimeError` to `ValueError` so we need to update our try-catch in `deepspeed.comm` Tested with torch version 2.3.0.dev20231221+cu121 Fixes: https://github.com/microsoft/DeepSpeed/issues/4853 --- deepspeed/comm/comm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 568211645f40..fb92c1e98421 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -595,7 +595,7 @@ def get_all_ranks_from_group(group=None): while True: group_ranks.append(cdb.get_global_rank(group, rank)) rank += 1 - except RuntimeError: + except (RuntimeError, ValueError): pass return group_ranks From 3e94f8c75116377d4b1c32b8c674368a27fb2a77 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Tue, 26 Dec 2023 18:00:11 +0200 Subject: [PATCH 08/10] deepspeed engine: fp16 support validation on init (#4843) --- deepspeed/runtime/engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9c9641a1c4cf..4f672f0ba5e2 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1017,6 +1017,9 @@ def _supported_optims(self): # Validate configuration based on command line arguments def _do_sanity_check(self): + if self.fp16_enabled() and not get_accelerator().is_fp16_supported(): + raise ValueError("Type fp16 is not supported.") + expected_optim_types = self._supported_optims() expected_optim_types += [type(None), Callable] assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \ From 40342055cefda4c453e803759861b00ca5cfb879 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Sat, 30 Dec 2023 08:18:49 +0530 Subject: [PATCH 09/10] Remove hooks on gradient accumulation on engine/optimizer destroy (#4858) Fixes: #4856 See: https://github.com/pytorch/pytorch/issues/46386 --- deepspeed/runtime/engine.py | 3 ++- deepspeed/runtime/zero/stage3.py | 6 +++++- deepspeed/runtime/zero/stage_1_and_2.py | 8 +++++++- deepspeed/utils/debug.py | 7 +++++++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 4f672f0ba5e2..79bdba90e6d4 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -69,7 +69,7 @@ STEP_MICRO_TIMER, \ FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ STEP_GLOBAL_TIMER -from deepspeed.utils.debug import debug_extract_module_and_param_names +from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from deepspeed.runtime.utils import clip_grad_norm_ @@ -362,6 +362,7 @@ def __init__( def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() + debug_clear_module_and_param_names() def _get_model_parameters(self): if self.autotuning_profile_model_info(): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 30a168dcd396..fa4e64faf5a5 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -377,6 +377,7 @@ def __init__( #creates backward hooks for gradient partitioning ###Calls all gather param + self._grad_acc_hooks = [] self.create_reduce_and_remove_grad_hooks() #exit(0) @@ -397,6 +398,9 @@ def __init__( def destroy(self): self.parameter_offload.destroy() + for hook in self._grad_acc_hooks: + hook.remove() + print_rank_0("Removed grad acc hooks", force=False) del self.__ipg_bucket_flat_buffer def initialize_ds_offload( @@ -1118,7 +1122,7 @@ def wrapper(param): def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param) - grad_acc.register_hook(reduce_partition_and_remove_grads) + self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) self.grad_accs.append(grad_acc) #print(f"param grad fn {param.expand_as(param).grad_fn}") diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index aeb533698af3..e17bcbe6ade8 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -490,6 +490,7 @@ def __init__(self, self.reset_partition_gradient_structures() # creates backward hooks for gradient partitioning + self._grad_acc_hooks = [] if self.partition_gradients or self.overlap_comm: self.create_reduce_and_remove_grad_hooks() @@ -522,6 +523,11 @@ def __init__(self, self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() + def destroy(self): + for hook in self._grad_acc_hooks: + hook.remove() + self.print_rank_0("Removed grad acc hooks") + def _enable_universal_checkpoint(self): for lp_param_group in self.bit16_groups: enable_universal_checkpoint(param_list=lp_param_group) @@ -864,7 +870,7 @@ def wrapper(param, i): def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param, i) - grad_acc.register_hook(reduce_partition_and_remove_grads) + self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) self.grad_accs.append(grad_acc) wrapper(param, i) diff --git a/deepspeed/utils/debug.py b/deepspeed/utils/debug.py index 02295fa98011..cebea56255d9 100644 --- a/deepspeed/utils/debug.py +++ b/deepspeed/utils/debug.py @@ -11,6 +11,13 @@ param_names = {} +def debug_clear_module_and_param_names(): + global module_names + global param_names + module_names = {} + param_names = {} + + def debug_extract_module_and_param_names(model): # extract the fully qualified names as soon as the model is acquired global module_names From ea0d81143c6ba0801828919ea53888843d0fb19f Mon Sep 17 00:00:00 2001 From: mmhab <132277730+mmhab@users.noreply.github.com> Date: Tue, 2 Jan 2024 15:19:08 +0200 Subject: [PATCH 10/10] optimize grad_norm calculation in stage3.py (#4436) reduce the synchronization between the device and the host by removing .item() from the loops that calculate the total norm. --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Michael Wyatt Co-authored-by: Michael Wyatt Co-authored-by: Shaden Smith --- deepspeed/runtime/zero/stage3.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index fa4e64faf5a5..ce4137028195 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1328,7 +1328,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): param_id = self.get_param_id(p) if param_id in self.norm_for_param_grads.keys(): param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 + total_norm += param_norm**2 # Sum across all model parallel GPUs. total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) @@ -1337,10 +1337,14 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0]**(1. / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm @@ -1669,7 +1673,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): # Take max across all GPUs. self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() + total_norm = total_norm_cuda[0] else: # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") @@ -1690,10 +1694,14 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) - total_norm = total_norm_cuda.item()**(1. / norm_type) + total_norm = total_norm_cuda**(1. / norm_type) + + norm_is_inf = total_norm.isinf() + norm_is_nan = total_norm.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm