diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml
new file mode 100644
index 000000000000..38d4ffd1828b
--- /dev/null
+++ b/.github/workflows/nv-a6000.yml
@@ -0,0 +1,56 @@
+name: nv-a6000
+
+on:
+ pull_request:
+ paths-ignore:
+ - 'docs/**'
+ - 'blogs/**'
+ workflow_dispatch:
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+permissions:
+ contents: read
+ issues: write
+
+jobs:
+ unit-tests:
+ runs-on: [self-hosted, nvidia, a6000]
+ container:
+ image: nvcr.io/nvidia/pytorch:23.03-py3
+ ports:
+ - 80
+ options: --gpus all --shm-size "8G"
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Check container state
+ run: |
+ ldd --version
+ nvcc --version
+ nvidia-smi
+ python -c "import torch; print('torch:', torch.__version__, torch)"
+ python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+ - name: Install transformers
+ run: |
+ git clone https://github.com/huggingface/transformers
+ cd transformers
+ git rev-parse --short HEAD
+ python -m pip install .
+ - name: Install deepspeed
+ run: |
+ python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
+ python -m pip install .[dev,1bit,autotuning]
+ ds_report
+ - name: Python environment
+ run: |
+ python -m pip list
+ - name: Unit tests
+ run: |
+ unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
+ cd tests
+ python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.0" --cuda_ver="12"
+ python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12"
diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml
index 065f8b93f1e0..b5cf46f79011 100644
--- a/.github/workflows/nv-inference.yml
+++ b/.github/workflows/nv-inference.yml
@@ -34,6 +34,7 @@ jobs:
run: |
git clone https://github.com/huggingface/transformers
cd transformers
+ git checkout f370bebdc
git rev-parse --short HEAD
pip install .
diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml
index ccb6c25e14f7..f253340c6966 100644
--- a/.github/workflows/nv-pre-compile-ops.yml
+++ b/.github/workflows/nv-pre-compile-ops.yml
@@ -33,7 +33,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
- TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
+ TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 6b11b3acba51..2432a7a24124 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -49,6 +49,7 @@ repos:
entry: ./scripts/check-license.py
language: python
files: \.(py|c|cpp|cu|cc|h|hpp|cuh|hip|tr)$
+ exclude: ^(deepspeed/inference/v2/kernels/ragged_ops/blocked_flash|deepspeed/inference/v2/kernels/cutlass_ops/grouped_gemm)
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
diff --git a/MANIFEST.in b/MANIFEST.in
index 2fec750c6644..ab79573ef96c 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,4 +1,6 @@
include *.txt README.md
+include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so
+include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so
recursive-include requirements *.txt
recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
diff --git a/README.md b/README.md
index bea26ea1828a..721ba62cee37 100755
--- a/README.md
+++ b/README.md
@@ -15,6 +15,7 @@
## Latest News
DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat).
+* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)
* [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)]
* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)]
* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md)
diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py
index 4715e28b192c..2786b425ca7f 100644
--- a/accelerator/cuda_accelerator.py
+++ b/accelerator/cuda_accelerator.py
@@ -153,9 +153,26 @@ def max_memory_reserved(self, device_index=None):
def total_memory(self, device_index=None):
return torch.cuda.get_device_properties(device_index).total_memory
+ def _get_nvml_gpu_id(self, torch_gpu_id):
+ """
+ credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020
+
+ Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES.
+
+ If the latter isn't set return the same id
+ """
+ # if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")))
+ return ids[torch_gpu_id] # remap
+ else:
+ return torch_gpu_id
+
def available_memory(self, device_index=None):
if pynvml:
- handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
+ if device_index is None:
+ device_index = self.current_device()
+ handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.free
else:
diff --git a/blogs/deepspeed-fastgen/README.md b/blogs/deepspeed-fastgen/README.md
new file mode 100644
index 000000000000..b271c112a762
--- /dev/null
+++ b/blogs/deepspeed-fastgen/README.md
@@ -0,0 +1,302 @@
+
+
+# DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference
+
+
+
+
+
+
+
+
+## Table of Contents
+1. [Introduction](#introduction)
+2. [Key LLM Serving Techniques](#background)
+3. [Dynamic SplitFuse: A Novel Prompt and Generation Composition Strategy](#technical-approach)
+4. [Performance Evaluation](#performance-evaluation)
+5. [DeepSpeed-FastGen: Implementation and Usage](#using-deepspeed-fastgen)
+6. [Try out DeepSpeed-FastGen](#try)
+7. [Acknowledgements](#acknowledgements)
+
+
+## 1. Introduction
+
+Large language models (LLMs) like GPT-4 and LLaMA have emerged as a dominant workload in serving a wide range of applications infused with AI at every level. From general chat models to document summarization, and from autonomous driving to copilots at every layer of the software stack, the demand to deploy and serve these models at scale has skyrocketed. While frameworks like DeepSpeed, PyTorch, and several others can regularly achieve good hardware utilization during LLM training, the interactive nature of these applications and the poor arithmetic intensity of tasks like open-ended text generation have become the bottleneck for inference throughput in existing systems.
+
+To this end, frameworks like [vLLM](https://arxiv.org/pdf/2309.06180.pdf) powered by PagedAttention and research systems like [Orca](https://www.usenix.org/system/files/osdi22-yu.pdf) have significantly improved the performance of inference for LLMs. However, these systems still struggle to provide consistent quality of service, particularly for workloads with longer prompts. These long prompt workloads are becoming increasingly important as more and more models, like [MPT-StoryWriter](https://www.mosaicml.com/blog/mpt-7b), and systems, such as [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses), support context windows stretching to tens of thousands of tokens. To better understand the problem space, we provide detailed examples of how text generation works for LLMs in two distinct phases called prompt processing and generation. When systems treat them as distinct phases, generation will be preempted by prompt processing that risks breaking the service level agreements (SLAs).
+
+Today, we are glad to present DeepSpeed-FastGen, a system that overcomes these limitations by leveraging the proposed Dynamic SplitFuse technique and offers up to 2.3x higher effective throughput compared to state-of-the-art systems like vLLM. DeepSpeed-FastGen leverages the combination of DeepSpeed-MII and DeepSpeed-Inference to provide an easy-to-use serving system.
+
+**Quick Start:** Trying DeepSpeed-FastGen is as simple as installing the latest [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII) release:
+
+```bash
+pip install deepspeed-mii
+```
+
+To generate text using a simple non-persistent pipeline deployment, run the following code. For more details, please see [Section 5](#using-deepspeed-fastgen).
+
+```python
+from mii import pipeline
+pipe = pipeline("mistralai/Mistral-7B-v0.1")
+output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128)
+print(output)
+```
+
+## 2. Existing LLM Serving Techniques in Literature
+
+A text generation workload for a single sequence consists of two phases: 1) prompt processing, in which the user-provided text is efficiently processed as a batch of tokens to build a key-value (KV) cache for attention, and 2) token generation, which will add a single token to that cache and generate a new token. Over the course of generating a sequence of text, the model will make many forward calls to the model to generate the full sequence of text. Two major techniques have been proposed in the literature and deployed in systems that address various limitations and bottlenecks that may arise during these phases.
+
+_ Blocked KV Caching: _
+
+vLLM identified that memory fragmentation due to large monolithic KV-caches significantly reduced the concurrency of LLM serving systems and proposed [Paged Attention](https://arxiv.org/pdf/2309.06180.pdf) to enable non-contiguous caches and increase total system throughput. Rather than assign individual variable-sized contiguous chunks of memory, the underlying storage in the KV cache is fixed-sized blocks (also known as pages). The blocked KV-cache increases system throughput by increasing the amount of potential sequence concurrency by eliminating KV-cache induced memory fragmentation. Non-contiguous KV cache implementations are also included in [HuggingFace TGI](https://github.com/huggingface/text-generation-inference) and [NVIDIA TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM).
+
+_ Continuous Batching: _
+
+In the past, dynamic batching, in which a server would wait for multiple requests to process in phase with each other, was used to improve GPU utilization. However, this approach has drawbacks, as it typically requires padding inputs to identical lengths or stalling the system to wait to construct a larger batch.
+
+Recent advancement in large language model (LLM) inference and serving has been focusing on fine granularity scheduling and optimizing memory efficiency. For instance, Orca proposes _iteration-level scheduling_ (also known as continuous batching) which makes distinct scheduling decisions at each forward pass of the model. This allows requests to join/leave the batch as needed, eliminating the need for padding requests thus improving the overall throughput. In addition to Orca, continuous batching has been implemented in NVIDIA TRT-LLM, HuggingFace TGI, and vLLM.
+
+In current systems, there are two primary approaches to implement continuous batching. In TGI and vLLM, the generation phase is preempted to perform prompt processing (called infill in TGI) before continuing with generation. In Orca, these phases are not distinguished; instead, Orca will add a prompt into the running batch so long as the total number of sequences doesn't reach a fixed bound. Both of these approaches to varying degrees need to stall generation to process long prompts (see [Section 3B](#splitfuse)).
+
+To address these shortcomings, we propose a novel prompt and generation composition strategy, Dynamic SplitFuse.
+
+## 3. Dynamic SplitFuse: A Novel Prompt and Generation Composition Strategy
+
+DeepSpeed-FastGen is built to leverage continuous batching and non-contiguous KV caches to enable increased occupancy and higher responsivity for serving LLMs in the data center, similar to existing frameworks such as TRT-LLM, TGI, and vLLM. In order to achieve a new level of performance, DeepSpeed-FastGen introduces SplitFuse which leverages dynamic prompt and generation decomposition and unification to further improve continuous batching and system throughput.
+
+### A. Three Performance Insights
+Before describing Dynamic SplitFuse, we answer three key performance questions that together motivate its design.
+
+*__1. What factors impact the forward pass of a single LLM?__* In order to effectively schedule, it is necessary to understand what are the relevant independent variables the scheduling loop should control. We observe below that the composition of sequences in a forward pass (the batch size in sequences) has a negligible impact on performance compared to the raw number of tokens in the forward pass. This means an effective scheduler can be built around a single signal, the number of tokens in the forward pass.
+
+
+
+
+
+*__2. How does a model's throughput respond to changing the number of tokens in the forward pass?__* An LLM has two key operating regions with a relatively steep transition. With a small number of tokens, the GPU bottleneck is reading the model from memory and so throughput scales with the number of tokens, whereas with many tokens the model is throughput bound by compute and sees near-constant throughput. The model should run highly efficiently if all forward passes are in the throughput-saturating region.
+
+
+
+
+
+*__3. How should a pool of tokens be scheduled across multiple forward passes?__* We observe above that for well-aligned inputs the token-throughput curve is concave, which means the second derivative is bound to be less than or equal to 0. As an example, let $f(x)$ be a concave function of latency to throughput for a given model. For a concave function $f(x)$, the following holds:
+
+ $$0 \geq \lim_{h \to 0} \frac{f(x + h) - 2f(x) + f(x - h)}{h^2}$$
+
+ $$0 \geq f(x + h) - 2f(x) + f(x - h)$$
+
+ $$2f(x) \geq f(x + h) + f(x - h)$$
+
+This states that for a given pool of `2x` tokens to process, the manner that maximizes throughput is that which evenly splits them between two batches. More generally, in a system that must consume and process P tokens over F forward passes, the ideal partitioning scheme will divide them equally.
+
+### B. Dynamic SplitFuse
+
+Dynamic SplitFuse is a novel token composition strategy for prompt processing and token generation. DeepSpeed-FastGen utilizes Dynamic SplitFuse to run at a consistent forward size by leveraging the capability to take partial tokens from prompts and compose this with generation. In particular, Dynamic SplitFuse performs two key behaviors:
+
+1. Long prompts are decomposed into much smaller chunks and scheduled across multiple forward passes (iterations) with only the final pass performing any generation.
+2. Short prompts will be composed to exactly fill a target token budget. Even short prompts may be decomposed to ensure the budget is precisely met and the forward sizes are well-aligned.
+
+Together, these two techniques provide concrete benefits on all user metrics:
+
+1. *__Better Responsiveness__:* Since long prompts no longer require extremely long forward passes to process, the model will provide lower client latency. More forward passes are performed within the same window of time.
+2. *__Higher Efficiency:__* Fusion of short prompts to larger token budgets enables the model to consistently operate in the high throughput regime.
+3. *__Lower variance and better consistency:__* Since forward passes are of consistent size and forward pass size is the primary determinant of performance, the latency of each forward pass is much more consistent than competing systems as is the perceived generation frequency. There are no pre-emption or long-running prompts to increase the latency as in other prior work.
+
+Consequently, DeepSpeed-FastGen will consume tokens from incoming prompts at a rate that permits fast ongoing generation while adding tokens to the system that increase system utilization, providing lower latency and higher throughput streaming generation to all clients as compared to other state-of-the-art serving systems.
+
+
+
+
+
+ *Figure 1: Illustration of continuous batching strategies. Each block shows the execution of a forward pass. An arrow indicates that the forward pass has sequences with one or more tokens generated. vLLM performs either token generations or prompt processing in a forward pass; token generation preempts prompt processing. Orca runs prompts at their complete length alongside generation. Dynamic SplitFuse performs dynamic composition of fixed-sized batches composed of both generation and prompt tokens.*
+
+
+
+## 4. Performance Evaluation
+
+DeepSpeed-FastGen provides state-of-the-art LLM serving performance leveraging its blocked KV cache and Dynamic SplitFuse continuous batching. We evaluate DeepSpeed-FastGen against vLLM on a range of models and hardware configurations following the benchmarking methodology discussed below.
+
+### A. Benchmarking Methodology
+
+We use two primary quantitative schemes for measuring performance.
+
+**Throughput-Latency Curves:** Two key metrics for production readiness are throughput (measured in requests per second) and latency (the responsiveness of each request). To measure this, we instantiate multiple clients (ranging from 1 to 32) concurrently and send requests (512 in total) to the server. The resulting latency of each request is measured at the endpoint and throughput is measured by the end-to-end time to complete the experiment.
+
+**Effective Throughput:** Interactive applications, such as chat applications, can have more stringent and complex requirements than can be captured by top-level metrics like end-to-end latency. In particular, we focus on the increasingly popular chat user scenario:
+
+ 1. A user initiates a task by sending a prompt.
+ 2. The system processes the prompt and returns the first token.
+ 3. Subsequent tokens are streamed to the user as they are produced.
+
+At each point in this process there is an opportunity for a system to provide an adverse user experience; for example, if the first token arrives too slowly or the generation appears to stop for some time. We propose an SLA framework that considers both of these dimensions.
+
+As the lengths of prompts and generated texts vary significantly, affecting computational costs, it is impractical to set rigid SLA values for throughput and latency. Therefore, we define the SLA for prompt latency as |tokens in prompt| / 512 seconds (= 512 tokens/s). Additionally, considering humans' reading speed, we set the SLA for generation latency on the Exponential Moving Average (EMA) to 2, 4, or 6 tokens/sec. Requests that adhere to these SLAs are deemed successful, and the throughput of these successful requests is referred to as **effective throughput**.
+
+We evaluate vLLM and DeepSpeed-FastGen on both Llama-2 7B, Llama-2 13B, and Llama-2 70B on NVIDIA A100, H100, and A6000.
+
+### B. Throughput-Latency Analysis
+
+In this experiment, DeepSpeed-FastGen outperforms vLLM in both throughput and latency, providing equivalent latency with greater throughput or more responsive latency and the same throughput. On Llama-2 70B with 4 A100x80GB, DeepSpeed-FastGen demonstrates up to 2x higher throughput (1.36 rps vs. 0.67 rps) at identical latency (9 seconds) or up to 50% latency reduction (7 seconds vs. 14 seconds) while achieving the same throughput (1.2 rps), as shown in Figure 2. These trends hold when evaluating Llama-2 13B as shown in Figure 3.
+
+
+
+
+ *Figure 2: Throughput and latency of text generation using Llama 2 70B (Tensor parallelism across 4 A100-80GB GPUs). A normal distribution was applied to prompt and generation lengths with averages of 1200/2600 and 128/60, respectively, and a 30% variance*
+
+
+
+
+
+ *Figure 3: Throughput and latency of text generation using Llama 2 13B (A100-80GB GPU, no tensor parallelism). A normal distribution was applied to prompt and generation lengths with averages of 1200/2600 and 60/128, respectively, and a 30% variance*
+
+
+### C. Effective Throughput Analysis
+
+Under the effective throughput analysis that considers both first token latency and the rate at which generation occurs, DeepSpeed-FastGen provides up to 2.3x higher throughput than vLLM. Figure 4 presents a comparative analysis of the effective throughputs of DeepSpeed-FastGen and vLLM. Each plotted point denotes the effective throughput derived from a specific number of clients. As we scaled the number of clients, we initially observed an increase in effective throughput. However, the latency also significantly increases as the number of clients approaches the system's capacity, causing many requests to fail in meeting the SLA. Consequently, the effective throughput will either saturate or decrease at some point. From a usability perspective, it's not particularly relevant how many clients are required to achieve the max effective throughput; the maximum point of the line is the optimal serving point.
+
+
+
+
+ *Figure 4: Effective throughput of DeepSpeed-FastGen and vLLM (Llama 2 70B/A100-80GB using tensor parallelism across 4 A100-80GB GPUs. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance)*
+
+
+When vLLM preempts the ongoing generation of previous requests, the generation latency experiences a notable increase. This leads to vLLM's effective throughput appearing lower than its directly measured throughput. At vLLM's peak, the effective throughput was 0.63 queries/sec and around 28% of requests failed to meet the 4 tokens/s SLA. At the same SLA, DeepSpeed-FastGen achieved 1.42 queries/sec (less than 1% of requests failed to meet the SLA), which is 2.3x higher than vLLM.
+
+### D. Token Level Timing Analysis
+
+Figure 5 displays the P50, P90, and P95 latencies of the generation processes. Both vLLM and DeepSpeed-FlexGen exhibit similar P50 latencies, but vLLM demonstrates significantly higher latencies for P90 and P95.
+Regarding the P95 latencies, DeepSpeed-FlexGen achieved a reduction of 3.7 times.
+
+This discrepancy is due to a noticeable spike in vLLM's generation latency when it preempts the ongoing generation to process new prompts.
+In contrast, DeepSpeed-FastGen typically processes the prompt and generation for previous requests concurrently, leading to much more consistent generation latency.
+
+
+
+
+
+ *Figure 5: Per-Token generation Latency of Llama 2 70B/A100-80GB using tensor parallelism across 4 A100-80GB GPUs, 16 clients. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 128, respectively, and a 30% variance.*
+
+
+
+### E. Scalability using Load Balancing
+
+DeepSpeed-FastGen offers replica-level load balancing that evenly distributes requests across multiple servers, allowing you to effortlessly scale up your application.
+
+Figure 6 illustrates the scalability of DeepSpeed-FastGen when employing the load balancer and up to 16 replicas. Note that we utilized 4 A100 GPUs to compute the Llama 2 70B model. In total, we employed 8 nodes to run the 16 replicas. The results demonstrate nearly perfect scalability with DeepSpeed-FastGen.
+Given that the throughput of a single replica is 1.46 queries/sec, the throughput with 16 replicas reaches 23.7 queries/sec, marking a linear 16x increase compared to a single replica.
+
+
+
+
+ *Figure 6: Scalability using the load balancing feature. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+
+### F. Other Hardware Platforms
+
+In addition to the deep analysis on A100, we provide additional benchmarking results for H100 and A6000. The same performance trends were observed on both A6000 and H100 as A100.
+
+
+
+
+ *Figure 7: Throughput-latency curve and effective throughput of Llama 2 70b using 8 H100 GPUs. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+
+
+
+
+ *Figure 8: Throughput-latency curve and effective throughput of Llama 2 7b using A6000. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+
+## 5. DeepSpeed-FastGen: Implementation and Usage
+
+DeepSpeed-FastGen is the synergistic composition of [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII) and [DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed) as illustrated in the figure below. Together, both of these software packages provide various components of the system including the frontend APIs, the host and device infrastructure to schedule batches using Dynamic SplitFuse, optimized kernel implementations, and the tools to construct new model implementations.
+
+
+
+
+
+
+
+The fastest way to get started with our alpha release of DeepSpeed-FastGen is: `pip install deepspeed-mii`.
+
+Please follow our [Getting Started](https://github.com/microsoft/deepspeed-mii#getting-started-with-mii) guide for more details. For usage and reporting issues, please use the [DeepSpeed-MII Github repository](https://github.com/microsoft/DeepSpeed-MII).
+
+### A. Supported Models
+
+We currently support the following model architectures in this alpha release of DeepSpeed-FastGen:
+
+* [LLaMA](https://huggingface.co/models?other=llama) and [LLaMA-2](https://huggingface.co/models?other=llama-2)
+* [Mistral](https://huggingface.co/models?other=mistral)
+* [OPT](https://huggingface.co/models?other=opt)
+
+All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer.
+
+We plan to add additional models in the coming weeks and months after the initial release. If there are specific model architectures you would like supported, please [file an issue](https://github.com/microsoft/DeepSpeed-MII/issues) and let us know.
+
+### B. Deployment options
+All of the examples below are runnable in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/mii). Once installed you have two options for deployment: an interactive non-persistent pipeline or a persistent serving deployment:
+
+#### Non-persistent pipeline
+
+The non-persistent pipeline deployment is a great and fast way to get started and can be done with only a few lines of code. Non-persistent models are only around for the duration of the python script you are running but are useful for temporary interactive sessions.
+
+```python
+from mii import pipeline
+pipe = pipeline("mistralai/Mistral-7B-v0.1")
+output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128)
+print(output)
+```
+
+#### Persistent deployment
+
+A persistent deployment is ideal for use with long-running and production applications. The persistent deployment uses a lightweight GRPC server that can be created using the following 2 lines:
+
+
+```python
+import mii
+mii.serve("mistralai/Mistral-7B-v0.1")
+```
+
+The above server can be queried by multiple clients at once thanks to the built-in load balancer from DeepSpeed-MII. Creating a client also just takes 2 lines of code:
+
+```python
+client = mii.client("mistralai/Mistral-7B-v0.1")
+output = client.generate("Deepspeed is", max_new_tokens=128)
+print(output)
+```
+
+A persistent deployment can be terminated when it is no longer needed:
+
+```python
+client.terminate_server()
+```
+
+### C. Advanced Installation Information
+
+For ease of use and a significant reduction in lengthy compile times that many projects require in this space, we distribute a pre-compiled Python wheel covering the majority of our custom kernels through a new library called [DeepSpeed-Kernels](https://github.com/microsoft/DeepSpeed-Kernels). We have found this library to be very portable across environments with NVIDIA GPUs with compute capabilities 8.0+ (Ampere+), CUDA 11.6+, and Ubuntu 20+. In most cases, you shouldn't even need to know this library exists as it is a dependency of DeepSpeed-MII and will be installed with it. However, if for whatever reason you need to compile our kernels manually please see our [advanced installation docs](https://github.com/microsoft/DeepSpeed-Kernels#source).
+
+
+# 6. Try Out DeepSpeed-FastGen
+We are very excited to share this DeepSpeed-FastGen alpha release.
+
+* To get started, please visit our GitHub page for DeepSpeed-MII: [GitHub Landing Page](https://github.com/microsoft/DeepSpeed-MII)
+
+DeepSpeed-FastGen is part of the bigger DeepSpeed ecosystem comprising a multitude of Deep Learning systems and modeling technologies. To learn more,
+
+* Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation.
+* You can also follow us on our [English Twitter](https://twitter.com/MSFTDeepSpeed), [Japanese Twitter](https://twitter.com/MSFTDeepSpeedJP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed.
+
+DeepSpeed welcomes your contributions! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/) page. Please see our [contributing guide](https://github.com/microsoft/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, and companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to deepspeed-info@microsoft.com.
+
+The following items are on our roadmap and we plan to engage with our community on these through our GitHub issues and PRs:
+
+- Performance improvements
+- Broader model support
+- New hardware backends through collaboration with partners
+- Release performance benchmarks (used to generate plots in this blog)
+
+**"Star" our [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/) and [DeepSpeedMII GitHub](https://github.com/microsoft/DeepSpeed-MII/) repositories if you like our work!**
+
+# 7. Acknowledgements
+
+We would like to thank various open-source community projects including HuggingFace, vLLM, and HuggingFace TGI. We have leveraged HF APIs to support models and tokenizers in our alpha release and will continue to add more models. We especially acknowledge and thank the developers of [Flash Attention](https://github.com/Dao-AILab/flash-attention) for their great work. We have extensively leveraged FlashAttention kernels in our system with modifications that have been acknowledged in our code repositories at appropriate file headers. Finally, we want to thank the developers of [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) kernels that we have used in our MoE kernels (released as part of DeepSpeed-Kernels repository).
diff --git a/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png b/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png
new file mode 100644
index 000000000000..9d4ab55f5f7a
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png b/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png
new file mode 100644
index 000000000000..89fb9ca3e1ce
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/effective_throughput.png b/blogs/deepspeed-fastgen/assets/images/effective_throughput.png
new file mode 100644
index 000000000000..11c7f82bc54f
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/effective_throughput.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png b/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png
new file mode 100644
index 000000000000..1b9a38306e8e
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg b/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg
new file mode 100644
index 000000000000..2affbf8a4cc3
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg differ
diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png
new file mode 100644
index 000000000000..9b90357a3f1b
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png
new file mode 100644
index 000000000000..9e754abde85d
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png
new file mode 100755
index 000000000000..6ac1a775805b
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png
new file mode 100755
index 000000000000..af8f1defe653
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png
new file mode 100755
index 000000000000..dde598a985d8
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png
new file mode 100755
index 000000000000..bdb5f8df483e
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png
new file mode 100644
index 000000000000..6d45880588d9
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png
new file mode 100644
index 000000000000..7c14e2bf6e53
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png b/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png
new file mode 100644
index 000000000000..834c06dfb07a
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png
new file mode 100644
index 000000000000..df16b5bebc53
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png
new file mode 100644
index 000000000000..8b69a8a1718b
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/throughput_latency.png b/blogs/deepspeed-fastgen/assets/images/throughput_latency.png
new file mode 100644
index 000000000000..aaceebde7038
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/throughput_latency.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png b/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png
new file mode 100644
index 000000000000..cc7b8ec1ec05
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png differ
diff --git a/blogs/deepspeed-fastgen/assets/images/token_latency.png b/blogs/deepspeed-fastgen/assets/images/token_latency.png
new file mode 100644
index 000000000000..405a3c0d06ed
Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/token_latency.png differ
diff --git a/csrc/includes/activation_type.h b/csrc/includes/activation_type.h
new file mode 100644
index 000000000000..a44921d5d650
--- /dev/null
+++ b/csrc/includes/activation_type.h
@@ -0,0 +1,17 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+enum ActivationType {
+ GELU = 0,
+ RELU = 1,
+ SILU = 2,
+ GEGLU = 3,
+ ReGLU = 4,
+ SiGLU = 5,
+ IDENTITY = 6,
+ InvalidType = -1
+};
diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h
index a3af561adfe5..8e4888109fcd 100644
--- a/csrc/includes/ds_kernel_utils.h
+++ b/csrc/includes/ds_kernel_utils.h
@@ -11,6 +11,11 @@ used throughout the codebase.
#pragma once
#include
+#include
+
+#ifdef BF16_AVAILABLE
+#include
+#endif
#define DS_HD_INLINE __host__ __device__ __forceinline__
#define DS_D_INLINE __device__ __forceinline__
diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h
index 8d0c2d6986a1..eb8efab77ac1 100644
--- a/csrc/includes/reduction_utils.h
+++ b/csrc/includes/reduction_utils.h
@@ -233,6 +233,60 @@ DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs)
#endif
}
+template <>
+DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs)
+{
+ return lhs + rhs;
+}
+
+template <>
+DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs)
+{
+ return (lhs > rhs) ? lhs : rhs;
+}
+
+template <>
+DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs)
+{
+ return (lhs < rhs) ? lhs : rhs;
+}
+
+template <>
+DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs)
+{
+ return lhs + rhs;
+}
+
+template <>
+DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs)
+{
+ return (lhs > rhs) ? lhs : rhs;
+}
+
+template <>
+DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs)
+{
+ return (lhs < rhs) ? lhs : rhs;
+}
+
+template <>
+DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs)
+{
+ return lhs + rhs;
+}
+
+template <>
+DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs)
+{
+ return (lhs > rhs) ? lhs : rhs;
+}
+
+template <>
+DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs)
+{
+ return (lhs < rhs) ? lhs : rhs;
+}
+
/*
Reduction initialization primitives
*/
@@ -310,6 +364,78 @@ DS_D_INLINE __half2 init()
#endif
}
+template <>
+DS_D_INLINE int32_t init()
+{
+ return 0;
+}
+
+template <>
+DS_D_INLINE int32_t init()
+{
+ return 0x7FFFFFFF;
+}
+
+template <>
+DS_D_INLINE int32_t init()
+{
+ return 0x80000000;
+}
+
+template <>
+DS_D_INLINE uint32_t init()
+{
+ return 0;
+}
+
+template <>
+DS_D_INLINE uint32_t init()
+{
+ return 0xFFFFFFFF;
+}
+
+template <>
+DS_D_INLINE uint32_t init()
+{
+ return 0;
+}
+
+template <>
+DS_D_INLINE int64_t init()
+{
+ return 0;
+}
+
+template <>
+DS_D_INLINE int64_t init()
+{
+ return 0x7FFFFFFFFFFFFFFF;
+}
+
+template <>
+DS_D_INLINE int64_t init()
+{
+ return 0x8000000000000000;
+}
+
+template <>
+DS_D_INLINE uint64_t init()
+{
+ return 0;
+}
+
+template <>
+DS_D_INLINE uint64_t init()
+{
+ return 0xFFFFFFFFFFFFFFFF;
+}
+
+template <>
+DS_D_INLINE uint64_t init()
+{
+ return 0;
+}
+
template
DS_D_INLINE void init(T* data)
{
@@ -352,8 +478,8 @@ here (fold is C++17 only and I don't think helps and recursion feels like
huge overkill that harms readability) that would be wonderful.
*/
-template
-DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data)
+template
+DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@@ -361,8 +487,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data)
}
}
-template
-DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data)
+template
+DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@@ -371,8 +497,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data)
}
}
-template
-DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data)
+template
+DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@@ -382,8 +508,13 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data)
}
}
-template
-DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data)
+template
+DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
@@ -403,16 +534,15 @@ the number of warps in the block (which may exceed that
if the block is partitioned or if we do a conservative bound at
compile time).
*/
-template
+template
DS_D_INLINE void _block(cg::thread_block& tb,
cg::thread_block_tile& warp_arg,
- float* data)
+ T* data)
{
constexpr int elems = sizeof...(Ops);
- // Separated for now in case this no longer is true
- constexpr int bytes = sizeof(float);
+ constexpr int bytes = sizeof(T);
// Unused when `partition_size == 1` or total_warps == 1
- __shared__ float reduce_buffer[max_warps * elems];
+ __shared__ T reduce_buffer[max_warps * elems];
#ifdef __HIP_PLATFORM_AMD__
const int total_threads = blockDim.x * blockDim.y * blockDim.z;
@@ -422,7 +552,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
#endif
// Always perform warp-scope reduction
- _warp(warp_arg, data);
+ _warp(warp_arg, data);
// If max_warps == 1 let's skip the runtime check
if (total_warps != 1) {
@@ -447,7 +577,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
init(data);
}
- _warp(warp_arg, data);
+ _warp(warp_arg, data);
#pragma unroll
for (int i = 0; i < elems; i++) {
@@ -476,7 +606,7 @@ us to obfuscate the details of the partitioned implementation.
template
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val)
{
- _block(tb, warp, &val);
+ _block(tb, warp, &val);
}
template
@@ -486,7 +616,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val2)
{
float data[2] = {val1, val2};
- _block(tb, warp, data);
+ _block(tb, warp, data);
val1 = data[0];
val2 = data[1];
}
@@ -499,7 +629,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val3)
{
float data[3] = {val1, val2, val3};
- _block(tb, warp, data);
+ _block(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
@@ -514,7 +644,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val4)
{
float data[4] = {val1, val2, val3, val4};
- _block(tb, warp, data);
+ _block(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
@@ -531,10 +661,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float& val)
{
if (num_threads <= hw_warp_size) {
- _warp(warp, &val);
+ _warp(warp, &val);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
- _block(tb, warp, &val);
+ _block(tb, warp, &val);
}
}
@@ -547,10 +677,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float data[2] = {val1, val2};
if (num_threads <= hw_warp_size) {
- _warp(warp, data);
+ _warp(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
- _block(tb, warp, data);
+ _block(tb, warp, data);
}
val1 = data[0];
@@ -567,10 +697,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float data[3] = {val1, val2, val3};
if (num_threads <= hw_warp_size) {
- _warp(warp, data);
+ _warp(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
- _block(tb, warp, data);
+ _block(tb, warp, data);
}
val1 = data[0];
@@ -589,10 +719,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float data[4] = {val1, val2, val3, val4};
if (num_threads <= hw_warp_size) {
- _warp(warp, data);
+ _warp(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
- _block(tb, warp, data);
+ _block(tb, warp, data);
}
val1 = data[0];
@@ -601,4 +731,48 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
val4 = data[3];
}
+/*
+Arg-reduce is a specialization of the above. We only support this with a single reduction
+parameter. This only works for max/min reductions.
+*/
+
+__align__(8) struct IdxReduceResult {
+ /*
+ NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits
+ and the val is the most significant. Changing the order of this declaration
+ will break the code.
+ */
+ int idx;
+ float val;
+};
+
+template
+DS_D_INLINE IdxReduceResult
+idx_reduce(cg::thread_block& tb, cg::thread_block_tile& warp, float val, int idx)
+{
+ IdxReduceResult res = {idx, val};
+
+ // Clear out the nan. This shouldn't be an issue for our initial applications
+ if (isnan(val)) res.val = init();
+
+ // Can do float compares as integers. By packing the index into the lower bits
+ // we can just do a single int64 rather than a branch, compare, and select.
+ // One side benefit of this is that it is by nature a stable algorithm and
+ // will always bias ties to the higher index.
+ int64_t* res_as_int = reinterpret_cast(&res);
+
+ // The way floating point compare works is normally to perform a sign comparison
+ // and if they match, then do a comparison of the rest of the bits as unsigned
+ // integers. Since we are bundling these, that means for negative values we need
+ // to reverse the sort order, which we can do with an XOR.
+ if (val < 0) { *res_as_int ^= 0x7fffffff00000000; }
+
+ _block(tb, warp, res_as_int);
+
+ // Sign bit is preserved, so we can check if we need to invert the mantissa back
+ if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; }
+
+ return res;
+}
+
} // namespace reduce
diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu
index 3e6701d81e64..dc7ff4d1e7c0 100644
--- a/csrc/transformer/inference/csrc/gelu.cu
+++ b/csrc/transformer/inference/csrc/gelu.cu
@@ -18,8 +18,8 @@ using __nv_bfloat162 = __half2;
inline __device__ float gelu(const float x)
{
- const float sqrt_param = 0.79788456080286535587989211986876f;
- const float mul_param = 0.044715;
+ constexpr float sqrt_param = 0.79788456080286535587989211986876f;
+ constexpr float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
diff --git a/deepspeed/inference/__init__.py b/deepspeed/inference/__init__.py
index 208299fb8c50..7fed50cbe177 100644
--- a/deepspeed/inference/__init__.py
+++ b/deepspeed/inference/__init__.py
@@ -2,3 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
+from .v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
+from .v2.engine_v2 import InferenceEngineV2
+from .v2 import build_hf_engine
diff --git a/deepspeed/inference/v2/__init__.py b/deepspeed/inference/v2/__init__.py
new file mode 100644
index 000000000000..bba5d1a82081
--- /dev/null
+++ b/deepspeed/inference/v2/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+from .config_v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig
+from .engine_v2 import InferenceEngineV2
+from .engine_factory import build_hf_engine
diff --git a/deepspeed/inference/v2/allocator.py b/deepspeed/inference/v2/allocator.py
new file mode 100644
index 000000000000..fa2c5368604e
--- /dev/null
+++ b/deepspeed/inference/v2/allocator.py
@@ -0,0 +1,32 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from functools import reduce
+from typing import Iterable
+
+import torch
+
+from deepspeed.accelerator import get_accelerator
+
+
+def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor:
+ shape_size = reduce(lambda x, y: x * y, shape)
+ if shape_size == 0:
+ raise ValueError("Cannot create empty tensor with size 0")
+ return tensor.flatten()[:shape_size].view(shape)
+
+
+def on_device(method) -> torch.Tensor:
+ """
+ Wraps a method to ensure the returned tensor is on the current device.
+ """
+
+ def wrapped(self, *args, **kwargs):
+ tensor = method(self, *args, **kwargs)
+ if isinstance(tensor, torch.Tensor):
+ return tensor.to(get_accelerator().current_device()).contiguous()
+ return tensor
+
+ return wrapped
diff --git a/deepspeed/inference/v2/checkpoint/__init__.py b/deepspeed/inference/v2/checkpoint/__init__.py
new file mode 100644
index 000000000000..45e523ab62b9
--- /dev/null
+++ b/deepspeed/inference/v2/checkpoint/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .base_engine import CheckpointEngineBase
+from .in_memory_engine import InMemoryModelEngine
+from .huggingface_engine import HuggingFaceCheckpointEngine
diff --git a/deepspeed/inference/v2/checkpoint/base_engine.py b/deepspeed/inference/v2/checkpoint/base_engine.py
new file mode 100644
index 000000000000..26fc467d4d86
--- /dev/null
+++ b/deepspeed/inference/v2/checkpoint/base_engine.py
@@ -0,0 +1,41 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from abc import ABC, abstractmethod
+from typing import Iterable, Tuple
+
+import torch
+
+#from .huggingface_engine import HuggingFaceCheckpointEngine
+
+MEGATRON = 'megatron'
+HUGGINGFACE = 'huggingface'
+
+
+class CheckpointEngineBase(ABC):
+ """
+ Abstract interface for checkpoint engines to implement.
+
+ There is no ``__init__`` method here by design, since the creation of the checkpoint
+ engine will happen outside the policy/engine code. The tradeoff being made here is
+ that we will write different frontends for different checkpoint engines, but these
+ frontends can be tailored to the specific checkpoint engine/model source needs.
+ """
+
+ @abstractmethod
+ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
+ """
+ This method should create a generator of tuples of the form (name, parameter) for
+ all parameters in the model. The name should be the fully qualified name of the
+ parameter, and the parameter should be a torch.Tensor.
+
+ The expected use of a checkpoint engine is the following:
+ ```python
+ for name, parameter in checkpoint_engine.parameters():
+ container_map.map_param(name, parameter)
+ ```
+ For a concrete use example, see ``InferenceV2Policy``.
+ """
+ ...
diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py
new file mode 100644
index 000000000000..515378d31d02
--- /dev/null
+++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py
@@ -0,0 +1,99 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import os
+import json
+import torch
+from .base_engine import CheckpointEngineBase
+from typing import Iterable, Tuple
+
+from ..logging import inference_logger
+
+
+class HuggingFaceCheckpointEngine(CheckpointEngineBase):
+
+ def __init__(self, model_name_or_path: str, auth_token: str = None) -> None:
+ super().__init__()
+ from transformers import AutoConfig, GenerationConfig
+
+ self.model_name_or_path = model_name_or_path
+ self.auth_token = auth_token
+ self.model_config = AutoConfig.from_pretrained(self.model_name_or_path)
+ self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
+ # Define this property here so we can use it in the model implementation
+ if not hasattr(self.model_config, "max_seq_length"):
+ self.model_config.max_seq_length = self.model_config.max_position_embeddings
+ else:
+ self.model_config.max_seq_length = self.generation_config.max_length
+
+ self._all_ckpt_paths = self._fetch_checkpoint_files()
+
+ def _fetch_checkpoint_files(self):
+ """
+ Fetch the checkpoint files from the HuggingFace Hub.
+ """
+ # TODO(jeff): for models like llama-2 the user will have to provide an auth `token`,
+ # currently coming from the ckpt engine init but maybe a catch all kwargs for other
+ # snapshot download parameters would be more flexible.
+
+ # NOTE(jeff): allow_patterns here are explicitly not using safetensors or other
+ # checkpoint files that may be present. Example of all files in the llama-2-7b
+ # repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main
+ from huggingface_hub import snapshot_download
+
+ if os.path.isdir(self.model_name_or_path):
+ self._local_checkpoint_dir = self.model_name_or_path
+ else:
+ self._local_checkpoint_dir = snapshot_download(self.model_name_or_path,
+ allow_patterns=[
+ "*.bin",
+ "*.json",
+ "*.pt",
+ ],
+ revision=None,
+ token=self.auth_token)
+
+ assert os.path.isdir(
+ self._local_checkpoint_dir
+ ), f"Checkpoint dir {self._local_checkpoint_dir} is not a directory, cannot load checkpoint."
+
+ model_param_json = os.path.join(self._local_checkpoint_dir, "pytorch_model.bin.index.json")
+
+ if not os.path.isfile(model_param_json):
+ # We don't need any json as all such HF models will have pytorch_model.bin
+ all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, 'pytorch_model.bin')]
+ else:
+ param_map = json.load(open(model_param_json, "r"))
+
+ # weight_map -> { "lm_head.weight": "pytorch_model-00002-of-00002.bin", ... }
+ weight_map = param_map["weight_map"]
+
+ # unique set of all checkpoint files
+ all_checkpoint_files = set(weight_map.values())
+
+ # get absolute path of all unique checkpoint files
+ all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, f) for f in all_checkpoint_files]
+
+ return all_checkpoint_files
+
+ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
+ """
+ Generator of model parameters (satisfies the CheckpointEngineBase interface).
+ """
+ for checkpoint in self._all_ckpt_paths:
+ inference_logger().info(f"Loading checkpoint: {checkpoint}")
+ checkpoint_sd = torch.load(checkpoint, map_location='cpu')
+ param_keys = list(checkpoint_sd.keys())
+ for param_name in param_keys:
+ param = checkpoint_sd[param_name]
+ yield param_name, param
+
+
+if __name__ == "__main__":
+ # To test, add your auth_token here and run `python huggingface_engine.py`
+ engine = HuggingFaceCheckpointEngine(model_name_or_path="meta-llama/Llama-2-7b-hf",
+ auth_token="hf_xxxxxxxxxxxxxxxxx")
+ for name, param in engine.parameters():
+ print(name, param.shape)
diff --git a/deepspeed/inference/v2/checkpoint/in_memory_engine.py b/deepspeed/inference/v2/checkpoint/in_memory_engine.py
new file mode 100644
index 000000000000..13ec7b288f5f
--- /dev/null
+++ b/deepspeed/inference/v2/checkpoint/in_memory_engine.py
@@ -0,0 +1,40 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from typing import Iterable, Tuple
+import torch
+
+from .base_engine import CheckpointEngineBase
+
+
+class InMemoryModelEngine(CheckpointEngineBase):
+ """
+ This "checkpoint" engine uses the existing interface to enable loading parameters into an
+ inference model from a model already instantiated in memory. In general, this is not the
+ recommended way to use the inference engine, and should only be used when absolutely necessary.
+
+ The primary limitation of this approach is that the model must be fully instantiated in memory.
+ In a tensor parallel scenario, this means that the model is either replicated many times in host
+ memory. Currently, it is also recommended to only use this approach for models held in host memory.
+
+ In order to free the memory held by this copy of the model, we delete the model in the first call
+ to `parameters`, so it is not safe to make this call twice.
+ """
+
+ def __init__(self, model: torch.nn.Module) -> None:
+ """
+ Create virtual checkpoint engine for the provided module.
+
+ Args:
+ model (torch.nn.Module): Model to load parameters from.
+ """
+ super().__init__()
+ self.model = model
+
+ def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]:
+ for name, parameter in self.model.named_parameters():
+ yield name, parameter
+
+ del self.model
diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py
new file mode 100644
index 000000000000..64e7e29b1844
--- /dev/null
+++ b/deepspeed/inference/v2/config_v2.py
@@ -0,0 +1,31 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from deepspeed.pydantic_v1 import Field
+
+from deepspeed.runtime.config_utils import DeepSpeedConfigModel
+from .ragged import DSStateManagerConfig
+
+
+class DeepSpeedTPConfig(DeepSpeedConfigModel):
+ """ Configure tensor parallelism settings """
+
+ tp_size: int = 1
+ """ Number of devices to split the model across using tensor parallelism. """
+
+
+class RaggedInferenceEngineConfig(DeepSpeedConfigModel):
+ """ Sets parameters for DeepSpeed Inference Engine. """
+
+ tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp")
+ """
+ Configuration for tensor parallelism used to split the model across several
+ GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`.
+ """
+
+ state_manager: DSStateManagerConfig = Field({}, alias="manager")
+ """
+ Configuration for managing persistent state
+ """
diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py
new file mode 100644
index 000000000000..48274d6c3d53
--- /dev/null
+++ b/deepspeed/inference/v2/engine_factory.py
@@ -0,0 +1,46 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import logging
+from typing import Any
+
+from .engine_v2 import InferenceEngineV2
+from .config_v2 import RaggedInferenceEngineConfig
+from .checkpoint import HuggingFaceCheckpointEngine
+from .logging import inference_logger
+
+
+def build_hf_engine(path: str,
+ engine_config: RaggedInferenceEngineConfig,
+ debug_level: int = logging.INFO,
+ random_weights_config: Any = None,
+ fill_random: bool = False) -> InferenceEngineV2:
+ """
+ Build an InferenceV2 engine for HuggingFace models.
+ """
+ # Set up logging
+ inference_logger(level=debug_level)
+
+ # get HF checkpoint engine
+ checkpoint_engine = HuggingFaceCheckpointEngine(path)
+
+ # get model config from HF AutoConfig
+ model_config = checkpoint_engine.model_config
+
+ # get the policy
+ # TODO: generalize this to other models
+ if model_config.model_type == "opt":
+ from .model_implementations.opt.policy import OPTPolicy
+ policy = OPTPolicy(checkpoint_engine, model_config)
+ elif model_config.model_type == "llama":
+ from .model_implementations.llama_v2.llama_v2_policy import Llama2Policy
+ policy = Llama2Policy(checkpoint_engine, model_config)
+ elif model_config.model_type == "mistral":
+ from .model_implementations.mistral.policy import MistralPolicy
+ policy = MistralPolicy(checkpoint_engine, model_config)
+ else:
+ raise ValueError(f"Unsupported model type {model_config.model_type}")
+
+ return InferenceEngineV2(policy, engine_config)
diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py
new file mode 100644
index 000000000000..c670ddd09417
--- /dev/null
+++ b/deepspeed/inference/v2/engine_v2.py
@@ -0,0 +1,217 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import os
+from typing import Iterable, Tuple
+
+import torch
+
+import deepspeed.comm as dist
+
+from deepspeed.accelerator import get_accelerator
+from deepspeed.comm.comm import init_distributed
+
+from .model_implementations import InferenceV2Policy
+from .logging import inference_logger
+from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
+from .scheduling_utils import SchedulingError, SchedulingResult
+
+from .config_v2 import RaggedInferenceEngineConfig
+
+INFERENCE_MODEL_TIMER = "model-forward-inference"
+
+
+class InferenceEngineV2:
+
+ _config: RaggedInferenceEngineConfig
+ """
+ Configuration of the inference engine.
+ """
+
+ #_model: DSInferenceModelBase
+ """
+ Inference model supporting ragged inference.
+ """
+
+ _state_manager: DSStateManager
+ """
+ Persistent state manager for sequences and KV-cache.
+ """
+
+ @property
+ def free_blocks(self) -> int:
+ """
+ Number of free KV blocks.
+ """
+ return self._state_manager.free_blocks
+
+ def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None:
+ """
+ Create the Inference V2 engine.
+
+ Arguments:
+ policy (InferenceV2Policy): Policy for the model implementation. This policy object
+ will be used to build the model and load the checkpoint associated with it.
+ engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine.
+ """
+ self._config = engine_config
+ self._policy = policy
+ self._base_mp_group = self._initialize_tp_group()
+
+ # Build model from policy
+ inference_logger().info("Building model...")
+ self._model = self._policy.build_model(self._config, self._base_mp_group)
+ inference_logger().info("Model built.")
+
+ # Create state manager
+ self._batch = RaggedBatchWrapper(self._config.state_manager)
+ self._state_manager = DSStateManager(self._config.state_manager,
+ self._model.kv_cache_config(),
+ base_mp_group=self._base_mp_group)
+ self._model.set_state_manager(self._state_manager)
+
+ def _initialize_tp_group(self):
+ """
+ Implementation of our TP group initialization.
+ """
+ init_distributed()
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
+ get_accelerator().set_device(local_rank)
+
+ if local_rank >= self._config.tensor_parallel.tp_size:
+ raise RuntimeError("Local rank is greater than TP size, ensure that the TP config is correct.")
+
+ ranks = list(range(self._config.tensor_parallel.tp_size))
+ return dist.new_group(ranks=ranks)
+
+ def put(self, batch_uids: Iterable[int], batch_tokens: Iterable[torch.Tensor]) -> torch.Tensor:
+ """
+ Put a ragged batch onto the inference engine. This will perform one forward and return
+ a Tensor of the shape [len(batch_uids), *output_shape]. Logits for the non-final tokens
+ are not calculated.
+
+ Arguments:
+ batch_uids: Iterable of uids for the batch on the host
+ batch_tokens: Iterable of token tensors for the batch on the host
+ """
+
+ token_lens = [len(tokens) for tokens in batch_tokens]
+ schedule_check = self.can_schedule(batch_uids, token_lens)
+ if schedule_check != SchedulingResult.Success:
+ raise SchedulingError(schedule_check)
+
+ self._batch.clear()
+ for uid, tokens in zip(batch_uids, batch_tokens):
+
+ host_seq_desc = self._state_manager.get_or_create_sequence(uid)
+ self._model.maybe_allocate_kv(host_seq_desc, tokens.numel())
+ host_seq_desc.pre_forward(tokens.numel())
+
+ # We can disable checks since we already validated schedulability.
+ self._batch.insert_sequence(host_seq_desc, tokens, do_checks=False)
+
+ # Send all metadata to the device
+ self._batch.finalize()
+
+ # Prep all data structures for the actual forward (in anticipation of CG in the future)
+ # and also to amortize some of the costs in a more straightforward way.
+ self._model.prepare_batch(self._batch)
+
+ # Model implementation will pick up in the forward.
+ logits = self._model.forward(self._batch)
+
+ # We return one set of logits per sequence in the batch (saves cost on unembedding)
+ assert logits.shape[0] == self._batch.current_sequences
+
+ for uid in batch_uids:
+ host_seq_desc = self._state_manager.get_sequence(uid)
+ host_seq_desc.post_forward() # Updates sequence metadata.
+ self._model.maybe_free_kv(host_seq_desc)
+
+ return logits
+
+ def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, int]:
+ """
+ Determine the number of tokens and KV blocks to reserve for a given request. Given a UID
+ (this UID may not be recognized by the model yet), this will return the number of tokens
+ and blocks to reserve for the request.
+
+ Arguments:
+ uid (int): The UID of the sequence (as tracked by the scheduling entity). If
+ this is a new sequence (with a UID unknown to the inference engine), then
+ an empty placeholder is created to pass to the occupancy logic.
+ n_tokens (int): The number of tokens to hypothetically send.
+
+ Returns:
+ Tuple[int, Optional[int]]: Tuple of free kv blocks and the number of blocks
+ required to schedule the sequence.
+ """
+ seq_desc = self._state_manager.get_sequence(uid)
+ if seq_desc is None:
+ if (self._state_manager.n_tracked_sequences == self._config.state_manager.max_tracked_sequences):
+ return (0, 0)
+ seq_desc = PlaceholderSequenceDescriptor()
+
+ req_tokens, req_blocks = self._model.get_kv_requirements(seq_desc, max_request_tokens, max_request_blocks)
+
+ return (req_tokens, req_blocks)
+
+ def can_schedule(self, uids: Iterable[int], lengths: Iterable[int]) -> SchedulingResult:
+ """
+ Dry run a batch to determine if it can be scheduled. Placeholder sequences will be
+ created for any UIDs that are unknown to the inference engine.
+
+ Arguments:
+ uids (Iterable[int]): Iterable of UIDs for the batch
+ lengths (Iterable[int]): Iterable of lengths for each sequence of the batch. This lengths
+ corresponds to the number of tokens to send in the hypothetical forward; history
+ tokens will be determined via UID lookup and future tokens are disregarded.
+
+ Returns:
+ bool: True if the batch can be scheduled, False otherwise.
+ """
+
+ cur_seqs = self._state_manager.n_tracked_sequences
+ free_blocks = self._state_manager.free_blocks
+ req_blocks = 0
+ batch_len = 0
+
+ if len(uids) > self._config.state_manager.max_ragged_sequence_count:
+ # Can only compose a batch from a limited number of sequences
+ return SchedulingResult.BatchSequenceLimitExceeded
+
+ for uid, length in zip(uids, lengths):
+ seq_desc = self._state_manager.get_sequence(uid)
+ if seq_desc is None:
+ cur_seqs += 1
+ seq_desc = PlaceholderSequenceDescriptor()
+
+ sched_len, sched_blocks = self._model.get_kv_requirements(seq_desc, length, free_blocks)
+
+ if sched_len != length:
+ # We ran out of KV cache
+ return SchedulingResult.KVCacheLimitExceeded
+
+ batch_len += length
+ free_blocks -= sched_blocks
+
+ if cur_seqs > self._config.state_manager.max_tracked_sequences:
+ # Would run out of tracking metadata
+ return SchedulingResult.EngineSequenceLimitExceeded
+
+ if batch_len > self._config.state_manager.max_ragged_batch_size:
+ # Would exceed the maximum batch size
+ return SchedulingResult.BatchTokenLimitExceeded
+
+ return SchedulingResult.Success
+
+ def flush(self, uid: int) -> None:
+ """
+ Remove all state associated with a sequence from the inference engine.
+
+ Arguments:
+ uid (int): The UID of the sequence to flush.
+ """
+ self._state_manager.flush_sequence(uid)
diff --git a/deepspeed/inference/v2/inference_utils.py b/deepspeed/inference/v2/inference_utils.py
new file mode 100644
index 000000000000..7b2dd4237353
--- /dev/null
+++ b/deepspeed/inference/v2/inference_utils.py
@@ -0,0 +1,105 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from typing import Dict
+
+import torch
+
+from enum import Enum, IntEnum
+
+
+class NormTypeEnum(Enum):
+ LayerNorm: str = "layer_norm"
+ RMSNorm: str = "rms_norm"
+
+
+class DtypeEnum(Enum):
+ # The torch dtype must always be the first value (so we return torch.dtype)
+ fp16 = torch.float16, "torch.float16", "fp16", "float16", "half"
+ fp32 = torch.float32, "torch.float32", "fp32", "float32", "float"
+ bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat"
+ int8 = torch.int8, "torch.int8", "int8"
+
+ # Copied from https://stackoverflow.com/a/43210118
+ # Allows us to use multiple values for each Enum index and returns first
+ # listed value when Enum is called
+ def __new__(cls, *values):
+ obj = object.__new__(cls)
+ # first value is canonical value
+ obj._value_ = values[0]
+ for other_value in values[1:]:
+ cls._value2member_map_[other_value] = obj
+ obj._all_values = values
+ return obj
+
+ def __repr__(self):
+ return "<%s.%s: %s>" % (
+ self.__class__.__name__,
+ self._name_,
+ ", ".join([repr(v) for v in self._all_values]),
+ )
+
+
+ELEM_SIZES: Dict[torch.dtype, int] = {
+ torch.float16: 2,
+ torch.bfloat16: 2,
+ torch.float32: 4,
+ torch.float64: 8,
+ torch.int8: 1,
+ torch.uint8: 1,
+ torch.int16: 2,
+ torch.int32: 4,
+ torch.int64: 8,
+ torch.bool: 1,
+}
+
+
+class ActivationType(IntEnum):
+ """
+ Types of activations supported by DS-Inference
+ """
+
+ GELU = 0
+
+ RELU = 1
+
+ SILU = 2
+
+ GEGLU = 3
+
+ ReGLU = 4
+
+ SiGLU = 5
+
+ IDENTITY = 6
+
+ InvalidType = -1
+
+
+def is_gated(act_fn: ActivationType) -> bool:
+ """
+ Return True if the given activation function is gated.
+ """
+ if not isinstance(act_fn, ActivationType):
+ act_fn = ActivationType(act_fn)
+
+ return act_fn in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]
+
+
+def elem_size(dtype: torch.dtype) -> int:
+ """
+ Return size in bytes of the given dtype.
+ """
+ try:
+ return ELEM_SIZES[dtype]
+ except KeyError:
+ raise ValueError("Unknown dtype size for {}".format(dtype))
+
+
+def ceil_div(a: int, b: int) -> int:
+ """
+ Return ceil(a / b).
+ """
+ return -(-a // b)
diff --git a/deepspeed/inference/v2/kernels/__init__.py b/deepspeed/inference/v2/kernels/__init__.py
new file mode 100644
index 000000000000..01b7b0580073
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .ds_kernel import DSKernelBase
diff --git a/deepspeed/inference/v2/kernels/core_ops/__init__.py b/deepspeed/inference/v2/kernels/core_ops/__init__.py
new file mode 100644
index 000000000000..bbb53e5b58a2
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .bias_activations import *
+from .blas_kernels import *
+from .cuda_layer_norm import *
+from .cuda_rms_norm import *
+from .gated_activations import *
diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py b/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py
new file mode 100644
index 000000000000..ea7f8a7d1996
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .bias_activation import *
diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp
new file mode 100644
index 000000000000..4f0cc9cbd77c
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp
@@ -0,0 +1,68 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "bias_activation.h"
+#include
+#include "ds_kernel_utils.h"
+
+#ifdef BF16_AVAILABLE
+#define DTYPE_SWITCH(DTYPE, ...) \
+ [&] { \
+ if (DTYPE == torch::kFloat16) { \
+ using scalar_t = __half; \
+ return __VA_ARGS__(); \
+ } else if (DTYPE == torch::kBFloat16) { \
+ using scalar_t = __nv_bfloat16; \
+ return __VA_ARGS__(); \
+ } else { \
+ TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
+ } \
+ }()
+#else
+#define DTYPE_SWITCH(DTYPE, ...) \
+ [&] { \
+ if (DTYPE == torch::kFloat16) { \
+ using scalar_t = __half; \
+ return __VA_ARGS__(); \
+ } else { \
+ TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
+ } \
+ }()
+#endif
+
+/*
+In-place bias and activation fusion kernel.
+*/
+void bias_activation(torch::Tensor& activation,
+ c10::optional& bias,
+ const int32_t act_type)
+{
+ const ActivationType atype = static_cast(act_type);
+ const int32_t rows = activation.size(0);
+ const int32_t cols = activation.size(1);
+
+ TORCH_CHECK(atype == ActivationType::GELU || atype == ActivationType::RELU ||
+ atype == ActivationType::SILU || atype == ActivationType::IDENTITY,
+ "Unsupported activation type for BiasActivation");
+ TORCH_CHECK(activation.dim() == 2, "BiasActivation only supports 2D activation tensors");
+
+ DTYPE_SWITCH(activation.scalar_type(), [&] {
+ scalar_t* activation_ptr = reinterpret_cast(activation.data_ptr());
+
+ const scalar_t* bias_ptr;
+ if (bias.has_value()) {
+ TORCH_CHECK(activation.scalar_type() == bias.value().scalar_type(),
+ "BiasActivation activation and bias must have same dtype");
+ bias_ptr = reinterpret_cast(bias.value().data_ptr());
+ } else {
+ bias_ptr = nullptr;
+ }
+
+ if (atype == ActivationType::IDENTITY && bias_ptr == nullptr) { return; }
+
+ launch_bias_activation(
+ activation_ptr, bias_ptr, rows, cols, atype, c10::cuda::getCurrentCUDAStream());
+ });
+}
diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cu b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cu
new file mode 100644
index 000000000000..66bca0c175c3
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cu
@@ -0,0 +1,140 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include "activation_type.h"
+#include "conversion_utils.h"
+#include "ds_kernel_utils.h"
+#include "memory_access_utils.h"
+
+// Default activation function will error out
+template
+DS_D_INLINE float act_fn(float val);
+
+template <>
+DS_D_INLINE float act_fn(float val)
+{
+ return val;
+}
+
+template <>
+DS_D_INLINE float act_fn(float val)
+{
+ return val > 0.0f ? val : 0.0f;
+}
+
+template <>
+DS_D_INLINE float act_fn(float val)
+{
+ constexpr float sqrt_param = 0.79788456080286535587989211986876f;
+ constexpr float mul_param = 0.044715f;
+ return val * 0.5f * (1.0f + tanhf(sqrt_param * (val + mul_param * val * val * val)));
+}
+
+template <>
+DS_D_INLINE float act_fn(float val)
+{
+ return val / (1.0f + expf(-val));
+}
+
+namespace bias_act {
+
+constexpr int access_size = 16;
+constexpr int threads = 512;
+constexpr int unroll = 4;
+
+} // namespace bias_act
+
+template
+__global__ void bias_activation_kernel(T* activation,
+ const T* bias,
+ const int32_t rows,
+ const int32_t cols)
+{
+ constexpr int vector_T = bias_act::access_size / sizeof(T);
+
+ const int32_t thread_offset = threadIdx.x * vector_T;
+ const int32_t block_offset = blockIdx.x * vector_T * bias_act::unroll * bias_act::threads;
+ const int32_t base_offset = block_offset + thread_offset;
+
+ const int32_t thread_stride = bias_act::threads * vector_T;
+
+#pragma unroll
+ for (int i = 0; i < bias_act::unroll; i++) {
+ const int32_t iter_offset = base_offset + i * thread_stride;
+
+ const int32_t row = iter_offset / cols;
+
+ T buffer[vector_T];
+ T bias_buffer[vector_T];
+
+ if (row < rows) {
+ const int32_t col = iter_offset % cols;
+
+ mem_access::load_global(buffer, activation + iter_offset);
+ mem_access::load_global(
+ bias_buffer, bias + col, bias != nullptr);
+
+#pragma unroll
+ for (int j = 0; j < vector_T; j++) {
+ float val =
+ conversion::to(buffer[j]) + conversion::to(bias_buffer[j]);
+ buffer[j] = conversion::to(act_fn(val));
+ }
+
+ mem_access::store_global(activation + iter_offset, buffer);
+ }
+ }
+}
+
+#define ACT_TYPE_SWITCH(ACT_TYPE, ...) \
+ if (ACT_TYPE == ActivationType::IDENTITY) { \
+ constexpr ActivationType act_fn_t = ActivationType::IDENTITY; \
+ return __VA_ARGS__(); \
+ } else if (ACT_TYPE == ActivationType::RELU) { \
+ constexpr ActivationType act_fn_t = ActivationType::RELU; \
+ return __VA_ARGS__(); \
+ } else if (ACT_TYPE == ActivationType::GELU) { \
+ constexpr ActivationType act_fn_t = ActivationType::GELU; \
+ return __VA_ARGS__(); \
+ } else if (ACT_TYPE == ActivationType::SILU) { \
+ constexpr ActivationType act_fn_t = ActivationType::SILU; \
+ return __VA_ARGS__(); \
+ } else { \
+ assert(false); \
+ }
+
+template
+void launch_bias_activation(T* activation,
+ const T* bias,
+ const int32_t n_rows,
+ const int32_t n_cols,
+ const ActivationType activation_type,
+ cudaStream_t stream)
+{
+ constexpr int32_t elems_per_block =
+ bias_act::threads * bias_act::unroll * bias_act::access_size / sizeof(T);
+ const int32_t total_elems = n_rows * n_cols;
+
+ const int32_t blocks = (total_elems + elems_per_block - 1) / elems_per_block;
+
+ const dim3 grid(blocks);
+ const dim3 block(bias_act::threads);
+
+ ACT_TYPE_SWITCH(activation_type, [&] {
+ bias_activation_kernel
+ <<>>(activation, bias, n_rows, n_cols);
+ });
+}
+
+#define INSTANTIATE_FOR_T(T) \
+ template void launch_bias_activation( \
+ T*, const T*, const int32_t, const int32_t, const ActivationType, cudaStream_t);
+
+INSTANTIATE_FOR_T(__half);
+
+#ifdef BF16_AVAILABLE
+INSTANTIATE_FOR_T(__nv_bfloat16);
+#endif
diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h
new file mode 100644
index 000000000000..db6174633a09
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include
+#include
+#include "activation_type.h"
+
+template
+void launch_bias_activation(T* activation,
+ const T* bias,
+ const int32_t n_rows,
+ const int32_t n_cols,
+ const ActivationType activation_type,
+ cudaStream_t stream);
+
+void bias_activation(torch::Tensor& activation,
+ c10::optional& bias,
+ const int32_t activation_type);
diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py
new file mode 100644
index 000000000000..436d7f8805d5
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py
@@ -0,0 +1,62 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from typing import Optional
+
+import torch
+
+from ....inference_utils import ActivationType, DtypeEnum
+from deepspeed.ops.op_builder import InferenceCoreBuilder
+from ... import DSKernelBase
+
+
+class CUDABiasActivation(DSKernelBase):
+ """
+ CUDA implementation of bias activation kernel. This kernel should be deprecated once
+ we are fusing the bias activation into the linear kernel in all scenarios.
+ """
+
+ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16]
+ supported_act_fns = [ActivationType.IDENTITY, ActivationType.GELU, ActivationType.RELU, ActivationType.SILU]
+
+ def __init__(self, channels: int, dtype: DtypeEnum, act_fn: ActivationType) -> None:
+ """
+ Compile and validate for the fused bias-activation kernel.
+
+ Parameters:
+ channels (int): Number of channels to expect in the activation.
+ dtype (torch.dtype): Data type for the input/output. Supported values
+ are DtypeEnum.fp16 and DtypeEnum.bf16.
+ act_fn (ActivationType): Activation function to use. Only IDENTITY, GELU, RELU, and SILU are supported.
+ """
+
+ if channels % 8 != 0:
+ raise ValueError("channels must be divisible by 8")
+
+ if DtypeEnum(dtype) not in CUDABiasActivation.supported_dtypes:
+ raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
+ dtype, CUDABiasActivation.supported_dtypes))
+
+ act_fn = ActivationType(act_fn)
+ if act_fn not in CUDABiasActivation.supported_act_fns:
+ raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
+ act_fn, CUDABiasActivation.supported_act_fns))
+
+ inf_module = InferenceCoreBuilder().load()
+ self.kernel = inf_module.bias_activation
+ self.act_fn = act_fn
+
+ def __call__(self, activation: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Add an optional bias and perform the non-linear activation function.
+
+ Parameters:
+ activation (torch.Tensor): Input tensor of shape [tokens, channels]
+ bias (torch.Tensor): Optional bias tensor of shape [channels]
+
+ Returns:
+ activation that has been updated in-place
+ """
+ self.kernel(activation, bias, self.act_fn.value)
diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py
new file mode 100644
index 000000000000..4af5a579ca1b
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .blas_linear import *
diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h
new file mode 100644
index 000000000000..1854e40a227d
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h
@@ -0,0 +1,138 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include
+#include
+#include
+#include "blas_utils.h"
+
+#define DISPATCH_BLAS_MATMUL(T_TYPE, C_TYPE) \
+ if (output.options().dtype() == torch::T_TYPE) { \
+ blas_gemm_ex(output.data_ptr(), \
+ (const void*)weights.data_ptr(), \
+ (const void*)hidden_states.data_ptr(), \
+ m, \
+ n, \
+ k, \
+ lda, \
+ ldb, \
+ ldc, \
+ trans_a, \
+ trans_b, \
+ &alpha, \
+ &beta, \
+ C_TYPE); \
+ }
+
+void blas_linear(at::Tensor& output, at::Tensor& hidden_states, at::Tensor& weights)
+{
+ /*
+ Expected shape: output([total_tokens_across_dims], out_neurons)
+ hidden_states([total_tokens_across_dims], in_neurons)
+ weights(out_neurons, in_neurons)
+
+ We are going to assume contiguous for the above shapes.
+
+ The shapes are going to get messed with a little internally to handle column-major
+ GEMMs.
+ */
+
+ // Number of tokens is N (since the GEMM output is column-major but our Tensor
+ // is row-major, we need to transpose the shapes)
+ const int n = output.numel() / output.size(-1);
+ const int k = weights.size(1);
+ const int m = weights.size(0);
+
+ // A strides
+ const bool trans_a = weights.stride(1) == 1;
+ const int lda = (trans_a) ? weights.stride(0) : weights.stride(1);
+
+ // B strides
+ const bool trans_b = hidden_states.stride(-1) != 1;
+ const int ldb = (trans_b) ? hidden_states.stride(-1) : hidden_states.stride(-2);
+
+ // C strides
+ const int ldc = output.stride(-2);
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ TORCH_CHECK(output.scalar_type() == hidden_states.scalar_type(),
+ "Output and hidden states must have the same scalar type");
+ TORCH_CHECK(output.scalar_type() == weights.scalar_type(),
+ "Output and weights must have the same scalar type");
+
+ // Dispatch the datatypes
+ DISPATCH_BLAS_MATMUL(kFloat, BlasType::FP32);
+ DISPATCH_BLAS_MATMUL(kHalf, BlasType::FP16);
+#ifdef BF16_AVAILABLE
+ DISPATCH_BLAS_MATMUL(kBFloat16, BlasType::BF16);
+#endif
+}
+
+#define DISPATCH_4D_BLAS(T_TYPE, C_TYPE) \
+ if (C.options().dtype() == torch::T_TYPE) { \
+ blas_strided_batched_gemm(C.data_ptr(), \
+ (const void*)A.data_ptr(), \
+ (const void*)B.data_ptr(), \
+ m, \
+ n, \
+ k, \
+ lda, \
+ ldb, \
+ ldc, \
+ trans_a, \
+ trans_b, \
+ &alpha, \
+ &beta, \
+ stride_a, \
+ stride_b, \
+ stride_c, \
+ batch, \
+ C_TYPE); \
+ }
+
+void blas_4d_matmul(at::Tensor& C, at::Tensor& B, at::Tensor& A)
+{
+ /*
+ C shape: (batch_size, N, M)
+ A shape: (batch_size, N, K)
+ B shape: (batch_size, K, M)
+ */
+
+ const int n = C.size(-2);
+ const int k = C.size(-1);
+ const int m = B.size(-1);
+
+ // A strides
+ const bool trans_a = A.stride(-1) == 1;
+ const int lda = (trans_a) ? A.stride(-2) : A.stride(-1);
+ const int stride_a = A.stride(-3);
+
+ // B strides
+ const bool trans_b = B.stride(-1) != 1;
+ const int ldb = (trans_b) ? B.stride(-1) : B.stride(-2);
+ const int stride_b = B.stride(-3);
+
+ // C strides
+ const int ldc = C.stride(-2);
+ const int stride_c = C.stride(-3);
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ const int batch = C.numel() / (n * m);
+
+ // Dispatch the datatypes
+ DISPATCH_4D_BLAS(kFloat, BlasType::FP32);
+ DISPATCH_4D_BLAS(kHalf, BlasType::FP16);
+#ifdef BF16_AVAILABLE
+ DISPATCH_4D_BLAS(kBFloat16, BlasType::BF16);
+#endif
+}
+
+void create_handle() { BlasContext::getInstance().get_handle(); }
diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py
new file mode 100644
index 000000000000..9a151ce36dc4
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py
@@ -0,0 +1,55 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+
+from ....inference_utils import DtypeEnum
+from deepspeed.ops.op_builder import InferenceCoreBuilder
+from ... import DSKernelBase
+
+
+class BlasLibLinear(DSKernelBase):
+ """
+ Wrapper around the BLAS matmul kernel for FP16/BF16/FP32 for CUDA/RoCM.
+
+ Performs z = x @ y
+ """
+
+ supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]
+
+ def __init__(self, fp_dtype: DtypeEnum):
+ """
+ Parameters:
+ fp_dtype (torch.dtype): Data type for the input/output. Supported values
+ are torch.float16, torch.bfloat16, and torch.float32.
+ """
+ fp_dtype = DtypeEnum(fp_dtype)
+ if fp_dtype not in BlasLibLinear.supported_dtypes:
+ raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
+ fp_dtype, BlasLibLinear.supported_dtypes))
+
+ self.inf_module = InferenceCoreBuilder().load()
+ self.inf_module.create_handle()
+ self.kernel = self.inf_module.blas_linear
+
+ def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
+ """
+ Matmul kernel as implemented by platform BLAS library. The input must be 2D or larger. If
+ n-dimensional, the leading dimensions are folded into each other:
+ 2D: m = x.size(0)
+ 3D: m = x.size(0) * x.size(1)
+ 4D: m = x.size(0) * x.size(1) * x.size(2) (etc...)
+ All inputs should be contiguous.
+
+ Parameters:
+ output (torch.Tensor): Output tensor. Shape is of [*, out_features]
+ hidden_states (torch.Tensor): Input tensor. Shape is of [*, in_features]
+ weights (torch.Tensor): Input tensor. Shape is of [out_features, in_features]
+
+ Returns:
+ z (torch.Tensor): Output tensor. Shape is of [m, n]
+ """
+ self.kernel(output, hidden_states, weights)
+ return output
diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h
new file mode 100644
index 000000000000..450991b3c387
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h
@@ -0,0 +1,275 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include
+#include
+#include
+#ifdef BF16_AVAILABLE
+#include
+#endif
+#include
+#include
+#ifndef __HIP_PLATFORM_HCC__
+#include
+#endif
+#include
+#include
+#include
+
+class BlasContext {
+ /*
+ Slim wrapper for managing the lifetime of the platform's BLAS handle. This should
+ be hipified for ROCm.
+ */
+public:
+ BlasContext()
+ {
+ if (cublasCreate(&_handle) != CUBLAS_STATUS_SUCCESS) {
+ auto message = std::string("Fail to create cublas handle.");
+ std::cerr << message << std::endl;
+ throw std::runtime_error(message);
+ }
+#ifndef __HIP_PLATFORM_HCC__
+ cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH);
+#endif
+ }
+
+ virtual ~BlasContext() { cublasDestroy(_handle); }
+
+ static BlasContext& getInstance()
+ {
+ // Should always access the singleton through this function.
+ static BlasContext _instance;
+ return _instance;
+ }
+
+ cublasHandle_t get_handle() const { return _handle; }
+
+private:
+ cublasHandle_t _handle;
+};
+
+enum class BlasType { FP32, FP16, BF16 };
+
+#ifdef __HIP_PLATFORM_HCC__
+rocblas_operation get_trans_op(bool do_trans)
+{
+ return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none;
+}
+
+rocblas_datatype get_datatype(BlasType type)
+{
+ switch (type) {
+ case BlasType::FP32: return rocblas_datatype_f32_r;
+ case BlasType::FP16: return rocblas_datatype_f16_r;
+ case BlasType::BF16: return rocblas_datatype_bf16_r;
+ default: throw std::runtime_error("Unsupported BlasType");
+ }
+}
+#else
+cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T : CUBLAS_OP_N; }
+
+cublasDataType_t get_datatype(BlasType type)
+{
+ switch (type) {
+ case BlasType::FP32: return CUDA_R_32F;
+ case BlasType::FP16: return CUDA_R_16F;
+ case BlasType::BF16: return CUDA_R_16BF;
+ default: throw std::runtime_error("Unsupported BlasType");
+ }
+}
+#endif
+
+int blas_gemm_ex(void* C,
+ const void* A,
+ const void* B,
+ int m,
+ int n,
+ int k,
+ int lda,
+ int ldb,
+ int ldc,
+ bool transa,
+ bool transb,
+ const float* alpha,
+ const float* beta,
+ BlasType type)
+{
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_operation_t transa_op = get_trans_op(transa);
+ rocblas_operation_t transb_op = get_trans_op(transb);
+
+ rocblas_datatype_t abc_type = get_datatype(type);
+
+ rocblas_status status = rocblas_gemm_ex(BlasContext::getInstance().get_handle(),
+ transa_op,
+ transb_op,
+ m,
+ n,
+ k,
+ (const void*)alpha,
+ A,
+ abc_type,
+ lda,
+ B,
+ abc_type,
+ ldb,
+ (const void*)beta,
+ C,
+ abc_type,
+ ldc,
+ C,
+ abc_type,
+ ldc,
+ rocblas_datatype_f32_r,
+ rocblas_gemm_algo_standard,
+ 0,
+ 0);
+#else
+ cublasOperation_t transa_op = get_trans_op(transa);
+ cublasOperation_t transb_op = get_trans_op(transb);
+
+ cublasDataType_t abc_type = get_datatype(type);
+ cublasStatus_t status = cublasGemmEx(BlasContext::getInstance().get_handle(),
+ transa_op,
+ transb_op,
+ m,
+ n,
+ k,
+ (const void*)alpha,
+ A,
+ abc_type,
+ lda,
+ B,
+ abc_type,
+ ldb,
+ (const void*)beta,
+ C,
+ abc_type,
+ ldc,
+ CUDA_R_32F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+#endif
+
+#ifdef __HIP_PLATFORM_HCC__
+ if (status != rocblas_status_success) {
+#else
+ if (status != CUBLAS_STATUS_SUCCESS) {
+#endif
+ fprintf(stderr,
+ "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
+ m,
+ n,
+ k,
+ (int)status);
+ return EXIT_FAILURE;
+ }
+ return 0;
+}
+
+int blas_strided_batched_gemm(void* C,
+ const void* A,
+ const void* B,
+ int m,
+ int n,
+ int k,
+ int lda,
+ int ldb,
+ int ldc,
+ bool transa,
+ bool transb,
+ const float* alpha,
+ const float* beta,
+ int stride_A,
+ int stride_B,
+ int stride_C,
+ int batch,
+ BlasType type)
+{
+#ifdef __HIP_PLATFORM_HCC__
+ rocblas_operation_t transa_op = get_trans_op(transa);
+ rocblas_operation_t transb_op = get_trans_op(transb);
+
+ rocblas_datatype_t abc_type = get_datatype(type);
+
+ rocblas_status status =
+ rocblas_gemm_strided_batched_ex(BlasContext::getInstance()::get_handle(),
+ transa_op,
+ transb_op,
+ m,
+ n,
+ k,
+ (const void*)alpha,
+ A,
+ abc_type,
+ lda,
+ stride_A,
+ B,
+ abc_type,
+ ldb,
+ stride_B,
+ (const void*)beta,
+ C,
+ abc_type,
+ ldc,
+ stride_C,
+ C,
+ abc_type,
+ ldc,
+ stride_C,
+ batch,
+ rocblas_datatype_f32_r,
+ rocblas_gemm_algo_standard,
+ 0,
+ 0);
+#else
+ cublasOperation_t transa_op = get_trans_op(transa);
+ cublasOperation_t transb_op = get_trans_op(transb);
+
+ cublasDataType_t abc_type = get_datatype(type);
+
+ cublasStatus_t status = cublasGemmStridedBatchedEx(BlasContext::getInstance().get_handle(),
+ transa_op,
+ transb_op,
+ m,
+ n,
+ k,
+ (const void*)alpha,
+ A,
+ abc_type,
+ lda,
+ stride_A,
+ B,
+ abc_type,
+ ldb,
+ stride_B,
+ (const void*)beta,
+ C,
+ abc_type,
+ ldc,
+ stride_C,
+ batch,
+ CUDA_R_32F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+#endif
+
+#ifdef __HIP_PLATFORM_HCC__
+ if (status != rocblas_status_success) {
+#else
+ if (status != CUBLAS_STATUS_SUCCESS) {
+#endif
+ fprintf(stderr,
+ "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
+ batch,
+ m,
+ n,
+ k,
+ (int)status);
+ return EXIT_FAILURE;
+ }
+ return 0;
+}
diff --git a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp
new file mode 100644
index 000000000000..58df88e56136
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include
+
+#include "bias_activation.h"
+#include "blas.h"
+#include "gated_activation_kernels.h"
+#include "layer_norm.h"
+#include "rms_norm.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ // bias_activation.h
+ m.def("bias_activation", &bias_activation, "DeepSpeed bias activation in CUDA");
+
+ // layer_norm.h
+ m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm in CUDA");
+ m.def("pre_layer_norm", &ds_pre_layer_norm, "DeepSpeed pre layer norm in CUDA");
+ m.def("post_layer_norm", &ds_post_layer_norm, "DeepSpeed pre layer norm in CUDA");
+
+ // blas.h
+ m.def("blas_linear", &blas_linear, "Linear implemented by vendor BLAS");
+ m.def("blas_4d_matmul", &blas_4d_matmul, "4D matmul implemented by vendor BLAS");
+ m.def("create_handle", &create_handle, "Create a handle for vendor BLAS");
+
+ // gated_activation_kernels.h
+ m.def("gated_activation", &ds_gated_activation, "DeepSpeed gated activation in CUDA");
+
+ // rms_norm.h
+ m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA");
+ m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA");
+}
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py
new file mode 100644
index 000000000000..bed7688b15d2
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .cuda_ln import *
+from .cuda_post_ln import *
+from .cuda_pre_ln import *
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py
new file mode 100644
index 000000000000..3c2aa5cb5eb4
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py
@@ -0,0 +1,37 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+
+from ... import DSKernelBase
+from ....inference_utils import elem_size
+from deepspeed.ops.op_builder import InferenceCoreBuilder
+
+
+class CUDAFPLNBase(DSKernelBase):
+ """
+ Base class for CUDA LN kernels. They all same the same validation logic,
+ so we can share it here.
+ """
+
+ supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5):
+ """
+ Parameters:
+ channels (int): Number of channels in the input tensor. Must be divisible to align
+ to 16 bytes.
+ fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
+ are torch.float16, torch.bfloat16, and torch.float32.
+ """
+ if fp_dtype not in CUDAFPLNBase.supported_dtypes:
+ raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
+ fp_dtype, CUDAFPLNBase.supported_dtypes))
+
+ if elem_size(fp_dtype) * channels % 16 != 0:
+ raise ValueError("channels must be divisible by 16 bytes")
+
+ self.inf_module = InferenceCoreBuilder().load()
+ self.epsilon = epsilon
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py
new file mode 100644
index 000000000000..583736fb8bbc
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py
@@ -0,0 +1,30 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+
+from .cuda_fp_ln_base import CUDAFPLNBase
+
+
+class CUDAFPLN(CUDAFPLNBase):
+ """
+ Floating point layer norm kernel for CUDA/RoCM.
+
+ Performs: z = ln(x)
+ """
+
+ def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor,
+ beta: torch.Tensor) -> torch.Tensor:
+ """
+ output_z may alias input_x directly. All Tensors should have the same shape.
+
+ Parameters:
+ output_z (torch.Tensor): Output tensor.
+ input_x (torch.Tensor): Input tensor.
+ gamma (torch.Tensor): Gamma tensor.
+ beta (torch.Tensor): Beta tensor.
+ """
+ self.inf_module.layer_norm(output_z, input_x, gamma, beta, self.epsilon)
+ return output_z
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py
new file mode 100644
index 000000000000..0ced1ecf207e
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py
@@ -0,0 +1,34 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+
+from .cuda_fp_ln_base import CUDAFPLNBase
+
+
+class CUDAFPPostLN(CUDAFPLNBase):
+ """
+ Floating point post-LayerNorm kernel for CUDA/RoCM.
+
+ Performs: z = ln(x + y)
+ """
+
+ def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, input_y: torch.Tensor, gamma: torch.Tensor,
+ beta: torch.Tensor) -> torch.Tensor:
+ """
+ Either input_x or input_y can alias output_z.
+
+ Parameters:
+ output_z (torch.Tensor): Output tensor.
+ input_x (torch.Tensor): Input tensor.
+ input_y (torch.Tensor): Input tensor.
+ gamma (torch.Tensor): Gamma tensor.
+ beta (torch.Tensor): Beta tensor.
+
+ Returns:
+ output (torch.Tensor): Output tensor.
+ """
+ self.inf_module.post_layer_norm(output_z, input_x, input_y, gamma, beta, self.epsilon)
+ return output_z
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py
new file mode 100644
index 000000000000..74b2d9cf5880
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from typing import Tuple
+
+import torch
+
+from .cuda_fp_ln_base import CUDAFPLNBase
+
+
+class CUDAFPPreLN(CUDAFPLNBase):
+ """
+ Floating point pre-LayerNorm kernel for CUDA/RoCM.
+
+ Performs: z_res = x_res + y_hid
+ z_hid = ln(z_hid)
+ """
+
+ def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor,
+ gamma: torch.Tensor, beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ z_res can alias x_res. All non-parameter input/output tensors
+ must have the same shape. z_hid can alias y_hid.
+
+ Parameters:
+ z_res (torch.Tensor): Output residual.
+ z_hid (torch.Tensor): Output hidden states.
+ x_res (torch.Tensor): Input residual.
+ y_hid (torch.Tensor): Input hidden states.
+ gamma (torch.Tensor): Gamma tensor.
+ beta (torch.Tensor): Beta tensor.
+
+ Returns:
+ output (torch.Tensor): Output tensor.
+ """
+ self.inf_module.pre_layer_norm(z_res, z_hid, x_res, y_hid, gamma, beta, self.epsilon)
+ return z_res, z_hid
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp
new file mode 100644
index 000000000000..b2c95d410a1f
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp
@@ -0,0 +1,102 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "layer_norm.h"
+
+#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \
+ if (input.options().dtype() == torch::T_TYPE) { \
+ launch_fused_ln((C_TYPE*)output.data_ptr(), \
+ (const C_TYPE*)input.data_ptr(), \
+ (const C_TYPE*)gamma.data_ptr(), \
+ (const C_TYPE*)beta.data_ptr(), \
+ epsilon, \
+ rows, \
+ elems_per_row, \
+ at::cuda::getCurrentCUDAStream()); \
+ }
+
+void ds_layer_norm(at::Tensor& output,
+ at::Tensor& input,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ float epsilon)
+{
+ bool ragged_input = input.dim() == 2;
+
+ const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
+ const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
+
+ DISPATCH_LAYER_NORM(kFloat, float);
+ DISPATCH_LAYER_NORM(kHalf, __half);
+#ifdef BF16_AVAILABLE
+ DISPATCH_LAYER_NORM(kBFloat16, __nv_bfloat16);
+#endif
+}
+
+#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \
+ if (input.options().dtype() == torch::T_TYPE) { \
+ launch_fused_post_ln((C_TYPE*)output.data_ptr(), \
+ (const C_TYPE*)input.data_ptr(), \
+ (const C_TYPE*)residual.data_ptr(), \
+ (const C_TYPE*)gamma.data_ptr(), \
+ (const C_TYPE*)beta.data_ptr(), \
+ epsilon, \
+ rows, \
+ elems_per_row, \
+ at::cuda::getCurrentCUDAStream()); \
+ }
+
+void ds_post_layer_norm(at::Tensor& output,
+ at::Tensor& input,
+ at::Tensor& residual,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ float epsilon)
+{
+ bool ragged_input = input.dim() == 2;
+
+ const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
+ const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
+
+ DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float);
+ DISPATCH_LAYER_NORM_RESIDUAL(kHalf, __half);
+#ifdef BF16_AVAILABLE
+ DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16);
+#endif
+}
+
+#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \
+ if (input.options().dtype() == torch::T_TYPE) { \
+ launch_fused_pre_ln((C_TYPE*)norm_output.data_ptr(), \
+ (C_TYPE*)res_output.data_ptr(), \
+ (const C_TYPE*)input.data_ptr(), \
+ (const C_TYPE*)residual.data_ptr(), \
+ (const C_TYPE*)gamma.data_ptr(), \
+ (const C_TYPE*)beta.data_ptr(), \
+ epsilon, \
+ rows, \
+ elems_per_row, \
+ at::cuda::getCurrentCUDAStream()); \
+ }
+
+void ds_pre_layer_norm(at::Tensor& res_output,
+ at::Tensor& norm_output,
+ at::Tensor& input,
+ at::Tensor& residual,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ float epsilon)
+{
+ bool ragged_input = input.dim() == 2;
+
+ const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
+ const int elems_per_row = ragged_input ? input.size(1) : input.size(2);
+
+ DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float);
+ DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, __half);
+#ifdef BF16_AVAILABLE
+ DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16);
+#endif
+}
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu
new file mode 100644
index 000000000000..15f52c46622b
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu
@@ -0,0 +1,490 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "conversion_utils.h"
+#include "ds_kernel_utils.h"
+#include "memory_access_utils.h"
+#include "reduction_utils.h"
+
+namespace cg = cooperative_groups;
+using rop = reduce::ROpType;
+
+namespace ln {
+constexpr int granularity = 16;
+} // namespace ln
+
+/*
+Regular layer norm implementation. Assumes elems_per_row % 8
+is equal to 0.
+
+Args:
+ output: buffer for output data
+ vals: buffer for input data
+ gamma: gain for normalization
+ beta: bias for normalization
+ epsilon: numeric stability
+ elems_per_row: number of elements each block will normalize
+*/
+template
+__global__ void fused_ln(T* output,
+ const T* vals,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int elems_per_row)
+{
+ constexpr int T_per_load = ln::granularity / sizeof(T);
+
+ cg::thread_block tb = cg::this_thread_block();
+ cg::thread_block_tile warp = cg::tiled_partition(tb);
+
+ // X-dimension of the block
+ const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
+ (tb.thread_index().y * elems_per_row);
+ const int thread_offset = tb.thread_index().x * T_per_load;
+ const int base_offset = block_offset + thread_offset;
+ const int stride = blockDim.x * T_per_load;
+
+ float sum = reduce::init();
+
+ const T* input_base = vals + base_offset;
+
+ T local_buffer[unRoll * T_per_load];
+
+#pragma unRoll
+ for (int i = 0; i < unRoll; i++) {
+ T* iteration_buffer = local_buffer + i * T_per_load;
+
+ mem_access::load_global(
+ iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
+
+#pragma unRoll
+ for (int j = 0; j < T_per_load; j++) {
+ float vals_up_cast = conversion::to(iteration_buffer[j]);
+ sum = reduce::element(sum, vals_up_cast);
+ }
+ }
+
+ reduce::partitioned_block(tb, warp, sum);
+ const float mean = sum / elems_per_row;
+
+ float mean_diff = reduce::init();
+
+#pragma unRoll
+ for (int i = 0; i < unRoll; i++) {
+#pragma unRoll
+ for (int j = 0; j < T_per_load; j++) {
+ // Using a 0 value here skews the variance, have to if-guard
+ if (thread_offset + i * stride < elems_per_row) {
+ float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean);
+ mean_diff = reduce::element(mean_diff, diff * diff);
+ }
+ }
+ }
+
+ reduce::partitioned_block(tb, warp, mean_diff);
+ const float variance = mean_diff / elems_per_row;
+ const float denom = __frsqrt_rn(variance + epsilon);
+
+ T* block_output = output + block_offset;
+
+#pragma unRoll
+ for (int i = 0; i < unRoll; i++) {
+ T* iteration_buffer = local_buffer + i * T_per_load;
+ const int iter_idx = i * stride + thread_offset;
+ const bool do_loads = iter_idx < elems_per_row;
+
+ T gamma_local[T_per_load], beta_local[T_per_load];
+
+ mem_access::load_global(gamma_local, gamma + iter_idx, do_loads);
+ mem_access::load_global(beta_local, beta + iter_idx, do_loads);
+
+#pragma unRoll
+ for (int j = 0; j < T_per_load; j++) {
+ float val = conversion::to(iteration_buffer[j]);
+ val = (val - mean) * denom;
+ val =
+ val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]);
+ iteration_buffer[j] = conversion::to(val);
+ }
+
+ if (do_loads) {
+ mem_access::store_global(block_output + iter_idx, iteration_buffer);
+ }
+ }
+}
+
+#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \
+ fused_ln \
+ <<>>(output, vals, gamma, beta, epsilon, elems_per_row);
+
+template
+void launch_fused_ln(T* output,
+ const T* vals,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream)
+{
+ // 8 for __half, 4 for float
+ constexpr int T_per_load = ln::granularity / sizeof(T);
+
+ constexpr int maxThreads = 256;
+
+ // For Flaoat, unRoll 4, for __half, unRoll 2
+ constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
+
+ const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
+ const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
+
+ // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
+ // warp-sized blocks rather than stepping up to 64/96 threads
+ const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
+ const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
+
+ const int groups_per_block_max =
+ is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
+ const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
+ const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
+
+ dim3 block(threadsPerGroup, groups_per_block);
+ dim3 grid(groups_launch);
+
+ const int elems_per_step = threadsPerGroup * h_per_step;
+ const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
+
+ if (is_subblock_schedule) {
+ // <=128
+ if (threadsPerGroup == 1) {
+ LAUNCH_FUSED_LN(1, 1, maxThreads);
+ } else if (threadsPerGroup == 2) {
+ LAUNCH_FUSED_LN(1, 2, maxThreads);
+ } else if (threadsPerGroup == 4) {
+ LAUNCH_FUSED_LN(1, 4, maxThreads);
+ } else if (threadsPerGroup == 8) {
+ LAUNCH_FUSED_LN(1, 8, maxThreads);
+ } else if (threadsPerGroup == 16) {
+ LAUNCH_FUSED_LN(1, 16, maxThreads);
+ }
+ } else if (external_unRoll == 1) {
+ // 129 - 4096 elems
+ // (this can launch with 1-7 warps as well)
+ LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 2) {
+ // 4097 - 8192 elems
+ LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 3) {
+ // 8193 - 12288 elems
+ LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 4) {
+ // 12289 - 16384 elems
+ LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads);
+ }
+}
+
+#define INSTANTIATE_FUSED_LN(T) \
+ template void launch_fused_ln(T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
+
+INSTANTIATE_FUSED_LN(__half);
+#ifdef BF16_AVAILABLE
+INSTANTIATE_FUSED_LN(__nv_bfloat16);
+#endif
+INSTANTIATE_FUSED_LN(float);
+
+/*
+Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8
+is equal to 0.
+
+TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual
+need to be fused into compute-bound producer operations.
+
+Args:
+ output: buffer for output data
+ res_output: output of residual addition
+ vals: buffer for input data
+ residual: residual data
+ bias: bias of of input data
+ gamma: gain for normalization
+ beta: bias for normalization
+ epsilon: numeric stability
+ elems_per_row: number of elements each block will normalize
+Template arg:
+ StoreResidual: controls whether the residual calculation is stored
+ or not. When set to false, the input `res_output` is unused.
+*/
+template
+__global__ void fused_residual_ln(T* output,
+ T* res_output,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int elems_per_row)
+{
+ constexpr int T_per_load = ln::granularity / sizeof(T);
+
+ cg::thread_block tb = cg::this_thread_block();
+ cg::thread_block_tile warp = cg::tiled_partition(tb);
+
+ // X-dimension of the block
+ const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
+ (tb.thread_index().y * elems_per_row);
+ const int thread_offset = tb.thread_index().x * T_per_load;
+ const int base_offset = block_offset + thread_offset;
+ const int stride = tb.size() * T_per_load;
+
+ float sum = reduce::init();
+
+ const T* input_base = vals + base_offset;
+ const T* residual_base = residual + base_offset;
+
+ T local_buffer[unRoll * T_per_load];
+
+ // Unlike a vanilla layernorm, since we're fusing the two adds as well
+ // an inner unRoll seems to be less valuable. If anything, a double unRoll
+ // makes the most sense if we find we are having performance issues.
+#pragma unRoll
+ for (int i = 0; i < unRoll; i++) {
+ T* iteration_buffer = local_buffer + i * T_per_load;
+ T residual_buffer[T_per_load];
+ T bias_buffer[T_per_load];
+
+ mem_access::load_global(
+ iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
+ mem_access::load_global(residual_buffer,
+ residual_base + i * stride,
+ thread_offset + i * stride < elems_per_row);
+
+#pragma unRoll
+ for (int j = 0; j < T_per_load; j++) {
+ float vals_up_cast = conversion::to(iteration_buffer[j]);
+ float res_up_cast = conversion::to(residual_buffer[j]);
+ vals_up_cast += res_up_cast;
+ sum = reduce::element(sum, vals_up_cast);
+ iteration_buffer[j] = conversion::to(vals_up_cast);
+ }
+
+ if (preLnResidual && (thread_offset + i * stride < elems_per_row)) {
+ mem_access::store_global(res_output + base_offset + i * stride,
+ iteration_buffer);
+ }
+ }
+
+ reduce::partitioned_block(tb, warp, sum);
+ const float mean = sum / elems_per_row;
+
+ float mean_diff = reduce::init();
+#pragma unRoll
+ for (int i = 0; i < unRoll; i++) {
+#pragma unRoll
+ for (int j = 0; j < T_per_load; j++) {
+ // Using a 0 value here skews the variance, have to if-guard
+ if (thread_offset + i * stride < elems_per_row) {
+ float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean);
+ mean_diff = reduce::element(mean_diff, diff * diff);
+ }
+ }
+ }
+
+ reduce::partitioned_block(tb, warp, mean_diff);
+ const float variance = mean_diff / elems_per_row;
+ const float denom = __frsqrt_rn(variance + epsilon);
+
+ T* block_output = output + block_offset;
+
+#pragma unRoll
+ for (int i = 0; i < unRoll; i++) {
+ T* iteration_buffer = local_buffer + i * T_per_load;
+ const int iter_idx = i * stride + thread_offset;
+ const bool do_loads = iter_idx < elems_per_row;
+
+ T gamma_local[T_per_load], beta_local[T_per_load];
+
+ mem_access::load_global(gamma_local, gamma + iter_idx, do_loads);
+ mem_access::load_global(beta_local, beta + iter_idx, do_loads);
+
+#pragma unRoll
+ for (int j = 0; j < T_per_load; j++) {
+ float val = conversion::to(iteration_buffer[j]);
+ val = (val - mean) * denom;
+ val =
+ val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]);
+ iteration_buffer[j] = conversion::to(val);
+ }
+
+ if (do_loads) {
+ mem_access::store_global(block_output + iter_idx, iteration_buffer);
+ }
+ }
+}
+
+// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified.
+#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \
+ fused_residual_ln \
+ <<>>( \
+ output, nullptr, vals, residual, gamma, beta, epsilon, elems_per_row);
+
+template
+void launch_fused_post_ln(T* output,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream)
+{
+ // 8 for __half, 4 for float
+ constexpr int T_per_load = ln::granularity / sizeof(T);
+
+ constexpr int maxThreads = 256;
+
+ // For Flaoat, unRoll 4, for __half, unRoll 2
+ constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
+
+ const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
+ const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
+
+ // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
+ // warp-sized blocks rather than stepping up to 64/96 threads
+ const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
+ const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
+
+ const int groups_per_block_max =
+ is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
+ const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
+ const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
+
+ dim3 block(threadsPerGroup, groups_per_block);
+ dim3 grid(groups_launch);
+
+ const int elems_per_step = threadsPerGroup * h_per_step;
+ const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
+
+ if (is_subblock_schedule) {
+ // <=128
+ if (threadsPerGroup == 1) {
+ LAUNCH_FUSED_RES_LN(1, 1, maxThreads);
+ } else if (threadsPerGroup == 2) {
+ LAUNCH_FUSED_RES_LN(1, 2, maxThreads);
+ } else if (threadsPerGroup == 4) {
+ LAUNCH_FUSED_RES_LN(1, 4, maxThreads);
+ } else if (threadsPerGroup == 8) {
+ LAUNCH_FUSED_RES_LN(1, 8, maxThreads);
+ } else if (threadsPerGroup == 16) {
+ LAUNCH_FUSED_RES_LN(1, 16, maxThreads);
+ }
+ } else if (external_unRoll == 1) {
+ // 129 - 4096 elems
+ // (this can launch with 1-7 warps as well)
+ LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 2) {
+ // 4097 - 8192 elems
+ LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 3) {
+ // 8193 - 12288 elems
+ LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 4) {
+ // 12289 - 16384 elems
+ LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads);
+ }
+}
+
+#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \
+ fused_residual_ln \
+ <<>>( \
+ norm_output, res_output, vals, residual, gamma, beta, epsilon, elems_per_row);
+
+template
+void launch_fused_pre_ln(T* norm_output,
+ T* res_output,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream)
+{
+ // 8 for __half, 4 for float
+ constexpr int T_per_load = ln::granularity / sizeof(T);
+
+ constexpr int maxThreads = 256;
+
+ // For Flaoat, unRoll 4, for __half, unRoll 2
+ constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
+
+ const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
+ const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
+
+ // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
+ // warp-sized blocks rather than stepping up to 64/96 threads
+ const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
+ const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
+
+ const int groups_per_block_max =
+ is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
+ const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
+ const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
+
+ dim3 block(threadsPerGroup, groups_per_block);
+ dim3 grid(groups_launch);
+
+ const int elems_per_step = threadsPerGroup * h_per_step;
+ const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
+
+ if (is_subblock_schedule) {
+ // <=128
+ if (threadsPerGroup == 1) {
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads);
+ } else if (threadsPerGroup == 2) {
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads);
+ } else if (threadsPerGroup == 4) {
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads);
+ } else if (threadsPerGroup == 8) {
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads);
+ } else if (threadsPerGroup == 16) {
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads);
+ }
+ } else if (external_unRoll == 1) {
+ // 129 - 4096 elems
+ // (this can launch with 1-7 warps as well)
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 2) {
+ // 4097 - 8192 elems
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 3) {
+ // 8193 - 12288 elems
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads);
+ } else if (external_unRoll == 4) {
+ // 12289 - 16384 elems
+ LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads);
+ }
+}
+
+#define INSTANTIATE_RES_LN(T) \
+ template void launch_fused_post_ln( \
+ T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
+
+#define INSTANTIATE_PRE_LN_RES(T) \
+ template void launch_fused_pre_ln( \
+ T*, T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t);
+
+INSTANTIATE_RES_LN(__half);
+INSTANTIATE_RES_LN(float);
+#ifdef BF16_AVAILABLE
+INSTANTIATE_RES_LN(__nv_bfloat16);
+#endif
+
+INSTANTIATE_PRE_LN_RES(__half);
+INSTANTIATE_PRE_LN_RES(float);
+#ifdef BF16_AVAILABLE
+INSTANTIATE_PRE_LN_RES(__nv_bfloat16);
+#endif
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h
new file mode 100644
index 000000000000..9ea3a8c42524
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h
@@ -0,0 +1,67 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include
+#include
+#include "ds_kernel_utils.h"
+
+/*
+Kernel launch methods for layer norm variants.
+*/
+
+template
+void launch_fused_ln(T* output,
+ const T* vals,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream);
+
+template
+void launch_fused_post_ln(T* output,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream);
+template
+void launch_fused_pre_ln(T* norm_output,
+ T* res_output,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ const T* beta,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream);
+
+void ds_layer_norm(at::Tensor& output,
+ at::Tensor& input,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ float epsilon);
+
+void ds_post_layer_norm(at::Tensor& output,
+ at::Tensor& input,
+ at::Tensor& residual,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ float epsilon);
+
+void ds_pre_layer_norm(at::Tensor& res_output,
+ at::Tensor& norm_output,
+ at::Tensor& input,
+ at::Tensor& residual,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ float epsilon);
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py
new file mode 100644
index 000000000000..640a72307650
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .rms_norm import CUDARMSNorm
+from .rms_pre_norm import CUDARMSPreNorm
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp
new file mode 100644
index 000000000000..c67712df438a
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp
@@ -0,0 +1,123 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "rms_norm.h"
+
+#ifdef BF16_AVAILABLE
+#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
+ [&] { \
+ if (DTYPE == torch::kFloat32) { \
+ using scalar_t = float; \
+ return __VA_ARGS__(); \
+ } else if (DTYPE == torch::kFloat16) { \
+ using scalar_t = __half; \
+ return __VA_ARGS__(); \
+ } else if (DTYPE == torch::kBFloat16) { \
+ using scalar_t = __nv_bfloat16; \
+ return __VA_ARGS__(); \
+ } else { \
+ TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
+ } \
+ }()
+#else
+#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
+ [&] { \
+ if (DTYPE == torch::kFloat32) { \
+ using scalar_t = float; \
+ return __VA_ARGS__(); \
+ } else if (DTYPE == torch::kFloat16) { \
+ using scalar_t = __half; \
+ return __VA_ARGS__(); \
+ } else { \
+ TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
+ } \
+ }()
+#endif
+
+void rms_norm(torch::Tensor& norm_output,
+ torch::Tensor& norm_input,
+ torch::Tensor& gamma,
+ float epsilon)
+{
+ TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(),
+ "norm_output and norm_input should have the same data type");
+ TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(),
+ "norm_output and gamma should have the same data type");
+
+ const int32_t rows = norm_input.size(0);
+ const int32_t cols = norm_input.size(1);
+
+ TORCH_CHECK(norm_output.size(0) == rows,
+ "norm_output and norm_input should have the same first dimension");
+ TORCH_CHECK(norm_output.size(1) == cols,
+ "norm_output and norm_input should have the same second dimension");
+
+ DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] {
+ scalar_t* norm_output_ptr = reinterpret_cast(norm_output.data_ptr());
+ scalar_t* norm_input_ptr = reinterpret_cast(norm_input.data_ptr());
+ scalar_t* gamma_ptr = reinterpret_cast(gamma.data_ptr());
+ scalar_t* null_t = nullptr;
+
+ launch_rms_norm(norm_output_ptr,
+ null_t,
+ norm_input_ptr,
+ null_t,
+ gamma_ptr,
+ epsilon,
+ rows,
+ cols,
+ at::cuda::getCurrentCUDAStream());
+ });
+}
+
+void rms_pre_norm(torch::Tensor& norm_output,
+ torch::Tensor& residual_output,
+ torch::Tensor& norm_input,
+ torch::Tensor& residual_input,
+ torch::Tensor& gamma,
+ float epsilon)
+{
+ TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(),
+ "norm_output and norm_input should have the same data type");
+ TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(),
+ "norm_output and gamma should have the same data type");
+
+ const int32_t rows = norm_input.size(0);
+ const int32_t cols = norm_input.size(1);
+
+ TORCH_CHECK(norm_output.size(0) == rows,
+ "norm_output and norm_input should have the same first dimension");
+ TORCH_CHECK(norm_output.size(1) == cols,
+ "norm_output and norm_input should have the same second dimension");
+
+ TORCH_CHECK(residual_output.size(0) == rows,
+ "residual_output and norm_input should have the same first dimension");
+ TORCH_CHECK(residual_output.size(1) == cols,
+ "residual_output and norm_input should have the same second dimension");
+
+ TORCH_CHECK(residual_input.size(0) == rows,
+ "residual_input and norm_input should have the same first dimension");
+ TORCH_CHECK(residual_input.size(1) == cols,
+ "residual_input and norm_input should have the same second dimension");
+
+ DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] {
+ scalar_t* norm_output_ptr = reinterpret_cast(norm_output.data_ptr());
+ scalar_t* residual_output_ptr = reinterpret_cast(residual_output.data_ptr());
+ const scalar_t* norm_input_ptr = reinterpret_cast(norm_input.data_ptr());
+ const scalar_t* residual_input_ptr =
+ reinterpret_cast(residual_input.data_ptr());
+ const scalar_t* gamma_ptr = reinterpret_cast(gamma.data_ptr());
+
+ launch_rms_norm(norm_output_ptr,
+ residual_output_ptr,
+ norm_input_ptr,
+ residual_input_ptr,
+ gamma_ptr,
+ epsilon,
+ rows,
+ cols,
+ at::cuda::getCurrentCUDAStream());
+ });
+}
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu
new file mode 100644
index 000000000000..e69d3c36cc00
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu
@@ -0,0 +1,262 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "conversion_utils.h"
+#include "ds_kernel_utils.h"
+#include "memory_access_utils.h"
+#include "reduction_utils.h"
+
+namespace cg = cooperative_groups;
+using rop = reduce::ROpType;
+
+namespace rms {
+constexpr int granularity = 16;
+} // namespace rms
+
+template
+__global__ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row)
+{
+ constexpr int T_per_load = rms::granularity / sizeof(T);
+
+ cg::thread_block tb = cg::this_thread_block();
+ cg::thread_block_tile warp = cg::tiled_partition(tb);
+
+ // X-dimension of the block
+ const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
+ (tb.thread_index().y * elems_per_row);
+ const int thread_offset = tb.thread_index().x * T_per_load;
+ const int base_offset = block_offset + thread_offset;
+ const int stride = blockDim.x * T_per_load;
+
+ float var_sum = reduce::init();
+
+ const T* input_base = vals + base_offset;
+
+ T local_buffer[UNROLL * T_per_load];
+
+#pragma unroll
+ for (int i = 0; i < UNROLL; i++) {
+ T* iteration_buffer = local_buffer + (i * T_per_load);
+
+ mem_access::load_global(iteration_buffer,
+ input_base + (i * stride),
+ thread_offset + (i * stride) < elems_per_row);
+
+#pragma unroll
+ for (int j = 0; j < T_per_load; j++) {
+ float up_cast = conversion::to(iteration_buffer[j]);
+ float sq_val = up_cast * up_cast;
+ var_sum = reduce::element(var_sum, sq_val);
+ }
+ }
+
+ reduce::partitioned_block(tb, warp, var_sum);
+ const float var = var_sum / elems_per_row;
+ const T denom = conversion::to(__frsqrt_rn(var + epsilon));
+
+ T* block_output = output + block_offset;
+
+#pragma unroll
+ for (int i = 0; i < UNROLL; i++) {
+ T* iteration_buffer = local_buffer + (i * T_per_load);
+ const int iter_idx = i * stride + thread_offset;
+ const bool do_loads = (iter_idx < elems_per_row);
+
+ T gamma_local[T_per_load];
+
+ mem_access::load_global(gamma_local, gamma + iter_idx, do_loads);
+
+#pragma unroll
+ for (int j = 0; j < T_per_load; j++) {
+ iteration_buffer[j] *= denom;
+ iteration_buffer[j] *= gamma_local[j];
+ }
+
+ if (do_loads) {
+ mem_access::store_global(block_output + iter_idx, iteration_buffer);
+ }
+ }
+}
+
+template
+__global__ void pre_rms_norm(T* output,
+ T* res_out,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ float epsilon,
+ int elems_per_row)
+{
+ constexpr int T_per_load = rms::granularity / sizeof(T);
+
+ cg::thread_block tb = cg::this_thread_block();
+ cg::thread_block_tile warp = cg::tiled_partition(tb);
+
+ // X-dimension of the block
+ const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
+ (tb.thread_index().y * elems_per_row);
+ const int thread_offset = tb.thread_index().x * T_per_load;
+ const int base_offset = block_offset + thread_offset;
+ const int stride = blockDim.x * T_per_load;
+
+ float var_sum = reduce::init();
+
+ const T* input_base = vals + base_offset;
+ const T* residual_base = residual + base_offset;
+ T* res_output = res_out + base_offset;
+
+ T local_buffer[UNROLL * T_per_load];
+
+#pragma unroll
+ for (int i = 0; i < UNROLL; i++) {
+ T* iteration_buffer = local_buffer + (i * T_per_load);
+ T residual_buffer[T_per_load];
+
+ const int iter_offset = i * stride + thread_offset;
+ const bool do_loads = (iter_offset < elems_per_row);
+
+ mem_access::load_global(
+ iteration_buffer, input_base + (i * stride), do_loads);
+ mem_access::load_global(
+ residual_buffer, residual_base + (i * stride), do_loads);
+
+#pragma unroll
+ for (int j = 0; j < T_per_load; j++) {
+ iteration_buffer[j] += residual_buffer[j];
+ float vals_up_cast = conversion::to(iteration_buffer[j]);
+
+ var_sum = reduce::element(var_sum, vals_up_cast * vals_up_cast);
+ }
+
+ if (do_loads) {
+ mem_access::store_global(res_output + i * stride, iteration_buffer);
+ }
+ }
+
+ reduce::partitioned_block(tb, warp, var_sum);
+ const float var = var_sum / elems_per_row;
+ const T denom = conversion::to(__frsqrt_rn(var + epsilon));
+
+ T* block_output = output + block_offset;
+
+#pragma unroll
+ for (int i = 0; i < UNROLL; i++) {
+ T* iteration_buffer = local_buffer + (i * T_per_load);
+ const int iter_idx = i * stride + thread_offset;
+ const bool do_loads = (iter_idx < elems_per_row);
+
+ T gamma_local[T_per_load];
+
+ mem_access::load_global(gamma_local, gamma + iter_idx, do_loads);
+
+#pragma unroll
+ for (int j = 0; j < T_per_load; j++) {
+ iteration_buffer[j] *= denom;
+ iteration_buffer[j] *= gamma_local[j];
+ }
+
+ if (do_loads) {
+ mem_access::store_global(block_output + iter_idx, iteration_buffer);
+ }
+ }
+}
+
+#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
+ rms_norm \
+ <<>>(norm_output, vals, gamma, epsilon, elems_per_row);
+
+#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
+ pre_rms_norm<<>>( \
+ norm_output, res_output, vals, residual, gamma, epsilon, elems_per_row);
+
+#define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
+ if (pre_norm) { \
+ LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
+ } else { \
+ LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \
+ }
+
+template
+void launch_rms_norm(T* norm_output,
+ T* res_output,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream)
+{
+ // 8 for __half, 4 for float
+ constexpr int T_per_load = rms::granularity / sizeof(T);
+ constexpr int maxThreads = 256;
+ constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2;
+
+ const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
+ const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll;
+
+ // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
+ // warp-sized blocks rather than stepping up to 64/96 threads
+ const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
+ const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
+
+ const int groups_per_block_max =
+ is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1;
+ const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
+ const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
+
+ dim3 block(threads_per_group, groups_per_block);
+ dim3 grid(groups_launch);
+
+ const int elems_per_step = threads_per_group * h_per_step;
+ const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
+
+ bool pre_norm = (residual == nullptr) ? false : true;
+
+ if (is_subblock_schedule) {
+ // <=128
+ if (threads_per_group == 1) {
+ LAUNCH_ALL_RMS_NORM(1, 1, maxThreads);
+ } else if (threads_per_group == 2) {
+ LAUNCH_ALL_RMS_NORM(1, 2, maxThreads);
+ } else if (threads_per_group == 4) {
+ LAUNCH_ALL_RMS_NORM(1, 4, maxThreads);
+ } else if (threads_per_group == 8) {
+ LAUNCH_ALL_RMS_NORM(1, 8, maxThreads);
+ } else if (threads_per_group == 16) {
+ LAUNCH_ALL_RMS_NORM(1, 16, maxThreads);
+ }
+ } else if (external_unRoll == 1) {
+ // 129 - 4096 elems
+ // (this can launch with 1-7 warps as well)
+ LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads);
+ } else if (external_unRoll == 2) {
+ // 4097 - 8192 elems
+ LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads);
+ } else if (external_unRoll == 3) {
+ // 8193 - 12288 elems
+ LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads);
+ } else if (external_unRoll == 4) {
+ // 12289 - 16384 elems
+ LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads);
+ }
+}
+
+#define INSTANTIATE_LAUNCH_RMS_NORM(T) \
+ template void launch_rms_norm(T * norm_output, \
+ T * res_output, \
+ const T* vals, \
+ const T* residual, \
+ const T* gamma, \
+ float epsilon, \
+ int rows, \
+ int elems_per_row, \
+ cudaStream_t stream);
+
+INSTANTIATE_LAUNCH_RMS_NORM(float)
+INSTANTIATE_LAUNCH_RMS_NORM(__half)
+#ifdef BF16_AVAILABLE
+INSTANTIATE_LAUNCH_RMS_NORM(__nv_bfloat16)
+#endif
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h
new file mode 100644
index 000000000000..7867fb65964f
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h
@@ -0,0 +1,33 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#pragma once
+
+#include
+#include
+#include "ds_kernel_utils.h"
+
+template
+void launch_rms_norm(T* norm_output,
+ T* res_output,
+ const T* vals,
+ const T* residual,
+ const T* gamma,
+ float epsilon,
+ int rows,
+ int elems_per_row,
+ cudaStream_t stream);
+
+void rms_norm(torch::Tensor& norm_output,
+ torch::Tensor& norm_input,
+ torch::Tensor& gamma,
+ float epsilon);
+
+void rms_pre_norm(torch::Tensor& norm_output,
+ torch::Tensor& residual_output,
+ torch::Tensor& norm_input,
+ torch::Tensor& residual_input,
+ torch::Tensor& gamma,
+ float epsilon);
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py
new file mode 100644
index 000000000000..deb5d33111a9
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py
@@ -0,0 +1,28 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+
+from .rms_norm_base import CUDARMSNormBase
+
+
+class CUDARMSNorm(CUDARMSNormBase):
+ """
+ Floating point layer norm kernel for CUDA/RoCM.
+
+ Performs: z = ln(x)
+ """
+
+ def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
+ """
+ output_z may alias input_x directly. All Tensors should have the same shape.
+
+ Parameters:
+ output_z (torch.Tensor): Output tensor.
+ input_x (torch.Tensor): Input tensor.
+ gamma (torch.Tensor): Gamma tensor.
+ """
+ self.inf_module.rms_norm(output_z, input_x, gamma, self.epsilon)
+ return output_z
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py
new file mode 100644
index 000000000000..62bc9d056ade
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py
@@ -0,0 +1,37 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+
+from ... import DSKernelBase
+from ....inference_utils import elem_size
+from deepspeed.ops.op_builder import InferenceCoreBuilder
+
+
+class CUDARMSNormBase(DSKernelBase):
+ """
+ Base class for CUDA LN kernels. They all same the same validation logic,
+ so we can share it here.
+ """
+
+ supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5):
+ """
+ Parameters:
+ channels (int): Number of channels in the input tensor. Must be divisible to align
+ to 16 bytes.
+ fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
+ are torch.float16, torch.bfloat16, and torch.float32.
+ """
+ if fp_dtype not in CUDARMSNormBase.supported_dtypes:
+ raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
+ fp_dtype, CUDARMSNormBase.supported_dtypes))
+
+ if elem_size(fp_dtype) * channels % 16 != 0:
+ raise ValueError("channels must be divisible by 16 bytes")
+
+ self.inf_module = InferenceCoreBuilder().load()
+ self.epsilon = epsilon
diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py
new file mode 100644
index 000000000000..3b040d88b50f
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from typing import Tuple
+
+import torch
+
+from .rms_norm_base import CUDARMSNormBase
+
+
+class CUDARMSPreNorm(CUDARMSNormBase):
+ """
+ Floating point pre-LayerNorm kernel for CUDA/RoCM.
+
+ Performs: z_res = x_res + y_hid
+ z_hid = ln(z_hid)
+ """
+
+ def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor,
+ gamma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ z_res can alias x_res. All non-parameter input/output tensors
+ must have the same shape. z_hid can alias y_hid.
+
+ Parameters:
+ z_res (torch.Tensor): Output residual.
+ z_hid (torch.Tensor): Output hidden states.
+ x_res (torch.Tensor): Input residual.
+ y_hid (torch.Tensor): Input hidden states.
+ gamma (torch.Tensor): Gamma tensor.
+ beta (torch.Tensor): Beta tensor.
+
+ Returns:
+ output (torch.Tensor): Output tensor.
+ """
+ self.inf_module.rms_pre_norm(z_hid, z_res, y_hid, x_res, gamma, self.epsilon)
+ return z_res, z_hid
diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py b/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py
new file mode 100644
index 000000000000..05479d86c906
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .gated_activation import *
diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py
new file mode 100644
index 000000000000..ca1b62ba5c36
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py
@@ -0,0 +1,65 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from typing import Optional
+
+import torch
+
+from ... import DSKernelBase
+from ....inference_utils import ActivationType, elem_size
+from deepspeed.ops.op_builder import InferenceCoreBuilder
+
+
+class CUDAGatedActivation(DSKernelBase):
+ """
+ CUDA implementation of gated activation kernel. This kernel assumes that the input
+ tensor has gate and activation values in adjacent channels. The output tensor should
+ have half the dimensionality of the input tensor.
+ """
+
+ supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ supported_act_fns = [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]
+
+ def __init__(self, channels: int, fp_dtype: torch.dtype, act_fn: ActivationType) -> None:
+ """
+ Compile and validate for the gated activation function.
+
+ Args:
+ channels (int): Number of columns in the output tensor. Must be divisible to align
+ to 8 bytes.
+ fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values
+ are torch.float16, torch.bfloat16, and torch.float32.
+ act_fn (ActivationType): Activation function to use. Only GEGLU is supported.
+ """
+ if fp_dtype not in CUDAGatedActivation.supported_dtypes:
+ raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format(
+ fp_dtype, CUDAGatedActivation.supported_dtypes))
+
+ act_fn = ActivationType(act_fn)
+ if act_fn not in CUDAGatedActivation.supported_act_fns:
+ raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format(
+ act_fn, CUDAGatedActivation.supported_act_fns))
+
+ if elem_size(fp_dtype) * channels % 8 != 0:
+ raise ValueError("Channels must be divisible by 16 bytes")
+
+ if elem_size(fp_dtype) * channels > 98304:
+ raise ValueError(
+ "Kernel only compiled to support 98304 bytes per row, please file an issue if your model requires more."
+ )
+
+ self.inf_module = InferenceCoreBuilder().load()
+ self.act_fn = act_fn
+ self.kernel = self.inf_module.gated_activation
+
+ def __call__(self, output: torch.Tensor, input: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None:
+ """
+ Performs gated activation on the input tensor, writing the result to the output tensor.
+
+ Args:
+ output (torch.Tensor): Output tensor. Can be of [T, C // 2] or [B, S, C // 2]
+ input (torch.Tensor): Input tensor. Can be of [T, C] or [B, S, C]
+ """
+ self.kernel(output, input, bias, self.act_fn.value)
diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp
new file mode 100644
index 000000000000..05463c75138c
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp
@@ -0,0 +1,72 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include "gated_activation_kernels.h"
+
+#ifdef BF16_AVAILABLE
+#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
+ [&] { \
+ if (DTYPE == torch::kFloat32) { \
+ using scalar_t = float; \
+ return __VA_ARGS__(); \
+ } else if (DTYPE == torch::kFloat16) { \
+ using scalar_t = __half; \
+ return __VA_ARGS__(); \
+ } else if (DTYPE == torch::kBFloat16) { \
+ using scalar_t = __nv_bfloat16; \
+ return __VA_ARGS__(); \
+ } else { \
+ TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
+ } \
+ }()
+#else
+#define DISPATCH_FOR_FLOAT(DTYPE, ...) \
+ [&] { \
+ if (DTYPE == torch::kFloat32) { \
+ using scalar_t = float; \
+ return __VA_ARGS__(); \
+ } else if (DTYPE == torch::kFloat16) { \
+ using scalar_t = __half; \
+ return __VA_ARGS__(); \
+ } else { \
+ TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \
+ } \
+ }()
+#endif
+
+void ds_gated_activation(at::Tensor& output,
+ at::Tensor& input,
+ c10::optional& bias,
+ int activation_type_raw)
+{
+ bool ragged_input = input.dim() == 2;
+
+ const ActivationType activation_type = static_cast(activation_type_raw);
+
+ const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1);
+ const int cols = ragged_input ? input.size(1) : input.size(2);
+
+ DISPATCH_FOR_FLOAT(input.scalar_type(), [&] {
+ scalar_t* bias_ptr = nullptr;
+ if (bias.has_value()) {
+ TORCH_CHECK(bias.value().scalar_type() == input.scalar_type(),
+ "Bias type must match input type");
+ TORCH_CHECK(bias.value().numel() == cols,
+ "Bias must have the same number of elements as the input channels");
+ bias_ptr = reinterpret_cast(bias.value().data_ptr());
+ }
+
+ scalar_t* output_ptr = reinterpret_cast(output.data_ptr());
+ const scalar_t* input_ptr = reinterpret_cast(input.data_ptr());
+
+ launch_gated_activation(output_ptr,
+ input_ptr,
+ bias_ptr,
+ rows,
+ cols,
+ activation_type,
+ c10::cuda::getCurrentCUDAStream());
+ });
+}
diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu
new file mode 100644
index 000000000000..84a9906cf037
--- /dev/null
+++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu
@@ -0,0 +1,169 @@
+// Copyright (c) Microsoft Corporation.
+// SPDX-License-Identifier: Apache-2.0
+
+// DeepSpeed Team
+
+#include
+#include "activation_type.h"
+#include "conversion_utils.h"
+#include "ds_kernel_utils.h"
+#include "memory_access_utils.h"
+
+namespace cg = cooperative_groups;
+
+namespace gated_act {
+
+constexpr int access_size = 16;
+constexpr int threads = 1024;
+
+template
+float gated_act_fn(float x, float y);
+
+template <>
+DS_D_INLINE float gated_act_fn(float x, float y)
+{
+ constexpr float sqrt_param = 0.79788456080286535587989211986876f;
+ constexpr float mul_param = 0.044715;
+ return y * x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
+}
+
+template <>
+DS_D_INLINE float gated_act_fn(float x, float y)
+{
+ return y * (x > 0.0f ? x : 0.0f);
+}
+
+template <>
+DS_D_INLINE float gated_act_fn(float x, float y)
+{
+ return y * (x / (1.0f + expf(-x)));
+}
+
+} // namespace gated_act
+
+template
+__global__ void gated_activation_kernel(T* output,
+ const T* input,
+ const T* bias,
+ int rows,
+ int cols)
+{
+ constexpr int read_vector = gated_act::access_size / sizeof(T);
+ constexpr int write_vector = read_vector / 2;
+
+ const int row = blockIdx.x;
+ const int col = threadIdx.x * read_vector;
+
+ const T* input_row = input + row * cols;
+ T* output_row = output + row * cols / 2;
+
+#pragma unroll
+ for (int i = 0; i < loopUnroll; i++) {
+ T read[read_vector];
+ T bias_read[read_vector];
+ T store[write_vector];
+
+ const int read_offset = col + gated_act::threads * read_vector * i;
+ const int write_offset = col / 2 + gated_act::threads * write_vector * i;
+
+ if (i != loopUnroll - 1 || read_offset < cols) {
+ mem_access::load_global(read, input_row + read_offset);
+ mem_access::load_global(
+ bias_read, bias + read_offset, bias != nullptr);
+
+ for (int j = 0; j < write_vector; j++) {
+ float g_val =
+ conversion::to(read[j * 2]) + conversion::to(bias_read[j * 2]);
+ float a_val = conversion::to(read[j * 2 + 1]) +
+ conversion::to(bias_read[j * 2 + 1]);
+
+ float act_val = gated_act::gated_act_fn(g_val, a_val);
+ store[j] = conversion::to(act_val);
+ }
+
+ mem_access::store_global(output_row + write_offset, store);
+ }
+ }
+}
+
+#define DISPATCH_UNROLL(unroll_val) \
+ gated_activation_kernel \
+ <<>>(output, input, bias, rows, cols);
+
+template
+void launch_gated_activation_impl(T* output,
+ const T* input,
+ const T* bias,
+ int rows,
+ int cols,
+ cudaStream_t stream)
+{
+ constexpr int read_vector = gated_act::access_size / sizeof(T);
+ constexpr int cols_per_unroll = gated_act::threads * read_vector;
+ const int req_threads = (cols + read_vector - 1) / read_vector;
+ const int threads = std::min(req_threads, gated_act::threads);
+
+ const dim3 grid(rows);
+ const dim3 block(threads);
+ const int unroll = (cols + cols_per_unroll - 1) / cols_per_unroll;
+
+ if (unroll == 1) {
+ DISPATCH_UNROLL(1);
+ } else if (unroll == 2) {
+ DISPATCH_UNROLL(2);
+ } else if (unroll == 3) {
+ DISPATCH_UNROLL(3);
+ } else if (unroll == 4) {
+ DISPATCH_UNROLL(4);
+ } else if (unroll == 5) {
+ DISPATCH_UNROLL(5);
+ } else if (unroll == 6) {
+ DISPATCH_UNROLL(6);
+ } else {
+ throw std::runtime_error(
+ "Called with more columns than supported, please report this bug and this limit will "
+ "be increased.");
+ }
+}
+
+template
+void launch_gated_activation(T* output,
+ const T* input,
+ const T* bias,
+ int rows,
+ int cols,
+ ActivationType act_type,
+ cudaStream_t stream)
+{
+ switch (act_type) {
+ case ActivationType::GEGLU:
+ launch_gated_activation_impl(
+ output, input, bias, rows, cols, stream);
+ break;
+ case ActivationType::ReGLU:
+ launch_gated_activation_impl