diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml new file mode 100644 index 000000000000..38d4ffd1828b --- /dev/null +++ b/.github/workflows/nv-a6000.yml @@ -0,0 +1,56 @@ +name: nv-a6000 + +on: + pull_request: + paths-ignore: + - 'docs/**' + - 'blogs/**' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + issues: write + +jobs: + unit-tests: + runs-on: [self-hosted, nvidia, a6000] + container: + image: nvcr.io/nvidia/pytorch:23.03-py3 + ports: + - 80 + options: --gpus all --shm-size "8G" + + steps: + - uses: actions/checkout@v3 + + - name: Check container state + run: | + ldd --version + nvcc --version + nvidia-smi + python -c "import torch; print('torch:', torch.__version__, torch)" + python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + git rev-parse --short HEAD + python -m pip install . + - name: Install deepspeed + run: | + python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja + python -m pip install .[dev,1bit,autotuning] + ds_report + - name: Python environment + run: | + python -m pip list + - name: Unit tests + run: | + unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch + cd tests + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.0" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.0" --cuda_ver="12" diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index 065f8b93f1e0..b5cf46f79011 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -34,6 +34,7 @@ jobs: run: | git clone https://github.com/huggingface/transformers cd transformers + git checkout f370bebdc git rev-parse --short HEAD pip install . diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index ccb6c25e14f7..f253340c6966 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -33,7 +33,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . + TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . - name: DS Report run: | ds_report diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b11b3acba51..2432a7a24124 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,6 +49,7 @@ repos: entry: ./scripts/check-license.py language: python files: \.(py|c|cpp|cu|cc|h|hpp|cuh|hip|tr)$ + exclude: ^(deepspeed/inference/v2/kernels/ragged_ops/blocked_flash|deepspeed/inference/v2/kernels/cutlass_ops/grouped_gemm) - repo: https://github.com/codespell-project/codespell rev: v2.1.0 diff --git a/MANIFEST.in b/MANIFEST.in index 2fec750c6644..ab79573ef96c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,6 @@ include *.txt README.md +include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so +include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so recursive-include requirements *.txt recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc diff --git a/README.md b/README.md index bea26ea1828a..721ba62cee37 100755 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ ## Latest News DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). +* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) * [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)] * [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)] * [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 4715e28b192c..2786b425ca7f 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -153,9 +153,26 @@ def max_memory_reserved(self, device_index=None): def total_memory(self, device_index=None): return torch.cuda.get_device_properties(device_index).total_memory + def _get_nvml_gpu_id(self, torch_gpu_id): + """ + credit: https://discuss.pytorch.org/t/making-pynvml-match-torch-device-ids-cuda-visible-devices/103020 + + Remap torch device id to nvml device id, respecting CUDA_VISIBLE_DEVICES. + + If the latter isn't set return the same id + """ + # if CUDA_VISIBLE_DEVICES is used automagically remap the id since pynvml ignores this env var + if "CUDA_VISIBLE_DEVICES" in os.environ: + ids = list(map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(","))) + return ids[torch_gpu_id] # remap + else: + return torch_gpu_id + def available_memory(self, device_index=None): if pynvml: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + if device_index is None: + device_index = self.current_device() + handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index)) info = pynvml.nvmlDeviceGetMemoryInfo(handle) return info.free else: diff --git a/blogs/deepspeed-fastgen/README.md b/blogs/deepspeed-fastgen/README.md new file mode 100644 index 000000000000..b271c112a762 --- /dev/null +++ b/blogs/deepspeed-fastgen/README.md @@ -0,0 +1,302 @@ +
+ +# DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference + +
+ +
+ + +
+ +## Table of Contents +1. [Introduction](#introduction) +2. [Key LLM Serving Techniques](#background) +3. [Dynamic SplitFuse: A Novel Prompt and Generation Composition Strategy](#technical-approach) +4. [Performance Evaluation](#performance-evaluation) +5. [DeepSpeed-FastGen: Implementation and Usage](#using-deepspeed-fastgen) +6. [Try out DeepSpeed-FastGen](#try) +7. [Acknowledgements](#acknowledgements) + + +## 1. Introduction + +Large language models (LLMs) like GPT-4 and LLaMA have emerged as a dominant workload in serving a wide range of applications infused with AI at every level. From general chat models to document summarization, and from autonomous driving to copilots at every layer of the software stack, the demand to deploy and serve these models at scale has skyrocketed. While frameworks like DeepSpeed, PyTorch, and several others can regularly achieve good hardware utilization during LLM training, the interactive nature of these applications and the poor arithmetic intensity of tasks like open-ended text generation have become the bottleneck for inference throughput in existing systems. + +To this end, frameworks like [vLLM](https://arxiv.org/pdf/2309.06180.pdf) powered by PagedAttention and research systems like [Orca](https://www.usenix.org/system/files/osdi22-yu.pdf) have significantly improved the performance of inference for LLMs. However, these systems still struggle to provide consistent quality of service, particularly for workloads with longer prompts. These long prompt workloads are becoming increasingly important as more and more models, like [MPT-StoryWriter](https://www.mosaicml.com/blog/mpt-7b), and systems, such as [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses), support context windows stretching to tens of thousands of tokens. To better understand the problem space, we provide detailed examples of how text generation works for LLMs in two distinct phases called prompt processing and generation. When systems treat them as distinct phases, generation will be preempted by prompt processing that risks breaking the service level agreements (SLAs). + +Today, we are glad to present DeepSpeed-FastGen, a system that overcomes these limitations by leveraging the proposed Dynamic SplitFuse technique and offers up to 2.3x higher effective throughput compared to state-of-the-art systems like vLLM. DeepSpeed-FastGen leverages the combination of DeepSpeed-MII and DeepSpeed-Inference to provide an easy-to-use serving system. + +**Quick Start:** Trying DeepSpeed-FastGen is as simple as installing the latest [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII) release: + +```bash +pip install deepspeed-mii +``` + +To generate text using a simple non-persistent pipeline deployment, run the following code. For more details, please see [Section 5](#using-deepspeed-fastgen). + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +## 2. Existing LLM Serving Techniques in Literature + +A text generation workload for a single sequence consists of two phases: 1) prompt processing, in which the user-provided text is efficiently processed as a batch of tokens to build a key-value (KV) cache for attention, and 2) token generation, which will add a single token to that cache and generate a new token. Over the course of generating a sequence of text, the model will make many forward calls to the model to generate the full sequence of text. Two major techniques have been proposed in the literature and deployed in systems that address various limitations and bottlenecks that may arise during these phases. + +_ Blocked KV Caching: _ + +vLLM identified that memory fragmentation due to large monolithic KV-caches significantly reduced the concurrency of LLM serving systems and proposed [Paged Attention](https://arxiv.org/pdf/2309.06180.pdf) to enable non-contiguous caches and increase total system throughput. Rather than assign individual variable-sized contiguous chunks of memory, the underlying storage in the KV cache is fixed-sized blocks (also known as pages). The blocked KV-cache increases system throughput by increasing the amount of potential sequence concurrency by eliminating KV-cache induced memory fragmentation. Non-contiguous KV cache implementations are also included in [HuggingFace TGI](https://github.com/huggingface/text-generation-inference) and [NVIDIA TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). + +_ Continuous Batching: _ + +In the past, dynamic batching, in which a server would wait for multiple requests to process in phase with each other, was used to improve GPU utilization. However, this approach has drawbacks, as it typically requires padding inputs to identical lengths or stalling the system to wait to construct a larger batch. + +Recent advancement in large language model (LLM) inference and serving has been focusing on fine granularity scheduling and optimizing memory efficiency. For instance, Orca proposes _iteration-level scheduling_ (also known as continuous batching) which makes distinct scheduling decisions at each forward pass of the model. This allows requests to join/leave the batch as needed, eliminating the need for padding requests thus improving the overall throughput. In addition to Orca, continuous batching has been implemented in NVIDIA TRT-LLM, HuggingFace TGI, and vLLM. + +In current systems, there are two primary approaches to implement continuous batching. In TGI and vLLM, the generation phase is preempted to perform prompt processing (called infill in TGI) before continuing with generation. In Orca, these phases are not distinguished; instead, Orca will add a prompt into the running batch so long as the total number of sequences doesn't reach a fixed bound. Both of these approaches to varying degrees need to stall generation to process long prompts (see [Section 3B](#splitfuse)). + +To address these shortcomings, we propose a novel prompt and generation composition strategy, Dynamic SplitFuse. + +## 3. Dynamic SplitFuse: A Novel Prompt and Generation Composition Strategy + +DeepSpeed-FastGen is built to leverage continuous batching and non-contiguous KV caches to enable increased occupancy and higher responsivity for serving LLMs in the data center, similar to existing frameworks such as TRT-LLM, TGI, and vLLM. In order to achieve a new level of performance, DeepSpeed-FastGen introduces SplitFuse which leverages dynamic prompt and generation decomposition and unification to further improve continuous batching and system throughput. + +### A. Three Performance Insights +Before describing Dynamic SplitFuse, we answer three key performance questions that together motivate its design. + +*__1. What factors impact the forward pass of a single LLM?__* In order to effectively schedule, it is necessary to understand what are the relevant independent variables the scheduling loop should control. We observe below that the composition of sequences in a forward pass (the batch size in sequences) has a negligible impact on performance compared to the raw number of tokens in the forward pass. This means an effective scheduler can be built around a single signal, the number of tokens in the forward pass. + +
+
+
+ +*__2. How does a model's throughput respond to changing the number of tokens in the forward pass?__* An LLM has two key operating regions with a relatively steep transition. With a small number of tokens, the GPU bottleneck is reading the model from memory and so throughput scales with the number of tokens, whereas with many tokens the model is throughput bound by compute and sees near-constant throughput. The model should run highly efficiently if all forward passes are in the throughput-saturating region. + +
+
+
+ +*__3. How should a pool of tokens be scheduled across multiple forward passes?__* We observe above that for well-aligned inputs the token-throughput curve is concave, which means the second derivative is bound to be less than or equal to 0. As an example, let $f(x)$ be a concave function of latency to throughput for a given model. For a concave function $f(x)$, the following holds: + + $$0 \geq \lim_{h \to 0} \frac{f(x + h) - 2f(x) + f(x - h)}{h^2}$$ + + $$0 \geq f(x + h) - 2f(x) + f(x - h)$$ + + $$2f(x) \geq f(x + h) + f(x - h)$$ + +This states that for a given pool of `2x` tokens to process, the manner that maximizes throughput is that which evenly splits them between two batches. More generally, in a system that must consume and process P tokens over F forward passes, the ideal partitioning scheme will divide them equally. + +### B. Dynamic SplitFuse + +Dynamic SplitFuse is a novel token composition strategy for prompt processing and token generation. DeepSpeed-FastGen utilizes Dynamic SplitFuse to run at a consistent forward size by leveraging the capability to take partial tokens from prompts and compose this with generation. In particular, Dynamic SplitFuse performs two key behaviors: + +1. Long prompts are decomposed into much smaller chunks and scheduled across multiple forward passes (iterations) with only the final pass performing any generation. +2. Short prompts will be composed to exactly fill a target token budget. Even short prompts may be decomposed to ensure the budget is precisely met and the forward sizes are well-aligned. + +Together, these two techniques provide concrete benefits on all user metrics: + +1. *__Better Responsiveness__:* Since long prompts no longer require extremely long forward passes to process, the model will provide lower client latency. More forward passes are performed within the same window of time. +2. *__Higher Efficiency:__* Fusion of short prompts to larger token budgets enables the model to consistently operate in the high throughput regime. +3. *__Lower variance and better consistency:__* Since forward passes are of consistent size and forward pass size is the primary determinant of performance, the latency of each forward pass is much more consistent than competing systems as is the perceived generation frequency. There are no pre-emption or long-running prompts to increase the latency as in other prior work. + +Consequently, DeepSpeed-FastGen will consume tokens from incoming prompts at a rate that permits fast ongoing generation while adding tokens to the system that increase system utilization, providing lower latency and higher throughput streaming generation to all clients as compared to other state-of-the-art serving systems. + +
+ +
+ + *Figure 1: Illustration of continuous batching strategies. Each block shows the execution of a forward pass. An arrow indicates that the forward pass has sequences with one or more tokens generated. vLLM performs either token generations or prompt processing in a forward pass; token generation preempts prompt processing. Orca runs prompts at their complete length alongside generation. Dynamic SplitFuse performs dynamic composition of fixed-sized batches composed of both generation and prompt tokens.* + +
+ +## 4. Performance Evaluation + +DeepSpeed-FastGen provides state-of-the-art LLM serving performance leveraging its blocked KV cache and Dynamic SplitFuse continuous batching. We evaluate DeepSpeed-FastGen against vLLM on a range of models and hardware configurations following the benchmarking methodology discussed below. + +### A. Benchmarking Methodology + +We use two primary quantitative schemes for measuring performance. + +**Throughput-Latency Curves:** Two key metrics for production readiness are throughput (measured in requests per second) and latency (the responsiveness of each request). To measure this, we instantiate multiple clients (ranging from 1 to 32) concurrently and send requests (512 in total) to the server. The resulting latency of each request is measured at the endpoint and throughput is measured by the end-to-end time to complete the experiment. + +**Effective Throughput:** Interactive applications, such as chat applications, can have more stringent and complex requirements than can be captured by top-level metrics like end-to-end latency. In particular, we focus on the increasingly popular chat user scenario: + + 1. A user initiates a task by sending a prompt. + 2. The system processes the prompt and returns the first token. + 3. Subsequent tokens are streamed to the user as they are produced. + +At each point in this process there is an opportunity for a system to provide an adverse user experience; for example, if the first token arrives too slowly or the generation appears to stop for some time. We propose an SLA framework that considers both of these dimensions. + +As the lengths of prompts and generated texts vary significantly, affecting computational costs, it is impractical to set rigid SLA values for throughput and latency. Therefore, we define the SLA for prompt latency as |tokens in prompt| / 512 seconds (= 512 tokens/s). Additionally, considering humans' reading speed, we set the SLA for generation latency on the Exponential Moving Average (EMA) to 2, 4, or 6 tokens/sec. Requests that adhere to these SLAs are deemed successful, and the throughput of these successful requests is referred to as **effective throughput**. + +We evaluate vLLM and DeepSpeed-FastGen on both Llama-2 7B, Llama-2 13B, and Llama-2 70B on NVIDIA A100, H100, and A6000. + +### B. Throughput-Latency Analysis + +In this experiment, DeepSpeed-FastGen outperforms vLLM in both throughput and latency, providing equivalent latency with greater throughput or more responsive latency and the same throughput. On Llama-2 70B with 4 A100x80GB, DeepSpeed-FastGen demonstrates up to 2x higher throughput (1.36 rps vs. 0.67 rps) at identical latency (9 seconds) or up to 50% latency reduction (7 seconds vs. 14 seconds) while achieving the same throughput (1.2 rps), as shown in Figure 2. These trends hold when evaluating Llama-2 13B as shown in Figure 3. + +
+
+ + *Figure 2: Throughput and latency of text generation using Llama 2 70B (Tensor parallelism across 4 A100-80GB GPUs). A normal distribution was applied to prompt and generation lengths with averages of 1200/2600 and 128/60, respectively, and a 30% variance* +

+ +
+
+ + *Figure 3: Throughput and latency of text generation using Llama 2 13B (A100-80GB GPU, no tensor parallelism). A normal distribution was applied to prompt and generation lengths with averages of 1200/2600 and 60/128, respectively, and a 30% variance* +
+ +### C. Effective Throughput Analysis + +Under the effective throughput analysis that considers both first token latency and the rate at which generation occurs, DeepSpeed-FastGen provides up to 2.3x higher throughput than vLLM. Figure 4 presents a comparative analysis of the effective throughputs of DeepSpeed-FastGen and vLLM. Each plotted point denotes the effective throughput derived from a specific number of clients. As we scaled the number of clients, we initially observed an increase in effective throughput. However, the latency also significantly increases as the number of clients approaches the system's capacity, causing many requests to fail in meeting the SLA. Consequently, the effective throughput will either saturate or decrease at some point. From a usability perspective, it's not particularly relevant how many clients are required to achieve the max effective throughput; the maximum point of the line is the optimal serving point. + +
+ + + *Figure 4: Effective throughput of DeepSpeed-FastGen and vLLM (Llama 2 70B/A100-80GB using tensor parallelism across 4 A100-80GB GPUs. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance)* +

+ +When vLLM preempts the ongoing generation of previous requests, the generation latency experiences a notable increase. This leads to vLLM's effective throughput appearing lower than its directly measured throughput. At vLLM's peak, the effective throughput was 0.63 queries/sec and around 28% of requests failed to meet the 4 tokens/s SLA. At the same SLA, DeepSpeed-FastGen achieved 1.42 queries/sec (less than 1% of requests failed to meet the SLA), which is 2.3x higher than vLLM. + +### D. Token Level Timing Analysis + +Figure 5 displays the P50, P90, and P95 latencies of the generation processes. Both vLLM and DeepSpeed-FlexGen exhibit similar P50 latencies, but vLLM demonstrates significantly higher latencies for P90 and P95. +Regarding the P95 latencies, DeepSpeed-FlexGen achieved a reduction of 3.7 times. + +This discrepancy is due to a noticeable spike in vLLM's generation latency when it preempts the ongoing generation to process new prompts. +In contrast, DeepSpeed-FastGen typically processes the prompt and generation for previous requests concurrently, leading to much more consistent generation latency. + + +
+
+ + *Figure 5: Per-Token generation Latency of Llama 2 70B/A100-80GB using tensor parallelism across 4 A100-80GB GPUs, 16 clients. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 128, respectively, and a 30% variance.* +

+ + +### E. Scalability using Load Balancing + +DeepSpeed-FastGen offers replica-level load balancing that evenly distributes requests across multiple servers, allowing you to effortlessly scale up your application. + +Figure 6 illustrates the scalability of DeepSpeed-FastGen when employing the load balancer and up to 16 replicas. Note that we utilized 4 A100 GPUs to compute the Llama 2 70B model. In total, we employed 8 nodes to run the 16 replicas. The results demonstrate nearly perfect scalability with DeepSpeed-FastGen. +Given that the throughput of a single replica is 1.46 queries/sec, the throughput with 16 replicas reaches 23.7 queries/sec, marking a linear 16x increase compared to a single replica. + +
+
+ + *Figure 6: Scalability using the load balancing feature. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+ +### F. Other Hardware Platforms + +In addition to the deep analysis on A100, we provide additional benchmarking results for H100 and A6000. The same performance trends were observed on both A6000 and H100 as A100. + +
+
+ + *Figure 7: Throughput-latency curve and effective throughput of Llama 2 70b using 8 H100 GPUs. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+ +
+
+ + *Figure 8: Throughput-latency curve and effective throughput of Llama 2 7b using A6000. A normal distribution was applied to prompt and generation lengths with averages of 2600 and 60, respectively, and a 30% variance*
+
+ +## 5. DeepSpeed-FastGen: Implementation and Usage + +DeepSpeed-FastGen is the synergistic composition of [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII) and [DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed) as illustrated in the figure below. Together, both of these software packages provide various components of the system including the frontend APIs, the host and device infrastructure to schedule batches using Dynamic SplitFuse, optimized kernel implementations, and the tools to construct new model implementations. + +
+ + +
+ + +The fastest way to get started with our alpha release of DeepSpeed-FastGen is: `pip install deepspeed-mii`. + +Please follow our [Getting Started](https://github.com/microsoft/deepspeed-mii#getting-started-with-mii) guide for more details. For usage and reporting issues, please use the [DeepSpeed-MII Github repository](https://github.com/microsoft/DeepSpeed-MII). + +### A. Supported Models + +We currently support the following model architectures in this alpha release of DeepSpeed-FastGen: + +* [LLaMA](https://huggingface.co/models?other=llama) and [LLaMA-2](https://huggingface.co/models?other=llama-2) +* [Mistral](https://huggingface.co/models?other=mistral) +* [OPT](https://huggingface.co/models?other=opt) + +All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer. + +We plan to add additional models in the coming weeks and months after the initial release. If there are specific model architectures you would like supported, please [file an issue](https://github.com/microsoft/DeepSpeed-MII/issues) and let us know. + +### B. Deployment options +All of the examples below are runnable in [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/mii). Once installed you have two options for deployment: an interactive non-persistent pipeline or a persistent serving deployment: + +#### Non-persistent pipeline + +The non-persistent pipeline deployment is a great and fast way to get started and can be done with only a few lines of code. Non-persistent models are only around for the duration of the python script you are running but are useful for temporary interactive sessions. + +```python +from mii import pipeline +pipe = pipeline("mistralai/Mistral-7B-v0.1") +output = pipe(["Hello, my name is", "DeepSpeed is"], max_new_tokens=128) +print(output) +``` + +#### Persistent deployment + +A persistent deployment is ideal for use with long-running and production applications. The persistent deployment uses a lightweight GRPC server that can be created using the following 2 lines: + + +```python +import mii +mii.serve("mistralai/Mistral-7B-v0.1") +``` + +The above server can be queried by multiple clients at once thanks to the built-in load balancer from DeepSpeed-MII. Creating a client also just takes 2 lines of code: + +```python +client = mii.client("mistralai/Mistral-7B-v0.1") +output = client.generate("Deepspeed is", max_new_tokens=128) +print(output) +``` + +A persistent deployment can be terminated when it is no longer needed: + +```python +client.terminate_server() +``` + +### C. Advanced Installation Information + +For ease of use and a significant reduction in lengthy compile times that many projects require in this space, we distribute a pre-compiled Python wheel covering the majority of our custom kernels through a new library called [DeepSpeed-Kernels](https://github.com/microsoft/DeepSpeed-Kernels). We have found this library to be very portable across environments with NVIDIA GPUs with compute capabilities 8.0+ (Ampere+), CUDA 11.6+, and Ubuntu 20+. In most cases, you shouldn't even need to know this library exists as it is a dependency of DeepSpeed-MII and will be installed with it. However, if for whatever reason you need to compile our kernels manually please see our [advanced installation docs](https://github.com/microsoft/DeepSpeed-Kernels#source). + + +# 6. Try Out DeepSpeed-FastGen +We are very excited to share this DeepSpeed-FastGen alpha release. + +* To get started, please visit our GitHub page for DeepSpeed-MII: [GitHub Landing Page](https://github.com/microsoft/DeepSpeed-MII) + +DeepSpeed-FastGen is part of the bigger DeepSpeed ecosystem comprising a multitude of Deep Learning systems and modeling technologies. To learn more, + +* Please visit our [website](https://www.deepspeed.ai/) for detailed blog posts, tutorials, and helpful documentation. +* You can also follow us on our [English Twitter](https://twitter.com/MSFTDeepSpeed), [Japanese Twitter](https://twitter.com/MSFTDeepSpeedJP), and [Chinese Zhihu](https://www.zhihu.com/people/deepspeed) for latest news on DeepSpeed. + +DeepSpeed welcomes your contributions! We encourage you to report issues, contribute PRs, and join discussions on the [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/) page. Please see our [contributing guide](https://github.com/microsoft/DeepSpeed/blob/master/CONTRIBUTING.md) for more details. We are open to collaborations with universities, research labs, and companies, such as those working together on deep learning research, applying DeepSpeed to empower real-world AI models and applications, and so on. For such requests (and other requests unsuitable for GitHub), please directly email to deepspeed-info@microsoft.com. + +The following items are on our roadmap and we plan to engage with our community on these through our GitHub issues and PRs: + +- Performance improvements +- Broader model support +- New hardware backends through collaboration with partners +- Release performance benchmarks (used to generate plots in this blog) + +**"Star" our [DeepSpeed GitHub](https://github.com/microsoft/DeepSpeed/) and [DeepSpeedMII GitHub](https://github.com/microsoft/DeepSpeed-MII/) repositories if you like our work!** + +# 7. Acknowledgements + +We would like to thank various open-source community projects including HuggingFace, vLLM, and HuggingFace TGI. We have leveraged HF APIs to support models and tokenizers in our alpha release and will continue to add more models. We especially acknowledge and thank the developers of [Flash Attention](https://github.com/Dao-AILab/flash-attention) for their great work. We have extensively leveraged FlashAttention kernels in our system with modifications that have been acknowledged in our code repositories at appropriate file headers. Finally, we want to thank the developers of [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) kernels that we have used in our MoE kernels (released as part of DeepSpeed-Kernels repository). diff --git a/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png b/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png new file mode 100644 index 000000000000..9d4ab55f5f7a Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/A6000_benchmark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png b/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png new file mode 100644 index 000000000000..89fb9ca3e1ce Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/H100_benchmark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/effective_throughput.png b/blogs/deepspeed-fastgen/assets/images/effective_throughput.png new file mode 100644 index 000000000000..11c7f82bc54f Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/effective_throughput.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png b/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png new file mode 100644 index 000000000000..1b9a38306e8e Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/effective_throughput_main.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg b/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg new file mode 100644 index 000000000000..2affbf8a4cc3 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fast-gen-overview.jpg differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png new file mode 100644 index 000000000000..9b90357a3f1b Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-dark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png new file mode 100644 index 000000000000..9e754abde85d Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-arch-light.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png new file mode 100755 index 000000000000..6ac1a775805b Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-dark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png new file mode 100755 index 000000000000..af8f1defe653 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-hero-light.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png new file mode 100755 index 000000000000..dde598a985d8 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-dark.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png new file mode 100755 index 000000000000..bdb5f8df483e Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/fastgen-overview-light.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png new file mode 100644 index 000000000000..6d45880588d9 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-flops.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png new file mode 100644 index 000000000000..7c14e2bf6e53 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/observation-prompt-v-latency.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png b/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png new file mode 100644 index 000000000000..834c06dfb07a Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/repl_scale_llama70b_tp4_p2600g60.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png new file mode 100644 index 000000000000..df16b5bebc53 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p1200g128.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png new file mode 100644 index 000000000000..8b69a8a1718b Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/th_lat_curve_llama70b_tp4_p2600g128.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/throughput_latency.png b/blogs/deepspeed-fastgen/assets/images/throughput_latency.png new file mode 100644 index 000000000000..aaceebde7038 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/throughput_latency.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png b/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png new file mode 100644 index 000000000000..cc7b8ec1ec05 Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/throughput_latency_13B_no_arrow.png differ diff --git a/blogs/deepspeed-fastgen/assets/images/token_latency.png b/blogs/deepspeed-fastgen/assets/images/token_latency.png new file mode 100644 index 000000000000..405a3c0d06ed Binary files /dev/null and b/blogs/deepspeed-fastgen/assets/images/token_latency.png differ diff --git a/csrc/includes/activation_type.h b/csrc/includes/activation_type.h new file mode 100644 index 000000000000..a44921d5d650 --- /dev/null +++ b/csrc/includes/activation_type.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +enum ActivationType { + GELU = 0, + RELU = 1, + SILU = 2, + GEGLU = 3, + ReGLU = 4, + SiGLU = 5, + IDENTITY = 6, + InvalidType = -1 +}; diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h index a3af561adfe5..8e4888109fcd 100644 --- a/csrc/includes/ds_kernel_utils.h +++ b/csrc/includes/ds_kernel_utils.h @@ -11,6 +11,11 @@ used throughout the codebase. #pragma once #include +#include + +#ifdef BF16_AVAILABLE +#include +#endif #define DS_HD_INLINE __host__ __device__ __forceinline__ #define DS_D_INLINE __device__ __forceinline__ diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index 8d0c2d6986a1..eb8efab77ac1 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -233,6 +233,60 @@ DS_D_INLINE __half2 element(const __half2 lhs, const __half2 rhs) #endif } +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int32_t element(const int32_t lhs, const int32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE uint32_t element(const uint32_t lhs, const uint32_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return lhs + rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs > rhs) ? lhs : rhs; +} + +template <> +DS_D_INLINE int64_t element(const int64_t lhs, const int64_t rhs) +{ + return (lhs < rhs) ? lhs : rhs; +} + /* Reduction initialization primitives */ @@ -310,6 +364,78 @@ DS_D_INLINE __half2 init() #endif } +template <> +DS_D_INLINE int32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x7FFFFFFF; +} + +template <> +DS_D_INLINE int32_t init() +{ + return 0x80000000; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0xFFFFFFFF; +} + +template <> +DS_D_INLINE uint32_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x7FFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE int64_t init() +{ + return 0x8000000000000000; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0xFFFFFFFFFFFFFFFF; +} + +template <> +DS_D_INLINE uint64_t init() +{ + return 0; +} + template DS_D_INLINE void init(T* data) { @@ -352,8 +478,8 @@ here (fold is C++17 only and I don't think helps and recursion feels like huge overkill that harms readability) that would be wonderful. */ -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -361,8 +487,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) } } -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -371,8 +497,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) } } -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -382,8 +508,13 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) } } -template -DS_D_INLINE void _warp(cg::thread_block_tile& warp, float* data) +template +DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { @@ -403,16 +534,15 @@ the number of warps in the block (which may exceed that if the block is partitioned or if we do a conservative bound at compile time). */ -template +template DS_D_INLINE void _block(cg::thread_block& tb, cg::thread_block_tile& warp_arg, - float* data) + T* data) { constexpr int elems = sizeof...(Ops); - // Separated for now in case this no longer is true - constexpr int bytes = sizeof(float); + constexpr int bytes = sizeof(T); // Unused when `partition_size == 1` or total_warps == 1 - __shared__ float reduce_buffer[max_warps * elems]; + __shared__ T reduce_buffer[max_warps * elems]; #ifdef __HIP_PLATFORM_AMD__ const int total_threads = blockDim.x * blockDim.y * blockDim.z; @@ -422,7 +552,7 @@ DS_D_INLINE void _block(cg::thread_block& tb, #endif // Always perform warp-scope reduction - _warp(warp_arg, data); + _warp(warp_arg, data); // If max_warps == 1 let's skip the runtime check if (total_warps != 1) { @@ -447,7 +577,7 @@ DS_D_INLINE void _block(cg::thread_block& tb, init(data); } - _warp(warp_arg, data); + _warp(warp_arg, data); #pragma unroll for (int i = 0; i < elems; i++) { @@ -476,7 +606,7 @@ us to obfuscate the details of the partitioned implementation. template DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile& warp, float& val) { - _block(tb, warp, &val); + _block(tb, warp, &val); } template @@ -486,7 +616,7 @@ DS_D_INLINE void block(cg::thread_block& tb, float& val2) { float data[2] = {val1, val2}; - _block(tb, warp, data); + _block(tb, warp, data); val1 = data[0]; val2 = data[1]; } @@ -499,7 +629,7 @@ DS_D_INLINE void block(cg::thread_block& tb, float& val3) { float data[3] = {val1, val2, val3}; - _block(tb, warp, data); + _block(tb, warp, data); val1 = data[0]; val2 = data[1]; val3 = data[2]; @@ -514,7 +644,7 @@ DS_D_INLINE void block(cg::thread_block& tb, float& val4) { float data[4] = {val1, val2, val3, val4}; - _block(tb, warp, data); + _block(tb, warp, data); val1 = data[0]; val2 = data[1]; val3 = data[2]; @@ -531,10 +661,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float& val) { if (num_threads <= hw_warp_size) { - _warp(warp, &val); + _warp(warp, &val); } else { constexpr int num_warps = num_threads / hw_warp_size; - _block(tb, warp, &val); + _block(tb, warp, &val); } } @@ -547,10 +677,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float data[2] = {val1, val2}; if (num_threads <= hw_warp_size) { - _warp(warp, data); + _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; - _block(tb, warp, data); + _block(tb, warp, data); } val1 = data[0]; @@ -567,10 +697,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float data[3] = {val1, val2, val3}; if (num_threads <= hw_warp_size) { - _warp(warp, data); + _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; - _block(tb, warp, data); + _block(tb, warp, data); } val1 = data[0]; @@ -589,10 +719,10 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, float data[4] = {val1, val2, val3, val4}; if (num_threads <= hw_warp_size) { - _warp(warp, data); + _warp(warp, data); } else { constexpr int num_warps = num_threads / hw_warp_size; - _block(tb, warp, data); + _block(tb, warp, data); } val1 = data[0]; @@ -601,4 +731,48 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb, val4 = data[3]; } +/* +Arg-reduce is a specialization of the above. We only support this with a single reduction +parameter. This only works for max/min reductions. +*/ + +__align__(8) struct IdxReduceResult { + /* + NOTE: ORDERING MATTERS HERE! The idx is the least significant set of bits + and the val is the most significant. Changing the order of this declaration + will break the code. + */ + int idx; + float val; +}; + +template +DS_D_INLINE IdxReduceResult +idx_reduce(cg::thread_block& tb, cg::thread_block_tile& warp, float val, int idx) +{ + IdxReduceResult res = {idx, val}; + + // Clear out the nan. This shouldn't be an issue for our initial applications + if (isnan(val)) res.val = init(); + + // Can do float compares as integers. By packing the index into the lower bits + // we can just do a single int64 rather than a branch, compare, and select. + // One side benefit of this is that it is by nature a stable algorithm and + // will always bias ties to the higher index. + int64_t* res_as_int = reinterpret_cast(&res); + + // The way floating point compare works is normally to perform a sign comparison + // and if they match, then do a comparison of the rest of the bits as unsigned + // integers. Since we are bundling these, that means for negative values we need + // to reverse the sort order, which we can do with an XOR. + if (val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + _block(tb, warp, res_as_int); + + // Sign bit is preserved, so we can check if we need to invert the mantissa back + if (res.val < 0) { *res_as_int ^= 0x7fffffff00000000; } + + return res; +} + } // namespace reduce diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 3e6701d81e64..dc7ff4d1e7c0 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -18,8 +18,8 @@ using __nv_bfloat162 = __half2; inline __device__ float gelu(const float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715; return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); } diff --git a/deepspeed/inference/__init__.py b/deepspeed/inference/__init__.py index 208299fb8c50..7fed50cbe177 100644 --- a/deepspeed/inference/__init__.py +++ b/deepspeed/inference/__init__.py @@ -2,3 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team +from .v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig +from .v2.engine_v2 import InferenceEngineV2 +from .v2 import build_hf_engine diff --git a/deepspeed/inference/v2/__init__.py b/deepspeed/inference/v2/__init__.py new file mode 100644 index 000000000000..bba5d1a82081 --- /dev/null +++ b/deepspeed/inference/v2/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from .config_v2 import RaggedInferenceEngineConfig, DeepSpeedTPConfig +from .engine_v2 import InferenceEngineV2 +from .engine_factory import build_hf_engine diff --git a/deepspeed/inference/v2/allocator.py b/deepspeed/inference/v2/allocator.py new file mode 100644 index 000000000000..fa2c5368604e --- /dev/null +++ b/deepspeed/inference/v2/allocator.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from functools import reduce +from typing import Iterable + +import torch + +from deepspeed.accelerator import get_accelerator + + +def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor: + shape_size = reduce(lambda x, y: x * y, shape) + if shape_size == 0: + raise ValueError("Cannot create empty tensor with size 0") + return tensor.flatten()[:shape_size].view(shape) + + +def on_device(method) -> torch.Tensor: + """ + Wraps a method to ensure the returned tensor is on the current device. + """ + + def wrapped(self, *args, **kwargs): + tensor = method(self, *args, **kwargs) + if isinstance(tensor, torch.Tensor): + return tensor.to(get_accelerator().current_device()).contiguous() + return tensor + + return wrapped diff --git a/deepspeed/inference/v2/checkpoint/__init__.py b/deepspeed/inference/v2/checkpoint/__init__.py new file mode 100644 index 000000000000..45e523ab62b9 --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .base_engine import CheckpointEngineBase +from .in_memory_engine import InMemoryModelEngine +from .huggingface_engine import HuggingFaceCheckpointEngine diff --git a/deepspeed/inference/v2/checkpoint/base_engine.py b/deepspeed/inference/v2/checkpoint/base_engine.py new file mode 100644 index 000000000000..26fc467d4d86 --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/base_engine.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod +from typing import Iterable, Tuple + +import torch + +#from .huggingface_engine import HuggingFaceCheckpointEngine + +MEGATRON = 'megatron' +HUGGINGFACE = 'huggingface' + + +class CheckpointEngineBase(ABC): + """ + Abstract interface for checkpoint engines to implement. + + There is no ``__init__`` method here by design, since the creation of the checkpoint + engine will happen outside the policy/engine code. The tradeoff being made here is + that we will write different frontends for different checkpoint engines, but these + frontends can be tailored to the specific checkpoint engine/model source needs. + """ + + @abstractmethod + def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: + """ + This method should create a generator of tuples of the form (name, parameter) for + all parameters in the model. The name should be the fully qualified name of the + parameter, and the parameter should be a torch.Tensor. + + The expected use of a checkpoint engine is the following: + ```python + for name, parameter in checkpoint_engine.parameters(): + container_map.map_param(name, parameter) + ``` + For a concrete use example, see ``InferenceV2Policy``. + """ + ... diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py new file mode 100644 index 000000000000..515378d31d02 --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import json +import torch +from .base_engine import CheckpointEngineBase +from typing import Iterable, Tuple + +from ..logging import inference_logger + + +class HuggingFaceCheckpointEngine(CheckpointEngineBase): + + def __init__(self, model_name_or_path: str, auth_token: str = None) -> None: + super().__init__() + from transformers import AutoConfig, GenerationConfig + + self.model_name_or_path = model_name_or_path + self.auth_token = auth_token + self.model_config = AutoConfig.from_pretrained(self.model_name_or_path) + self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path) + # Define this property here so we can use it in the model implementation + if not hasattr(self.model_config, "max_seq_length"): + self.model_config.max_seq_length = self.model_config.max_position_embeddings + else: + self.model_config.max_seq_length = self.generation_config.max_length + + self._all_ckpt_paths = self._fetch_checkpoint_files() + + def _fetch_checkpoint_files(self): + """ + Fetch the checkpoint files from the HuggingFace Hub. + """ + # TODO(jeff): for models like llama-2 the user will have to provide an auth `token`, + # currently coming from the ckpt engine init but maybe a catch all kwargs for other + # snapshot download parameters would be more flexible. + + # NOTE(jeff): allow_patterns here are explicitly not using safetensors or other + # checkpoint files that may be present. Example of all files in the llama-2-7b + # repo here: https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main + from huggingface_hub import snapshot_download + + if os.path.isdir(self.model_name_or_path): + self._local_checkpoint_dir = self.model_name_or_path + else: + self._local_checkpoint_dir = snapshot_download(self.model_name_or_path, + allow_patterns=[ + "*.bin", + "*.json", + "*.pt", + ], + revision=None, + token=self.auth_token) + + assert os.path.isdir( + self._local_checkpoint_dir + ), f"Checkpoint dir {self._local_checkpoint_dir} is not a directory, cannot load checkpoint." + + model_param_json = os.path.join(self._local_checkpoint_dir, "pytorch_model.bin.index.json") + + if not os.path.isfile(model_param_json): + # We don't need any json as all such HF models will have pytorch_model.bin + all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, 'pytorch_model.bin')] + else: + param_map = json.load(open(model_param_json, "r")) + + # weight_map -> { "lm_head.weight": "pytorch_model-00002-of-00002.bin", ... } + weight_map = param_map["weight_map"] + + # unique set of all checkpoint files + all_checkpoint_files = set(weight_map.values()) + + # get absolute path of all unique checkpoint files + all_checkpoint_files = [os.path.join(self._local_checkpoint_dir, f) for f in all_checkpoint_files] + + return all_checkpoint_files + + def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: + """ + Generator of model parameters (satisfies the CheckpointEngineBase interface). + """ + for checkpoint in self._all_ckpt_paths: + inference_logger().info(f"Loading checkpoint: {checkpoint}") + checkpoint_sd = torch.load(checkpoint, map_location='cpu') + param_keys = list(checkpoint_sd.keys()) + for param_name in param_keys: + param = checkpoint_sd[param_name] + yield param_name, param + + +if __name__ == "__main__": + # To test, add your auth_token here and run `python huggingface_engine.py` + engine = HuggingFaceCheckpointEngine(model_name_or_path="meta-llama/Llama-2-7b-hf", + auth_token="hf_xxxxxxxxxxxxxxxxx") + for name, param in engine.parameters(): + print(name, param.shape) diff --git a/deepspeed/inference/v2/checkpoint/in_memory_engine.py b/deepspeed/inference/v2/checkpoint/in_memory_engine.py new file mode 100644 index 000000000000..13ec7b288f5f --- /dev/null +++ b/deepspeed/inference/v2/checkpoint/in_memory_engine.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Tuple +import torch + +from .base_engine import CheckpointEngineBase + + +class InMemoryModelEngine(CheckpointEngineBase): + """ + This "checkpoint" engine uses the existing interface to enable loading parameters into an + inference model from a model already instantiated in memory. In general, this is not the + recommended way to use the inference engine, and should only be used when absolutely necessary. + + The primary limitation of this approach is that the model must be fully instantiated in memory. + In a tensor parallel scenario, this means that the model is either replicated many times in host + memory. Currently, it is also recommended to only use this approach for models held in host memory. + + In order to free the memory held by this copy of the model, we delete the model in the first call + to `parameters`, so it is not safe to make this call twice. + """ + + def __init__(self, model: torch.nn.Module) -> None: + """ + Create virtual checkpoint engine for the provided module. + + Args: + model (torch.nn.Module): Model to load parameters from. + """ + super().__init__() + self.model = model + + def parameters(self) -> Iterable[Tuple[str, torch.Tensor]]: + for name, parameter in self.model.named_parameters(): + yield name, parameter + + del self.model diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py new file mode 100644 index 000000000000..64e7e29b1844 --- /dev/null +++ b/deepspeed/inference/v2/config_v2.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.pydantic_v1 import Field + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from .ragged import DSStateManagerConfig + + +class DeepSpeedTPConfig(DeepSpeedConfigModel): + """ Configure tensor parallelism settings """ + + tp_size: int = 1 + """ Number of devices to split the model across using tensor parallelism. """ + + +class RaggedInferenceEngineConfig(DeepSpeedConfigModel): + """ Sets parameters for DeepSpeed Inference Engine. """ + + tensor_parallel: DeepSpeedTPConfig = Field({}, alias="tp") + """ + Configuration for tensor parallelism used to split the model across several + GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`. + """ + + state_manager: DSStateManagerConfig = Field({}, alias="manager") + """ + Configuration for managing persistent state + """ diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py new file mode 100644 index 000000000000..48274d6c3d53 --- /dev/null +++ b/deepspeed/inference/v2/engine_factory.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import logging +from typing import Any + +from .engine_v2 import InferenceEngineV2 +from .config_v2 import RaggedInferenceEngineConfig +from .checkpoint import HuggingFaceCheckpointEngine +from .logging import inference_logger + + +def build_hf_engine(path: str, + engine_config: RaggedInferenceEngineConfig, + debug_level: int = logging.INFO, + random_weights_config: Any = None, + fill_random: bool = False) -> InferenceEngineV2: + """ + Build an InferenceV2 engine for HuggingFace models. + """ + # Set up logging + inference_logger(level=debug_level) + + # get HF checkpoint engine + checkpoint_engine = HuggingFaceCheckpointEngine(path) + + # get model config from HF AutoConfig + model_config = checkpoint_engine.model_config + + # get the policy + # TODO: generalize this to other models + if model_config.model_type == "opt": + from .model_implementations.opt.policy import OPTPolicy + policy = OPTPolicy(checkpoint_engine, model_config) + elif model_config.model_type == "llama": + from .model_implementations.llama_v2.llama_v2_policy import Llama2Policy + policy = Llama2Policy(checkpoint_engine, model_config) + elif model_config.model_type == "mistral": + from .model_implementations.mistral.policy import MistralPolicy + policy = MistralPolicy(checkpoint_engine, model_config) + else: + raise ValueError(f"Unsupported model type {model_config.model_type}") + + return InferenceEngineV2(policy, engine_config) diff --git a/deepspeed/inference/v2/engine_v2.py b/deepspeed/inference/v2/engine_v2.py new file mode 100644 index 000000000000..c670ddd09417 --- /dev/null +++ b/deepspeed/inference/v2/engine_v2.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +from typing import Iterable, Tuple + +import torch + +import deepspeed.comm as dist + +from deepspeed.accelerator import get_accelerator +from deepspeed.comm.comm import init_distributed + +from .model_implementations import InferenceV2Policy +from .logging import inference_logger +from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor +from .scheduling_utils import SchedulingError, SchedulingResult + +from .config_v2 import RaggedInferenceEngineConfig + +INFERENCE_MODEL_TIMER = "model-forward-inference" + + +class InferenceEngineV2: + + _config: RaggedInferenceEngineConfig + """ + Configuration of the inference engine. + """ + + #_model: DSInferenceModelBase + """ + Inference model supporting ragged inference. + """ + + _state_manager: DSStateManager + """ + Persistent state manager for sequences and KV-cache. + """ + + @property + def free_blocks(self) -> int: + """ + Number of free KV blocks. + """ + return self._state_manager.free_blocks + + def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None: + """ + Create the Inference V2 engine. + + Arguments: + policy (InferenceV2Policy): Policy for the model implementation. This policy object + will be used to build the model and load the checkpoint associated with it. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + """ + self._config = engine_config + self._policy = policy + self._base_mp_group = self._initialize_tp_group() + + # Build model from policy + inference_logger().info("Building model...") + self._model = self._policy.build_model(self._config, self._base_mp_group) + inference_logger().info("Model built.") + + # Create state manager + self._batch = RaggedBatchWrapper(self._config.state_manager) + self._state_manager = DSStateManager(self._config.state_manager, + self._model.kv_cache_config(), + base_mp_group=self._base_mp_group) + self._model.set_state_manager(self._state_manager) + + def _initialize_tp_group(self): + """ + Implementation of our TP group initialization. + """ + init_distributed() + local_rank = int(os.getenv("LOCAL_RANK", 0)) + get_accelerator().set_device(local_rank) + + if local_rank >= self._config.tensor_parallel.tp_size: + raise RuntimeError("Local rank is greater than TP size, ensure that the TP config is correct.") + + ranks = list(range(self._config.tensor_parallel.tp_size)) + return dist.new_group(ranks=ranks) + + def put(self, batch_uids: Iterable[int], batch_tokens: Iterable[torch.Tensor]) -> torch.Tensor: + """ + Put a ragged batch onto the inference engine. This will perform one forward and return + a Tensor of the shape [len(batch_uids), *output_shape]. Logits for the non-final tokens + are not calculated. + + Arguments: + batch_uids: Iterable of uids for the batch on the host + batch_tokens: Iterable of token tensors for the batch on the host + """ + + token_lens = [len(tokens) for tokens in batch_tokens] + schedule_check = self.can_schedule(batch_uids, token_lens) + if schedule_check != SchedulingResult.Success: + raise SchedulingError(schedule_check) + + self._batch.clear() + for uid, tokens in zip(batch_uids, batch_tokens): + + host_seq_desc = self._state_manager.get_or_create_sequence(uid) + self._model.maybe_allocate_kv(host_seq_desc, tokens.numel()) + host_seq_desc.pre_forward(tokens.numel()) + + # We can disable checks since we already validated schedulability. + self._batch.insert_sequence(host_seq_desc, tokens, do_checks=False) + + # Send all metadata to the device + self._batch.finalize() + + # Prep all data structures for the actual forward (in anticipation of CG in the future) + # and also to amortize some of the costs in a more straightforward way. + self._model.prepare_batch(self._batch) + + # Model implementation will pick up in the forward. + logits = self._model.forward(self._batch) + + # We return one set of logits per sequence in the batch (saves cost on unembedding) + assert logits.shape[0] == self._batch.current_sequences + + for uid in batch_uids: + host_seq_desc = self._state_manager.get_sequence(uid) + host_seq_desc.post_forward() # Updates sequence metadata. + self._model.maybe_free_kv(host_seq_desc) + + return logits + + def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, int]: + """ + Determine the number of tokens and KV blocks to reserve for a given request. Given a UID + (this UID may not be recognized by the model yet), this will return the number of tokens + and blocks to reserve for the request. + + Arguments: + uid (int): The UID of the sequence (as tracked by the scheduling entity). If + this is a new sequence (with a UID unknown to the inference engine), then + an empty placeholder is created to pass to the occupancy logic. + n_tokens (int): The number of tokens to hypothetically send. + + Returns: + Tuple[int, Optional[int]]: Tuple of free kv blocks and the number of blocks + required to schedule the sequence. + """ + seq_desc = self._state_manager.get_sequence(uid) + if seq_desc is None: + if (self._state_manager.n_tracked_sequences == self._config.state_manager.max_tracked_sequences): + return (0, 0) + seq_desc = PlaceholderSequenceDescriptor() + + req_tokens, req_blocks = self._model.get_kv_requirements(seq_desc, max_request_tokens, max_request_blocks) + + return (req_tokens, req_blocks) + + def can_schedule(self, uids: Iterable[int], lengths: Iterable[int]) -> SchedulingResult: + """ + Dry run a batch to determine if it can be scheduled. Placeholder sequences will be + created for any UIDs that are unknown to the inference engine. + + Arguments: + uids (Iterable[int]): Iterable of UIDs for the batch + lengths (Iterable[int]): Iterable of lengths for each sequence of the batch. This lengths + corresponds to the number of tokens to send in the hypothetical forward; history + tokens will be determined via UID lookup and future tokens are disregarded. + + Returns: + bool: True if the batch can be scheduled, False otherwise. + """ + + cur_seqs = self._state_manager.n_tracked_sequences + free_blocks = self._state_manager.free_blocks + req_blocks = 0 + batch_len = 0 + + if len(uids) > self._config.state_manager.max_ragged_sequence_count: + # Can only compose a batch from a limited number of sequences + return SchedulingResult.BatchSequenceLimitExceeded + + for uid, length in zip(uids, lengths): + seq_desc = self._state_manager.get_sequence(uid) + if seq_desc is None: + cur_seqs += 1 + seq_desc = PlaceholderSequenceDescriptor() + + sched_len, sched_blocks = self._model.get_kv_requirements(seq_desc, length, free_blocks) + + if sched_len != length: + # We ran out of KV cache + return SchedulingResult.KVCacheLimitExceeded + + batch_len += length + free_blocks -= sched_blocks + + if cur_seqs > self._config.state_manager.max_tracked_sequences: + # Would run out of tracking metadata + return SchedulingResult.EngineSequenceLimitExceeded + + if batch_len > self._config.state_manager.max_ragged_batch_size: + # Would exceed the maximum batch size + return SchedulingResult.BatchTokenLimitExceeded + + return SchedulingResult.Success + + def flush(self, uid: int) -> None: + """ + Remove all state associated with a sequence from the inference engine. + + Arguments: + uid (int): The UID of the sequence to flush. + """ + self._state_manager.flush_sequence(uid) diff --git a/deepspeed/inference/v2/inference_utils.py b/deepspeed/inference/v2/inference_utils.py new file mode 100644 index 000000000000..7b2dd4237353 --- /dev/null +++ b/deepspeed/inference/v2/inference_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Dict + +import torch + +from enum import Enum, IntEnum + + +class NormTypeEnum(Enum): + LayerNorm: str = "layer_norm" + RMSNorm: str = "rms_norm" + + +class DtypeEnum(Enum): + # The torch dtype must always be the first value (so we return torch.dtype) + fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" + fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" + bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16", "bfloat" + int8 = torch.int8, "torch.int8", "int8" + + # Copied from https://stackoverflow.com/a/43210118 + # Allows us to use multiple values for each Enum index and returns first + # listed value when Enum is called + def __new__(cls, *values): + obj = object.__new__(cls) + # first value is canonical value + obj._value_ = values[0] + for other_value in values[1:]: + cls._value2member_map_[other_value] = obj + obj._all_values = values + return obj + + def __repr__(self): + return "<%s.%s: %s>" % ( + self.__class__.__name__, + self._name_, + ", ".join([repr(v) for v in self._all_values]), + ) + + +ELEM_SIZES: Dict[torch.dtype, int] = { + torch.float16: 2, + torch.bfloat16: 2, + torch.float32: 4, + torch.float64: 8, + torch.int8: 1, + torch.uint8: 1, + torch.int16: 2, + torch.int32: 4, + torch.int64: 8, + torch.bool: 1, +} + + +class ActivationType(IntEnum): + """ + Types of activations supported by DS-Inference + """ + + GELU = 0 + + RELU = 1 + + SILU = 2 + + GEGLU = 3 + + ReGLU = 4 + + SiGLU = 5 + + IDENTITY = 6 + + InvalidType = -1 + + +def is_gated(act_fn: ActivationType) -> bool: + """ + Return True if the given activation function is gated. + """ + if not isinstance(act_fn, ActivationType): + act_fn = ActivationType(act_fn) + + return act_fn in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU] + + +def elem_size(dtype: torch.dtype) -> int: + """ + Return size in bytes of the given dtype. + """ + try: + return ELEM_SIZES[dtype] + except KeyError: + raise ValueError("Unknown dtype size for {}".format(dtype)) + + +def ceil_div(a: int, b: int) -> int: + """ + Return ceil(a / b). + """ + return -(-a // b) diff --git a/deepspeed/inference/v2/kernels/__init__.py b/deepspeed/inference/v2/kernels/__init__.py new file mode 100644 index 000000000000..01b7b0580073 --- /dev/null +++ b/deepspeed/inference/v2/kernels/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .ds_kernel import DSKernelBase diff --git a/deepspeed/inference/v2/kernels/core_ops/__init__.py b/deepspeed/inference/v2/kernels/core_ops/__init__.py new file mode 100644 index 000000000000..bbb53e5b58a2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .bias_activations import * +from .blas_kernels import * +from .cuda_layer_norm import * +from .cuda_rms_norm import * +from .gated_activations import * diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py b/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py new file mode 100644 index 000000000000..ea7f8a7d1996 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .bias_activation import * diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp new file mode 100644 index 000000000000..4f0cc9cbd77c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "bias_activation.h" +#include +#include "ds_kernel_utils.h" + +#ifdef BF16_AVAILABLE +#define DTYPE_SWITCH(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using scalar_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#else +#define DTYPE_SWITCH(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#endif + +/* +In-place bias and activation fusion kernel. +*/ +void bias_activation(torch::Tensor& activation, + c10::optional& bias, + const int32_t act_type) +{ + const ActivationType atype = static_cast(act_type); + const int32_t rows = activation.size(0); + const int32_t cols = activation.size(1); + + TORCH_CHECK(atype == ActivationType::GELU || atype == ActivationType::RELU || + atype == ActivationType::SILU || atype == ActivationType::IDENTITY, + "Unsupported activation type for BiasActivation"); + TORCH_CHECK(activation.dim() == 2, "BiasActivation only supports 2D activation tensors"); + + DTYPE_SWITCH(activation.scalar_type(), [&] { + scalar_t* activation_ptr = reinterpret_cast(activation.data_ptr()); + + const scalar_t* bias_ptr; + if (bias.has_value()) { + TORCH_CHECK(activation.scalar_type() == bias.value().scalar_type(), + "BiasActivation activation and bias must have same dtype"); + bias_ptr = reinterpret_cast(bias.value().data_ptr()); + } else { + bias_ptr = nullptr; + } + + if (atype == ActivationType::IDENTITY && bias_ptr == nullptr) { return; } + + launch_bias_activation( + activation_ptr, bias_ptr, rows, cols, atype, c10::cuda::getCurrentCUDAStream()); + }); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cu b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cu new file mode 100644 index 000000000000..66bca0c175c3 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.cu @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "activation_type.h" +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +// Default activation function will error out +template +DS_D_INLINE float act_fn(float val); + +template <> +DS_D_INLINE float act_fn(float val) +{ + return val; +} + +template <> +DS_D_INLINE float act_fn(float val) +{ + return val > 0.0f ? val : 0.0f; +} + +template <> +DS_D_INLINE float act_fn(float val) +{ + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715f; + return val * 0.5f * (1.0f + tanhf(sqrt_param * (val + mul_param * val * val * val))); +} + +template <> +DS_D_INLINE float act_fn(float val) +{ + return val / (1.0f + expf(-val)); +} + +namespace bias_act { + +constexpr int access_size = 16; +constexpr int threads = 512; +constexpr int unroll = 4; + +} // namespace bias_act + +template +__global__ void bias_activation_kernel(T* activation, + const T* bias, + const int32_t rows, + const int32_t cols) +{ + constexpr int vector_T = bias_act::access_size / sizeof(T); + + const int32_t thread_offset = threadIdx.x * vector_T; + const int32_t block_offset = blockIdx.x * vector_T * bias_act::unroll * bias_act::threads; + const int32_t base_offset = block_offset + thread_offset; + + const int32_t thread_stride = bias_act::threads * vector_T; + +#pragma unroll + for (int i = 0; i < bias_act::unroll; i++) { + const int32_t iter_offset = base_offset + i * thread_stride; + + const int32_t row = iter_offset / cols; + + T buffer[vector_T]; + T bias_buffer[vector_T]; + + if (row < rows) { + const int32_t col = iter_offset % cols; + + mem_access::load_global(buffer, activation + iter_offset); + mem_access::load_global( + bias_buffer, bias + col, bias != nullptr); + +#pragma unroll + for (int j = 0; j < vector_T; j++) { + float val = + conversion::to(buffer[j]) + conversion::to(bias_buffer[j]); + buffer[j] = conversion::to(act_fn(val)); + } + + mem_access::store_global(activation + iter_offset, buffer); + } + } +} + +#define ACT_TYPE_SWITCH(ACT_TYPE, ...) \ + if (ACT_TYPE == ActivationType::IDENTITY) { \ + constexpr ActivationType act_fn_t = ActivationType::IDENTITY; \ + return __VA_ARGS__(); \ + } else if (ACT_TYPE == ActivationType::RELU) { \ + constexpr ActivationType act_fn_t = ActivationType::RELU; \ + return __VA_ARGS__(); \ + } else if (ACT_TYPE == ActivationType::GELU) { \ + constexpr ActivationType act_fn_t = ActivationType::GELU; \ + return __VA_ARGS__(); \ + } else if (ACT_TYPE == ActivationType::SILU) { \ + constexpr ActivationType act_fn_t = ActivationType::SILU; \ + return __VA_ARGS__(); \ + } else { \ + assert(false); \ + } + +template +void launch_bias_activation(T* activation, + const T* bias, + const int32_t n_rows, + const int32_t n_cols, + const ActivationType activation_type, + cudaStream_t stream) +{ + constexpr int32_t elems_per_block = + bias_act::threads * bias_act::unroll * bias_act::access_size / sizeof(T); + const int32_t total_elems = n_rows * n_cols; + + const int32_t blocks = (total_elems + elems_per_block - 1) / elems_per_block; + + const dim3 grid(blocks); + const dim3 block(bias_act::threads); + + ACT_TYPE_SWITCH(activation_type, [&] { + bias_activation_kernel + <<>>(activation, bias, n_rows, n_cols); + }); +} + +#define INSTANTIATE_FOR_T(T) \ + template void launch_bias_activation( \ + T*, const T*, const int32_t, const int32_t, const ActivationType, cudaStream_t); + +INSTANTIATE_FOR_T(__half); + +#ifdef BF16_AVAILABLE +INSTANTIATE_FOR_T(__nv_bfloat16); +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h new file mode 100644 index 000000000000..db6174633a09 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.h @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "activation_type.h" + +template +void launch_bias_activation(T* activation, + const T* bias, + const int32_t n_rows, + const int32_t n_cols, + const ActivationType activation_type, + cudaStream_t stream); + +void bias_activation(torch::Tensor& activation, + c10::optional& bias, + const int32_t activation_type); diff --git a/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py new file mode 100644 index 000000000000..436d7f8805d5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/bias_activations/bias_activation.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ... import DSKernelBase + + +class CUDABiasActivation(DSKernelBase): + """ + CUDA implementation of bias activation kernel. This kernel should be deprecated once + we are fusing the bias activation into the linear kernel in all scenarios. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.IDENTITY, ActivationType.GELU, ActivationType.RELU, ActivationType.SILU] + + def __init__(self, channels: int, dtype: DtypeEnum, act_fn: ActivationType) -> None: + """ + Compile and validate for the fused bias-activation kernel. + + Parameters: + channels (int): Number of channels to expect in the activation. + dtype (torch.dtype): Data type for the input/output. Supported values + are DtypeEnum.fp16 and DtypeEnum.bf16. + act_fn (ActivationType): Activation function to use. Only IDENTITY, GELU, RELU, and SILU are supported. + """ + + if channels % 8 != 0: + raise ValueError("channels must be divisible by 8") + + if DtypeEnum(dtype) not in CUDABiasActivation.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, CUDABiasActivation.supported_dtypes)) + + act_fn = ActivationType(act_fn) + if act_fn not in CUDABiasActivation.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, CUDABiasActivation.supported_act_fns)) + + inf_module = InferenceCoreBuilder().load() + self.kernel = inf_module.bias_activation + self.act_fn = act_fn + + def __call__(self, activation: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Add an optional bias and perform the non-linear activation function. + + Parameters: + activation (torch.Tensor): Input tensor of shape [tokens, channels] + bias (torch.Tensor): Optional bias tensor of shape [channels] + + Returns: + activation that has been updated in-place + """ + self.kernel(activation, bias, self.act_fn.value) diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py new file mode 100644 index 000000000000..4af5a579ca1b --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blas_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h new file mode 100644 index 000000000000..1854e40a227d --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas.h @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#include "blas_utils.h" + +#define DISPATCH_BLAS_MATMUL(T_TYPE, C_TYPE) \ + if (output.options().dtype() == torch::T_TYPE) { \ + blas_gemm_ex(output.data_ptr(), \ + (const void*)weights.data_ptr(), \ + (const void*)hidden_states.data_ptr(), \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc, \ + trans_a, \ + trans_b, \ + &alpha, \ + &beta, \ + C_TYPE); \ + } + +void blas_linear(at::Tensor& output, at::Tensor& hidden_states, at::Tensor& weights) +{ + /* + Expected shape: output([total_tokens_across_dims], out_neurons) + hidden_states([total_tokens_across_dims], in_neurons) + weights(out_neurons, in_neurons) + + We are going to assume contiguous for the above shapes. + + The shapes are going to get messed with a little internally to handle column-major + GEMMs. + */ + + // Number of tokens is N (since the GEMM output is column-major but our Tensor + // is row-major, we need to transpose the shapes) + const int n = output.numel() / output.size(-1); + const int k = weights.size(1); + const int m = weights.size(0); + + // A strides + const bool trans_a = weights.stride(1) == 1; + const int lda = (trans_a) ? weights.stride(0) : weights.stride(1); + + // B strides + const bool trans_b = hidden_states.stride(-1) != 1; + const int ldb = (trans_b) ? hidden_states.stride(-1) : hidden_states.stride(-2); + + // C strides + const int ldc = output.stride(-2); + + const float alpha = 1.0f; + const float beta = 0.0f; + + TORCH_CHECK(output.scalar_type() == hidden_states.scalar_type(), + "Output and hidden states must have the same scalar type"); + TORCH_CHECK(output.scalar_type() == weights.scalar_type(), + "Output and weights must have the same scalar type"); + + // Dispatch the datatypes + DISPATCH_BLAS_MATMUL(kFloat, BlasType::FP32); + DISPATCH_BLAS_MATMUL(kHalf, BlasType::FP16); +#ifdef BF16_AVAILABLE + DISPATCH_BLAS_MATMUL(kBFloat16, BlasType::BF16); +#endif +} + +#define DISPATCH_4D_BLAS(T_TYPE, C_TYPE) \ + if (C.options().dtype() == torch::T_TYPE) { \ + blas_strided_batched_gemm(C.data_ptr(), \ + (const void*)A.data_ptr(), \ + (const void*)B.data_ptr(), \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc, \ + trans_a, \ + trans_b, \ + &alpha, \ + &beta, \ + stride_a, \ + stride_b, \ + stride_c, \ + batch, \ + C_TYPE); \ + } + +void blas_4d_matmul(at::Tensor& C, at::Tensor& B, at::Tensor& A) +{ + /* + C shape: (batch_size, N, M) + A shape: (batch_size, N, K) + B shape: (batch_size, K, M) + */ + + const int n = C.size(-2); + const int k = C.size(-1); + const int m = B.size(-1); + + // A strides + const bool trans_a = A.stride(-1) == 1; + const int lda = (trans_a) ? A.stride(-2) : A.stride(-1); + const int stride_a = A.stride(-3); + + // B strides + const bool trans_b = B.stride(-1) != 1; + const int ldb = (trans_b) ? B.stride(-1) : B.stride(-2); + const int stride_b = B.stride(-3); + + // C strides + const int ldc = C.stride(-2); + const int stride_c = C.stride(-3); + + const float alpha = 1.0f; + const float beta = 0.0f; + + const int batch = C.numel() / (n * m); + + // Dispatch the datatypes + DISPATCH_4D_BLAS(kFloat, BlasType::FP32); + DISPATCH_4D_BLAS(kHalf, BlasType::FP16); +#ifdef BF16_AVAILABLE + DISPATCH_4D_BLAS(kBFloat16, BlasType::BF16); +#endif +} + +void create_handle() { BlasContext::getInstance().get_handle(); } diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py new file mode 100644 index 000000000000..9a151ce36dc4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_linear.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ... import DSKernelBase + + +class BlasLibLinear(DSKernelBase): + """ + Wrapper around the BLAS matmul kernel for FP16/BF16/FP32 for CUDA/RoCM. + + Performs z = x @ y + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32] + + def __init__(self, fp_dtype: DtypeEnum): + """ + Parameters: + fp_dtype (torch.dtype): Data type for the input/output. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + fp_dtype = DtypeEnum(fp_dtype) + if fp_dtype not in BlasLibLinear.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, BlasLibLinear.supported_dtypes)) + + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.kernel = self.inf_module.blas_linear + + def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """ + Matmul kernel as implemented by platform BLAS library. The input must be 2D or larger. If + n-dimensional, the leading dimensions are folded into each other: + 2D: m = x.size(0) + 3D: m = x.size(0) * x.size(1) + 4D: m = x.size(0) * x.size(1) * x.size(2) (etc...) + All inputs should be contiguous. + + Parameters: + output (torch.Tensor): Output tensor. Shape is of [*, out_features] + hidden_states (torch.Tensor): Input tensor. Shape is of [*, in_features] + weights (torch.Tensor): Input tensor. Shape is of [out_features, in_features] + + Returns: + z (torch.Tensor): Output tensor. Shape is of [m, n] + """ + self.kernel(output, hidden_states, weights) + return output diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h new file mode 100644 index 000000000000..450991b3c387 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include +#ifdef BF16_AVAILABLE +#include +#endif +#include +#include +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif +#include +#include +#include + +class BlasContext { + /* + Slim wrapper for managing the lifetime of the platform's BLAS handle. This should + be hipified for ROCm. + */ +public: + BlasContext() + { + if (cublasCreate(&_handle) != CUBLAS_STATUS_SUCCESS) { + auto message = std::string("Fail to create cublas handle."); + std::cerr << message << std::endl; + throw std::runtime_error(message); + } +#ifndef __HIP_PLATFORM_HCC__ + cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH); +#endif + } + + virtual ~BlasContext() { cublasDestroy(_handle); } + + static BlasContext& getInstance() + { + // Should always access the singleton through this function. + static BlasContext _instance; + return _instance; + } + + cublasHandle_t get_handle() const { return _handle; } + +private: + cublasHandle_t _handle; +}; + +enum class BlasType { FP32, FP16, BF16 }; + +#ifdef __HIP_PLATFORM_HCC__ +rocblas_operation get_trans_op(bool do_trans) +{ + return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none; +} + +rocblas_datatype get_datatype(BlasType type) +{ + switch (type) { + case BlasType::FP32: return rocblas_datatype_f32_r; + case BlasType::FP16: return rocblas_datatype_f16_r; + case BlasType::BF16: return rocblas_datatype_bf16_r; + default: throw std::runtime_error("Unsupported BlasType"); + } +} +#else +cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T : CUBLAS_OP_N; } + +cublasDataType_t get_datatype(BlasType type) +{ + switch (type) { + case BlasType::FP32: return CUDA_R_32F; + case BlasType::FP16: return CUDA_R_16F; + case BlasType::BF16: return CUDA_R_16BF; + default: throw std::runtime_error("Unsupported BlasType"); + } +} +#endif + +int blas_gemm_ex(void* C, + const void* A, + const void* B, + int m, + int n, + int k, + int lda, + int ldb, + int ldc, + bool transa, + bool transb, + const float* alpha, + const float* beta, + BlasType type) +{ +#ifdef __HIP_PLATFORM_HCC__ + rocblas_operation_t transa_op = get_trans_op(transa); + rocblas_operation_t transb_op = get_trans_op(transb); + + rocblas_datatype_t abc_type = get_datatype(type); + + rocblas_status status = rocblas_gemm_ex(BlasContext::getInstance().get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + B, + abc_type, + ldb, + (const void*)beta, + C, + abc_type, + ldc, + C, + abc_type, + ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); +#else + cublasOperation_t transa_op = get_trans_op(transa); + cublasOperation_t transb_op = get_trans_op(transb); + + cublasDataType_t abc_type = get_datatype(type); + cublasStatus_t status = cublasGemmEx(BlasContext::getInstance().get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + B, + abc_type, + ldb, + (const void*)beta, + C, + abc_type, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + +#ifdef __HIP_PLATFORM_HCC__ + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int blas_strided_batched_gemm(void* C, + const void* A, + const void* B, + int m, + int n, + int k, + int lda, + int ldb, + int ldc, + bool transa, + bool transb, + const float* alpha, + const float* beta, + int stride_A, + int stride_B, + int stride_C, + int batch, + BlasType type) +{ +#ifdef __HIP_PLATFORM_HCC__ + rocblas_operation_t transa_op = get_trans_op(transa); + rocblas_operation_t transb_op = get_trans_op(transb); + + rocblas_datatype_t abc_type = get_datatype(type); + + rocblas_status status = + rocblas_gemm_strided_batched_ex(BlasContext::getInstance()::get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + stride_A, + B, + abc_type, + ldb, + stride_B, + (const void*)beta, + C, + abc_type, + ldc, + stride_C, + C, + abc_type, + ldc, + stride_C, + batch, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); +#else + cublasOperation_t transa_op = get_trans_op(transa); + cublasOperation_t transb_op = get_trans_op(transb); + + cublasDataType_t abc_type = get_datatype(type); + + cublasStatus_t status = cublasGemmStridedBatchedEx(BlasContext::getInstance().get_handle(), + transa_op, + transb_op, + m, + n, + k, + (const void*)alpha, + A, + abc_type, + lda, + stride_A, + B, + abc_type, + ldb, + stride_B, + (const void*)beta, + C, + abc_type, + ldc, + stride_C, + batch, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + +#ifdef __HIP_PLATFORM_HCC__ + if (status != rocblas_status_success) { +#else + if (status != CUBLAS_STATUS_SUCCESS) { +#endif + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", + batch, + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} diff --git a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp new file mode 100644 index 000000000000..58df88e56136 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include + +#include "bias_activation.h" +#include "blas.h" +#include "gated_activation_kernels.h" +#include "layer_norm.h" +#include "rms_norm.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + // bias_activation.h + m.def("bias_activation", &bias_activation, "DeepSpeed bias activation in CUDA"); + + // layer_norm.h + m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm in CUDA"); + m.def("pre_layer_norm", &ds_pre_layer_norm, "DeepSpeed pre layer norm in CUDA"); + m.def("post_layer_norm", &ds_post_layer_norm, "DeepSpeed pre layer norm in CUDA"); + + // blas.h + m.def("blas_linear", &blas_linear, "Linear implemented by vendor BLAS"); + m.def("blas_4d_matmul", &blas_4d_matmul, "4D matmul implemented by vendor BLAS"); + m.def("create_handle", &create_handle, "Create a handle for vendor BLAS"); + + // gated_activation_kernels.h + m.def("gated_activation", &ds_gated_activation, "DeepSpeed gated activation in CUDA"); + + // rms_norm.h + m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA"); + m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA"); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py new file mode 100644 index 000000000000..bed7688b15d2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_ln import * +from .cuda_post_ln import * +from .cuda_pre_ln import * diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py new file mode 100644 index 000000000000..3c2aa5cb5eb4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_fp_ln_base.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import elem_size +from deepspeed.ops.op_builder import InferenceCoreBuilder + + +class CUDAFPLNBase(DSKernelBase): + """ + Base class for CUDA LN kernels. They all same the same validation logic, + so we can share it here. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5): + """ + Parameters: + channels (int): Number of channels in the input tensor. Must be divisible to align + to 16 bytes. + fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + if fp_dtype not in CUDAFPLNBase.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, CUDAFPLNBase.supported_dtypes)) + + if elem_size(fp_dtype) * channels % 16 != 0: + raise ValueError("channels must be divisible by 16 bytes") + + self.inf_module = InferenceCoreBuilder().load() + self.epsilon = epsilon diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py new file mode 100644 index 000000000000..583736fb8bbc --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_ln.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .cuda_fp_ln_base import CUDAFPLNBase + + +class CUDAFPLN(CUDAFPLNBase): + """ + Floating point layer norm kernel for CUDA/RoCM. + + Performs: z = ln(x) + """ + + def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor) -> torch.Tensor: + """ + output_z may alias input_x directly. All Tensors should have the same shape. + + Parameters: + output_z (torch.Tensor): Output tensor. + input_x (torch.Tensor): Input tensor. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + """ + self.inf_module.layer_norm(output_z, input_x, gamma, beta, self.epsilon) + return output_z diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py new file mode 100644 index 000000000000..0ced1ecf207e --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_post_ln.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .cuda_fp_ln_base import CUDAFPLNBase + + +class CUDAFPPostLN(CUDAFPLNBase): + """ + Floating point post-LayerNorm kernel for CUDA/RoCM. + + Performs: z = ln(x + y) + """ + + def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, input_y: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor) -> torch.Tensor: + """ + Either input_x or input_y can alias output_z. + + Parameters: + output_z (torch.Tensor): Output tensor. + input_x (torch.Tensor): Input tensor. + input_y (torch.Tensor): Input tensor. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + + Returns: + output (torch.Tensor): Output tensor. + """ + self.inf_module.post_layer_norm(output_z, input_x, input_y, gamma, beta, self.epsilon) + return output_z diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py new file mode 100644 index 000000000000..74b2d9cf5880 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/cuda_pre_ln.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch + +from .cuda_fp_ln_base import CUDAFPLNBase + + +class CUDAFPPreLN(CUDAFPLNBase): + """ + Floating point pre-LayerNorm kernel for CUDA/RoCM. + + Performs: z_res = x_res + y_hid + z_hid = ln(z_hid) + """ + + def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor, + gamma: torch.Tensor, beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + z_res can alias x_res. All non-parameter input/output tensors + must have the same shape. z_hid can alias y_hid. + + Parameters: + z_res (torch.Tensor): Output residual. + z_hid (torch.Tensor): Output hidden states. + x_res (torch.Tensor): Input residual. + y_hid (torch.Tensor): Input hidden states. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + + Returns: + output (torch.Tensor): Output tensor. + """ + self.inf_module.pre_layer_norm(z_res, z_hid, x_res, y_hid, gamma, beta, self.epsilon) + return z_res, z_hid diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp new file mode 100644 index 000000000000..b2c95d410a1f --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "layer_norm.h" + +#define DISPATCH_LAYER_NORM(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + at::cuda::getCurrentCUDAStream()); \ + } + +void ds_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + bool ragged_input = input.dim() == 2; + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_LAYER_NORM(kFloat, float); + DISPATCH_LAYER_NORM(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM(kBFloat16, __nv_bfloat16); +#endif +} + +#define DISPATCH_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_post_ln((C_TYPE*)output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + at::cuda::getCurrentCUDAStream()); \ + } + +void ds_post_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + bool ragged_input = input.dim() == 2; + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif +} + +#define DISPATCH_PRE_LAYER_NORM_RESIDUAL(T_TYPE, C_TYPE) \ + if (input.options().dtype() == torch::T_TYPE) { \ + launch_fused_pre_ln((C_TYPE*)norm_output.data_ptr(), \ + (C_TYPE*)res_output.data_ptr(), \ + (const C_TYPE*)input.data_ptr(), \ + (const C_TYPE*)residual.data_ptr(), \ + (const C_TYPE*)gamma.data_ptr(), \ + (const C_TYPE*)beta.data_ptr(), \ + epsilon, \ + rows, \ + elems_per_row, \ + at::cuda::getCurrentCUDAStream()); \ + } + +void ds_pre_layer_norm(at::Tensor& res_output, + at::Tensor& norm_output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon) +{ + bool ragged_input = input.dim() == 2; + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int elems_per_row = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kFloat, float); + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kHalf, __half); +#ifdef BF16_AVAILABLE + DISPATCH_PRE_LAYER_NORM_RESIDUAL(kBFloat16, __nv_bfloat16); +#endif +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu new file mode 100644 index 000000000000..15f52c46622b --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu @@ -0,0 +1,490 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +namespace cg = cooperative_groups; +using rop = reduce::ROpType; + +namespace ln { +constexpr int granularity = 16; +} // namespace ln + +/* +Regular layer norm implementation. Assumes elems_per_row % 8 +is equal to 0. + +Args: + output: buffer for output data + vals: buffer for input data + gamma: gain for normalization + beta: bias for normalization + epsilon: numeric stability + elems_per_row: number of elements each block will normalize +*/ +template +__global__ void fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = ln::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float sum = reduce::init(); + + const T* input_base = vals + base_offset; + + T local_buffer[unRoll * T_per_load]; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + + mem_access::load_global( + iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + sum = reduce::element(sum, vals_up_cast); + } + } + + reduce::partitioned_block(tb, warp, sum); + const float mean = sum / elems_per_row; + + float mean_diff = reduce::init(); + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); + } + } + } + + reduce::partitioned_block(tb, warp, mean_diff); + const float variance = mean_diff / elems_per_row; + const float denom = __frsqrt_rn(variance + epsilon); + + T* block_output = output + block_offset; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; + + T gamma_local[T_per_load], beta_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global(beta_local, beta + iter_idx, do_loads); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \ + fused_ln \ + <<>>(output, vals, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_LN(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_FUSED_LN(T) \ + template void launch_fused_ln(T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_FUSED_LN(__half); +#ifdef BF16_AVAILABLE +INSTANTIATE_FUSED_LN(__nv_bfloat16); +#endif +INSTANTIATE_FUSED_LN(float); + +/* +Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8 +is equal to 0. + +TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual +need to be fused into compute-bound producer operations. + +Args: + output: buffer for output data + res_output: output of residual addition + vals: buffer for input data + residual: residual data + bias: bias of of input data + gamma: gain for normalization + beta: bias for normalization + epsilon: numeric stability + elems_per_row: number of elements each block will normalize +Template arg: + StoreResidual: controls whether the residual calculation is stored + or not. When set to false, the input `res_output` is unused. +*/ +template +__global__ void fused_residual_ln(T* output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = ln::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = tb.size() * T_per_load; + + float sum = reduce::init(); + + const T* input_base = vals + base_offset; + const T* residual_base = residual + base_offset; + + T local_buffer[unRoll * T_per_load]; + + // Unlike a vanilla layernorm, since we're fusing the two adds as well + // an inner unRoll seems to be less valuable. If anything, a double unRoll + // makes the most sense if we find we are having performance issues. +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + T residual_buffer[T_per_load]; + T bias_buffer[T_per_load]; + + mem_access::load_global( + iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); + mem_access::load_global(residual_buffer, + residual_base + i * stride, + thread_offset + i * stride < elems_per_row); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float vals_up_cast = conversion::to(iteration_buffer[j]); + float res_up_cast = conversion::to(residual_buffer[j]); + vals_up_cast += res_up_cast; + sum = reduce::element(sum, vals_up_cast); + iteration_buffer[j] = conversion::to(vals_up_cast); + } + + if (preLnResidual && (thread_offset + i * stride < elems_per_row)) { + mem_access::store_global(res_output + base_offset + i * stride, + iteration_buffer); + } + } + + reduce::partitioned_block(tb, warp, sum); + const float mean = sum / elems_per_row; + + float mean_diff = reduce::init(); +#pragma unRoll + for (int i = 0; i < unRoll; i++) { +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + // Using a 0 value here skews the variance, have to if-guard + if (thread_offset + i * stride < elems_per_row) { + float diff = (conversion::to(local_buffer[i * T_per_load + j]) - mean); + mean_diff = reduce::element(mean_diff, diff * diff); + } + } + } + + reduce::partitioned_block(tb, warp, mean_diff); + const float variance = mean_diff / elems_per_row; + const float denom = __frsqrt_rn(variance + epsilon); + + T* block_output = output + block_offset; + +#pragma unRoll + for (int i = 0; i < unRoll; i++) { + T* iteration_buffer = local_buffer + i * T_per_load; + const int iter_idx = i * stride + thread_offset; + const bool do_loads = iter_idx < elems_per_row; + + T gamma_local[T_per_load], beta_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + mem_access::load_global(beta_local, beta + iter_idx, do_loads); + +#pragma unRoll + for (int j = 0; j < T_per_load; j++) { + float val = conversion::to(iteration_buffer[j]); + val = (val - mean) * denom; + val = + val * conversion::to(gamma_local[j]) + conversion::to(beta_local[j]); + iteration_buffer[j] = conversion::to(val); + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified. +#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \ + fused_residual_ln \ + <<>>( \ + output, nullptr, vals, residual, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_post_ln(T* output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \ + fused_residual_ln \ + <<>>( \ + norm_output, res_output, vals, residual, gamma, beta, epsilon, elems_per_row); + +template +void launch_fused_pre_ln(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = ln::granularity / sizeof(T); + + constexpr int maxThreads = 256; + + // For Flaoat, unRoll 4, for __half, unRoll 2 + constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threadsPerGroup, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threadsPerGroup * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + if (is_subblock_schedule) { + // <=128 + if (threadsPerGroup == 1) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads); + } else if (threadsPerGroup == 2) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads); + } else if (threadsPerGroup == 4) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads); + } else if (threadsPerGroup == 8) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads); + } else if (threadsPerGroup == 16) { + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_RES_LN(T) \ + template void launch_fused_post_ln( \ + T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +#define INSTANTIATE_PRE_LN_RES(T) \ + template void launch_fused_pre_ln( \ + T*, T*, const T*, const T*, const T*, const T*, float, int, int, cudaStream_t); + +INSTANTIATE_RES_LN(__half); +INSTANTIATE_RES_LN(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_RES_LN(__nv_bfloat16); +#endif + +INSTANTIATE_PRE_LN_RES(__half); +INSTANTIATE_PRE_LN_RES(float); +#ifdef BF16_AVAILABLE +INSTANTIATE_PRE_LN_RES(__nv_bfloat16); +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h new file mode 100644 index 000000000000..9ea3a8c42524 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" + +/* +Kernel launch methods for layer norm variants. +*/ + +template +void launch_fused_ln(T* output, + const T* vals, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +template +void launch_fused_post_ln(T* output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); +template +void launch_fused_pre_ln(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +void ds_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon); + +void ds_post_layer_norm(at::Tensor& output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon); + +void ds_pre_layer_norm(at::Tensor& res_output, + at::Tensor& norm_output, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& gamma, + at::Tensor& beta, + float epsilon); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py new file mode 100644 index 000000000000..640a72307650 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .rms_norm import CUDARMSNorm +from .rms_pre_norm import CUDARMSPreNorm diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp new file mode 100644 index 000000000000..c67712df438a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "rms_norm.h" + +#ifdef BF16_AVAILABLE +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using scalar_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#else +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#endif + +void rms_norm(torch::Tensor& norm_output, + torch::Tensor& norm_input, + torch::Tensor& gamma, + float epsilon) +{ + TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(), + "norm_output and norm_input should have the same data type"); + TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(), + "norm_output and gamma should have the same data type"); + + const int32_t rows = norm_input.size(0); + const int32_t cols = norm_input.size(1); + + TORCH_CHECK(norm_output.size(0) == rows, + "norm_output and norm_input should have the same first dimension"); + TORCH_CHECK(norm_output.size(1) == cols, + "norm_output and norm_input should have the same second dimension"); + + DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] { + scalar_t* norm_output_ptr = reinterpret_cast(norm_output.data_ptr()); + scalar_t* norm_input_ptr = reinterpret_cast(norm_input.data_ptr()); + scalar_t* gamma_ptr = reinterpret_cast(gamma.data_ptr()); + scalar_t* null_t = nullptr; + + launch_rms_norm(norm_output_ptr, + null_t, + norm_input_ptr, + null_t, + gamma_ptr, + epsilon, + rows, + cols, + at::cuda::getCurrentCUDAStream()); + }); +} + +void rms_pre_norm(torch::Tensor& norm_output, + torch::Tensor& residual_output, + torch::Tensor& norm_input, + torch::Tensor& residual_input, + torch::Tensor& gamma, + float epsilon) +{ + TORCH_CHECK(norm_output.scalar_type() == norm_input.scalar_type(), + "norm_output and norm_input should have the same data type"); + TORCH_CHECK(norm_output.scalar_type() == gamma.scalar_type(), + "norm_output and gamma should have the same data type"); + + const int32_t rows = norm_input.size(0); + const int32_t cols = norm_input.size(1); + + TORCH_CHECK(norm_output.size(0) == rows, + "norm_output and norm_input should have the same first dimension"); + TORCH_CHECK(norm_output.size(1) == cols, + "norm_output and norm_input should have the same second dimension"); + + TORCH_CHECK(residual_output.size(0) == rows, + "residual_output and norm_input should have the same first dimension"); + TORCH_CHECK(residual_output.size(1) == cols, + "residual_output and norm_input should have the same second dimension"); + + TORCH_CHECK(residual_input.size(0) == rows, + "residual_input and norm_input should have the same first dimension"); + TORCH_CHECK(residual_input.size(1) == cols, + "residual_input and norm_input should have the same second dimension"); + + DISPATCH_FOR_FLOAT(norm_output.scalar_type(), [&] { + scalar_t* norm_output_ptr = reinterpret_cast(norm_output.data_ptr()); + scalar_t* residual_output_ptr = reinterpret_cast(residual_output.data_ptr()); + const scalar_t* norm_input_ptr = reinterpret_cast(norm_input.data_ptr()); + const scalar_t* residual_input_ptr = + reinterpret_cast(residual_input.data_ptr()); + const scalar_t* gamma_ptr = reinterpret_cast(gamma.data_ptr()); + + launch_rms_norm(norm_output_ptr, + residual_output_ptr, + norm_input_ptr, + residual_input_ptr, + gamma_ptr, + epsilon, + rows, + cols, + at::cuda::getCurrentCUDAStream()); + }); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu new file mode 100644 index 000000000000..e69d3c36cc00 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu @@ -0,0 +1,262 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" + +namespace cg = cooperative_groups; +using rop = reduce::ROpType; + +namespace rms { +constexpr int granularity = 16; +} // namespace rms + +template +__global__ void rms_norm(T* output, const T* vals, const T* gamma, float epsilon, int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + + mem_access::load_global(iteration_buffer, + input_base + (i * stride), + thread_offset + (i * stride) < elems_per_row); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + float up_cast = conversion::to(iteration_buffer[j]); + float sq_val = up_cast * up_cast; + var_sum = reduce::element(var_sum, sq_val); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +template +__global__ void pre_rms_norm(T* output, + T* res_out, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int elems_per_row) +{ + constexpr int T_per_load = rms::granularity / sizeof(T); + + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // X-dimension of the block + const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) + + (tb.thread_index().y * elems_per_row); + const int thread_offset = tb.thread_index().x * T_per_load; + const int base_offset = block_offset + thread_offset; + const int stride = blockDim.x * T_per_load; + + float var_sum = reduce::init(); + + const T* input_base = vals + base_offset; + const T* residual_base = residual + base_offset; + T* res_output = res_out + base_offset; + + T local_buffer[UNROLL * T_per_load]; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + T residual_buffer[T_per_load]; + + const int iter_offset = i * stride + thread_offset; + const bool do_loads = (iter_offset < elems_per_row); + + mem_access::load_global( + iteration_buffer, input_base + (i * stride), do_loads); + mem_access::load_global( + residual_buffer, residual_base + (i * stride), do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] += residual_buffer[j]; + float vals_up_cast = conversion::to(iteration_buffer[j]); + + var_sum = reduce::element(var_sum, vals_up_cast * vals_up_cast); + } + + if (do_loads) { + mem_access::store_global(res_output + i * stride, iteration_buffer); + } + } + + reduce::partitioned_block(tb, warp, var_sum); + const float var = var_sum / elems_per_row; + const T denom = conversion::to(__frsqrt_rn(var + epsilon)); + + T* block_output = output + block_offset; + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + T* iteration_buffer = local_buffer + (i * T_per_load); + const int iter_idx = i * stride + thread_offset; + const bool do_loads = (iter_idx < elems_per_row); + + T gamma_local[T_per_load]; + + mem_access::load_global(gamma_local, gamma + iter_idx, do_loads); + +#pragma unroll + for (int j = 0; j < T_per_load; j++) { + iteration_buffer[j] *= denom; + iteration_buffer[j] *= gamma_local[j]; + } + + if (do_loads) { + mem_access::store_global(block_output + iter_idx, iteration_buffer); + } + } +} + +#define LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + rms_norm \ + <<>>(norm_output, vals, gamma, epsilon, elems_per_row); + +#define LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + pre_rms_norm<<>>( \ + norm_output, res_output, vals, residual, gamma, epsilon, elems_per_row); + +#define LAUNCH_ALL_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + if (pre_norm) { \ + LAUNCH_PRE_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } else { \ + LAUNCH_RMS_NORM(UNROLL, threadsPerGroup, maxThreads) \ + } + +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream) +{ + // 8 for __half, 4 for float + constexpr int T_per_load = rms::granularity / sizeof(T); + constexpr int maxThreads = 256; + constexpr int internalUnroll = sizeof(T) == 4 ? 4 : 2; + + const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false; + const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internalUnroll; + + // Scheduling concern: may be slightly faster for some inputs to assign multiple stages of + // warp-sized blocks rather than stepping up to 64/96 threads + const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step); + const int threads_per_group = (one_step_threads < maxThreads) ? one_step_threads : maxThreads; + + const int groups_per_block_max = + is_subblock_schedule ? (maxThreads + threads_per_group - 1) / threads_per_group : 1; + const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max; + const int groups_launch = (groups_per_block + rows - 1) / groups_per_block; + + dim3 block(threads_per_group, groups_per_block); + dim3 grid(groups_launch); + + const int elems_per_step = threads_per_group * h_per_step; + const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step; + + bool pre_norm = (residual == nullptr) ? false : true; + + if (is_subblock_schedule) { + // <=128 + if (threads_per_group == 1) { + LAUNCH_ALL_RMS_NORM(1, 1, maxThreads); + } else if (threads_per_group == 2) { + LAUNCH_ALL_RMS_NORM(1, 2, maxThreads); + } else if (threads_per_group == 4) { + LAUNCH_ALL_RMS_NORM(1, 4, maxThreads); + } else if (threads_per_group == 8) { + LAUNCH_ALL_RMS_NORM(1, 8, maxThreads); + } else if (threads_per_group == 16) { + LAUNCH_ALL_RMS_NORM(1, 16, maxThreads); + } + } else if (external_unRoll == 1) { + // 129 - 4096 elems + // (this can launch with 1-7 warps as well) + LAUNCH_ALL_RMS_NORM(1 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 2) { + // 4097 - 8192 elems + LAUNCH_ALL_RMS_NORM(2 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 3) { + // 8193 - 12288 elems + LAUNCH_ALL_RMS_NORM(3 * internalUnroll, maxThreads, maxThreads); + } else if (external_unRoll == 4) { + // 12289 - 16384 elems + LAUNCH_ALL_RMS_NORM(4 * internalUnroll, maxThreads, maxThreads); + } +} + +#define INSTANTIATE_LAUNCH_RMS_NORM(T) \ + template void launch_rms_norm(T * norm_output, \ + T * res_output, \ + const T* vals, \ + const T* residual, \ + const T* gamma, \ + float epsilon, \ + int rows, \ + int elems_per_row, \ + cudaStream_t stream); + +INSTANTIATE_LAUNCH_RMS_NORM(float) +INSTANTIATE_LAUNCH_RMS_NORM(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_LAUNCH_RMS_NORM(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h new file mode 100644 index 000000000000..7867fb65964f --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" + +template +void launch_rms_norm(T* norm_output, + T* res_output, + const T* vals, + const T* residual, + const T* gamma, + float epsilon, + int rows, + int elems_per_row, + cudaStream_t stream); + +void rms_norm(torch::Tensor& norm_output, + torch::Tensor& norm_input, + torch::Tensor& gamma, + float epsilon); + +void rms_pre_norm(torch::Tensor& norm_output, + torch::Tensor& residual_output, + torch::Tensor& norm_input, + torch::Tensor& residual_input, + torch::Tensor& gamma, + float epsilon); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py new file mode 100644 index 000000000000..deb5d33111a9 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .rms_norm_base import CUDARMSNormBase + + +class CUDARMSNorm(CUDARMSNormBase): + """ + Floating point layer norm kernel for CUDA/RoCM. + + Performs: z = ln(x) + """ + + def __call__(self, output_z: torch.Tensor, input_x: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor: + """ + output_z may alias input_x directly. All Tensors should have the same shape. + + Parameters: + output_z (torch.Tensor): Output tensor. + input_x (torch.Tensor): Input tensor. + gamma (torch.Tensor): Gamma tensor. + """ + self.inf_module.rms_norm(output_z, input_x, gamma, self.epsilon) + return output_z diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py new file mode 100644 index 000000000000..62bc9d056ade --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_base.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import elem_size +from deepspeed.ops.op_builder import InferenceCoreBuilder + + +class CUDARMSNormBase(DSKernelBase): + """ + Base class for CUDA LN kernels. They all same the same validation logic, + so we can share it here. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def __init__(self, channels: int, fp_dtype: torch.dtype, epsilon: float = 1e-5): + """ + Parameters: + channels (int): Number of channels in the input tensor. Must be divisible to align + to 16 bytes. + fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + if fp_dtype not in CUDARMSNormBase.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, CUDARMSNormBase.supported_dtypes)) + + if elem_size(fp_dtype) * channels % 16 != 0: + raise ValueError("channels must be divisible by 16 bytes") + + self.inf_module = InferenceCoreBuilder().load() + self.epsilon = epsilon diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py new file mode 100644 index 000000000000..3b040d88b50f --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_rms_norm/rms_pre_norm.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch + +from .rms_norm_base import CUDARMSNormBase + + +class CUDARMSPreNorm(CUDARMSNormBase): + """ + Floating point pre-LayerNorm kernel for CUDA/RoCM. + + Performs: z_res = x_res + y_hid + z_hid = ln(z_hid) + """ + + def __call__(self, z_res: torch.Tensor, z_hid: torch.Tensor, x_res: torch.Tensor, y_hid: torch.Tensor, + gamma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + z_res can alias x_res. All non-parameter input/output tensors + must have the same shape. z_hid can alias y_hid. + + Parameters: + z_res (torch.Tensor): Output residual. + z_hid (torch.Tensor): Output hidden states. + x_res (torch.Tensor): Input residual. + y_hid (torch.Tensor): Input hidden states. + gamma (torch.Tensor): Gamma tensor. + beta (torch.Tensor): Beta tensor. + + Returns: + output (torch.Tensor): Output tensor. + """ + self.inf_module.rms_pre_norm(z_hid, z_res, y_hid, x_res, gamma, self.epsilon) + return z_res, z_hid diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py b/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py new file mode 100644 index 000000000000..05479d86c906 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .gated_activation import * diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py new file mode 100644 index 000000000000..ca1b62ba5c36 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, elem_size +from deepspeed.ops.op_builder import InferenceCoreBuilder + + +class CUDAGatedActivation(DSKernelBase): + """ + CUDA implementation of gated activation kernel. This kernel assumes that the input + tensor has gate and activation values in adjacent channels. The output tensor should + have half the dimensionality of the input tensor. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + supported_act_fns = [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU] + + def __init__(self, channels: int, fp_dtype: torch.dtype, act_fn: ActivationType) -> None: + """ + Compile and validate for the gated activation function. + + Args: + channels (int): Number of columns in the output tensor. Must be divisible to align + to 8 bytes. + fp_dtype (torch.dtype): Data type for the input/output/gamma. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + act_fn (ActivationType): Activation function to use. Only GEGLU is supported. + """ + if fp_dtype not in CUDAGatedActivation.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, CUDAGatedActivation.supported_dtypes)) + + act_fn = ActivationType(act_fn) + if act_fn not in CUDAGatedActivation.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, CUDAGatedActivation.supported_act_fns)) + + if elem_size(fp_dtype) * channels % 8 != 0: + raise ValueError("Channels must be divisible by 16 bytes") + + if elem_size(fp_dtype) * channels > 98304: + raise ValueError( + "Kernel only compiled to support 98304 bytes per row, please file an issue if your model requires more." + ) + + self.inf_module = InferenceCoreBuilder().load() + self.act_fn = act_fn + self.kernel = self.inf_module.gated_activation + + def __call__(self, output: torch.Tensor, input: torch.Tensor, bias: Optional[torch.Tensor] = None) -> None: + """ + Performs gated activation on the input tensor, writing the result to the output tensor. + + Args: + output (torch.Tensor): Output tensor. Can be of [T, C // 2] or [B, S, C // 2] + input (torch.Tensor): Input tensor. Can be of [T, C] or [B, S, C] + """ + self.kernel(output, input, bias, self.act_fn.value) diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp new file mode 100644 index 000000000000..05463c75138c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "gated_activation_kernels.h" + +#ifdef BF16_AVAILABLE +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using scalar_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#else +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using scalar_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using scalar_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dtype for BiasActivation"); \ + } \ + }() +#endif + +void ds_gated_activation(at::Tensor& output, + at::Tensor& input, + c10::optional& bias, + int activation_type_raw) +{ + bool ragged_input = input.dim() == 2; + + const ActivationType activation_type = static_cast(activation_type_raw); + + const int rows = ragged_input ? input.size(0) : input.size(0) * input.size(1); + const int cols = ragged_input ? input.size(1) : input.size(2); + + DISPATCH_FOR_FLOAT(input.scalar_type(), [&] { + scalar_t* bias_ptr = nullptr; + if (bias.has_value()) { + TORCH_CHECK(bias.value().scalar_type() == input.scalar_type(), + "Bias type must match input type"); + TORCH_CHECK(bias.value().numel() == cols, + "Bias must have the same number of elements as the input channels"); + bias_ptr = reinterpret_cast(bias.value().data_ptr()); + } + + scalar_t* output_ptr = reinterpret_cast(output.data_ptr()); + const scalar_t* input_ptr = reinterpret_cast(input.data_ptr()); + + launch_gated_activation(output_ptr, + input_ptr, + bias_ptr, + rows, + cols, + activation_type, + c10::cuda::getCurrentCUDAStream()); + }); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu new file mode 100644 index 000000000000..84a9906cf037 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "activation_type.h" +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +namespace gated_act { + +constexpr int access_size = 16; +constexpr int threads = 1024; + +template +float gated_act_fn(float x, float y); + +template <> +DS_D_INLINE float gated_act_fn(float x, float y) +{ + constexpr float sqrt_param = 0.79788456080286535587989211986876f; + constexpr float mul_param = 0.044715; + return y * x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +template <> +DS_D_INLINE float gated_act_fn(float x, float y) +{ + return y * (x > 0.0f ? x : 0.0f); +} + +template <> +DS_D_INLINE float gated_act_fn(float x, float y) +{ + return y * (x / (1.0f + expf(-x))); +} + +} // namespace gated_act + +template +__global__ void gated_activation_kernel(T* output, + const T* input, + const T* bias, + int rows, + int cols) +{ + constexpr int read_vector = gated_act::access_size / sizeof(T); + constexpr int write_vector = read_vector / 2; + + const int row = blockIdx.x; + const int col = threadIdx.x * read_vector; + + const T* input_row = input + row * cols; + T* output_row = output + row * cols / 2; + +#pragma unroll + for (int i = 0; i < loopUnroll; i++) { + T read[read_vector]; + T bias_read[read_vector]; + T store[write_vector]; + + const int read_offset = col + gated_act::threads * read_vector * i; + const int write_offset = col / 2 + gated_act::threads * write_vector * i; + + if (i != loopUnroll - 1 || read_offset < cols) { + mem_access::load_global(read, input_row + read_offset); + mem_access::load_global( + bias_read, bias + read_offset, bias != nullptr); + + for (int j = 0; j < write_vector; j++) { + float g_val = + conversion::to(read[j * 2]) + conversion::to(bias_read[j * 2]); + float a_val = conversion::to(read[j * 2 + 1]) + + conversion::to(bias_read[j * 2 + 1]); + + float act_val = gated_act::gated_act_fn(g_val, a_val); + store[j] = conversion::to(act_val); + } + + mem_access::store_global(output_row + write_offset, store); + } + } +} + +#define DISPATCH_UNROLL(unroll_val) \ + gated_activation_kernel \ + <<>>(output, input, bias, rows, cols); + +template +void launch_gated_activation_impl(T* output, + const T* input, + const T* bias, + int rows, + int cols, + cudaStream_t stream) +{ + constexpr int read_vector = gated_act::access_size / sizeof(T); + constexpr int cols_per_unroll = gated_act::threads * read_vector; + const int req_threads = (cols + read_vector - 1) / read_vector; + const int threads = std::min(req_threads, gated_act::threads); + + const dim3 grid(rows); + const dim3 block(threads); + const int unroll = (cols + cols_per_unroll - 1) / cols_per_unroll; + + if (unroll == 1) { + DISPATCH_UNROLL(1); + } else if (unroll == 2) { + DISPATCH_UNROLL(2); + } else if (unroll == 3) { + DISPATCH_UNROLL(3); + } else if (unroll == 4) { + DISPATCH_UNROLL(4); + } else if (unroll == 5) { + DISPATCH_UNROLL(5); + } else if (unroll == 6) { + DISPATCH_UNROLL(6); + } else { + throw std::runtime_error( + "Called with more columns than supported, please report this bug and this limit will " + "be increased."); + } +} + +template +void launch_gated_activation(T* output, + const T* input, + const T* bias, + int rows, + int cols, + ActivationType act_type, + cudaStream_t stream) +{ + switch (act_type) { + case ActivationType::GEGLU: + launch_gated_activation_impl( + output, input, bias, rows, cols, stream); + break; + case ActivationType::ReGLU: + launch_gated_activation_impl( + output, input, bias, rows, cols, stream); + break; + case ActivationType::SiGLU: + launch_gated_activation_impl( + output, input, bias, rows, cols, stream); + break; + default: throw std::runtime_error("Unsupported activation type"); + } +} + +#define INSTANTIATE_FOR_TYPE(T) \ + template void launch_gated_activation(T * output, \ + const T* input, \ + const T* bias, \ + int rows, \ + int cols, \ + ActivationType act_type, \ + cudaStream_t stream); + +INSTANTIATE_FOR_TYPE(float) +INSTANTIATE_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.h b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.h new file mode 100644 index 000000000000..6ae01e99679a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "activation_type.h" +#include "ds_kernel_utils.h" + +template +void launch_gated_activation(T* output, + const T* vals, + const T* bias, + int rows, + int cols, + ActivationType activation_type, + cudaStream_t stream); + +void ds_gated_activation(at::Tensor& output, + at::Tensor& input, + c10::optional& bias, + int activation_type_raw); diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/LICENSE b/deepspeed/inference/v2/kernels/cutlass_ops/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/__init__.py b/deepspeed/inference/v2/kernels/cutlass_ops/__init__.py new file mode 100644 index 000000000000..44b9adbae794 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .mixed_gemm import * +from .moe_gemm import * diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/cutlass_ops.cpp b/deepspeed/inference/v2/kernels/cutlass_ops/cutlass_ops.cpp new file mode 100644 index 000000000000..18e834f3e60a --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/cutlass_ops.cpp @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "mixed_gemm.h" +#include "moe_gemm.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + // mixed_gemm.h + m.def("mixed_gemm", &mixed_gemm, "Mixed-precision GEMM"); + + // moe_gemm.h + m.def("moe_gemm", &moe_gemm, "MultiGEMM for MoE (16-bit weights)"); + m.def("mixed_moe_gemm", &mixed_moe_gemm, "MultiGEMM for MoE (4-bit/8-bit weights)"); +} diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/__init__.py b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/__init__.py new file mode 100644 index 000000000000..14ccf2ce5354 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .mixed_gemm import * diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu new file mode 100644 index 000000000000..7c522203bb48 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "mixed_gemm.h" +#include "mixed_gemm_api.h" +#include "weight_variant.h" + +// Switch helpers inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#define ACT_DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using ActivationDtype = __half; \ + return __VA_ARGS__(); \ + } else { \ + using ActivationDtype = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + }() + +#define WEIGHT_VARIANT_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + constexpr WeightVariant WVariant = WeightVariant::kFP8; \ + return __VA_ARGS__(); \ + } else { \ + constexpr WeightVariant WVariant = WeightVariant::kFP4; \ + return __VA_ARGS__(); \ + } \ + }() + +void mixed_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + int num_bits, + int activation_raw) +{ + TORCH_CHECK(output.dtype() == hidden_states.dtype(), + "Output and hidden states must have the same dtype"); + TORCH_CHECK(num_bits == 4 || num_bits == 8, "Data width must be 4 or 8"); + TORCH_CHECK(output.size(0) == hidden_states.size(0), "Token dimension mismatch"); + + int32_t m = output.size(0); + int32_t k = hidden_states.size(1); + int32_t n = weight.size(1); + + TORCH_CHECK(weight.size(0) == k, "Weight dimension mismatch"); + + ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] { + WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] { + fastertransformer::CutlassFpAIntBGemmRunner runner = + *MixedGemmContext::Instance().GeMM_Runner(); + + ActivationType activation_type = (ActivationType)activation_raw; + if (!bias.has_value() && activation_type == ActivationType::IDENTITY) { + runner.gemm((ActivationDtype*)hidden_states.data_ptr(), + (const char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + (ActivationDtype*)output.data_ptr(), + m, + n, + k, + nullptr, + 0, + at::cuda::getCurrentCUDAStream()); + return; + } else { + ActivationDtype* bias_ptr = nullptr; + if (bias.has_value()) { bias_ptr = (ActivationDtype*)bias.value().data_ptr(); } + runner.gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + bias_ptr, + (ActivationDtype*)output.data_ptr(), + m, + n, + k, + activation_type, + nullptr, + 0, + at::cuda::getCurrentCUDAStream()); + return; + } + }); + }); +} diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.h b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.h new file mode 100644 index 000000000000..1fc3831e9084 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +void mixed_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + int num_bits, + int activation_raw); diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.py b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.py new file mode 100644 index 000000000000..dddb555e267a --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCutlassBuilder + +from typing import Optional + + +class MixedGEMM(DSKernelBase): + """ + CUTLASS implementation of MoE GEMM. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY] + + def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None: + + if not isinstance(fp_dtype, DtypeEnum): + fp_dtype = DtypeEnum(fp_dtype) + + if fp_dtype not in MixedGEMM.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, MixedGEMM.supported_dtypes)) + + if act_fn not in MixedGEMM.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, MixedGEMM.supported_act_fns)) + + if num_bits != 4 and num_bits != 8: + raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits)) + + inf_module = InferenceCutlassBuilder().load() + self.num_bits = num_bits + self.kernel = inf_module.moe_gemm + self.act_fn = act_fn + + def __call__(self, + output: torch.Tensor, + hidden_states: torch.Tensor, + weights: torch.Tensor, + scales: torch.Tensor, + biases: Optional[torch.Tensor] = None) -> None: + """ + Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2). + + Arguments: + output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons]. + hidden_states (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons]. + weights (torch.Tensor): The weights of shape [in_neurons, out_neurons]. These weights must be contiguous. + scales (torch.Tensor): The scales of shape [out_neurons]. These scales must be contiguous. + biases (torch.Tensor): The biases of shape [out_neurons]. These biases must be contiguous. + + Returns: + output + """ + self.kernel(output, hidden_states, weights, biases, self.num_bits, self.act_fn) + return output diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm_api.h b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm_api.h new file mode 100644 index 000000000000..74fc07ffc4a2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm_api.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "activation_type.h" +#include "weight_variant.h" + +namespace fastertransformer { + +template +class CutlassFpAIntBGemmRunner { +public: + void gemm(const T* A, + const char* B, + const T* weight_scales, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + void gemm_bias_act(const T* A, + const char* B, + const T* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + ActivationType activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); +}; + +} // namespace fastertransformer + +template +class MixedGemmContext { +public: + MixedGemmContext() { _runner = new fastertransformer::CutlassFpAIntBGemmRunner(); } + + virtual ~MixedGemmContext() { delete _runner; } + + static MixedGemmContext& Instance() + { + static MixedGemmContext _ctx; + return _ctx; + } + + fastertransformer::CutlassFpAIntBGemmRunner* GeMM_Runner() const { return _runner; } + + fastertransformer::CutlassFpAIntBGemmRunner* _runner; +}; diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/__init__.py b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/__init__.py new file mode 100644 index 000000000000..aff4e77bba98 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .mixed_moe_gemm import * +from .moe_gemm import * diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/mixed_moe_gemm.py b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/mixed_moe_gemm.py new file mode 100644 index 000000000000..9c55ce341532 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/mixed_moe_gemm.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCutlassBuilder + +from typing import Optional + + +class MixedMoEGEMM(DSKernelBase): + """ + CUTLASS implementation of MoE GEMM. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY] + + def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType, num_bits: int) -> None: + + if not isinstance(fp_dtype, DtypeEnum): + fp_dtype = DtypeEnum(fp_dtype) + + if fp_dtype not in MixedMoEGEMM.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, MixedMoEGEMM.supported_dtypes)) + + if act_fn not in MixedMoEGEMM.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, MixedMoEGEMM.supported_act_fns)) + + if num_bits != 4 and num_bits != 8: + raise ValueError("Unsupported num_bits: {}, supported num_bits are 4 and 8".format(num_bits)) + + inf_module = InferenceCutlassBuilder().load() + self.num_bits = num_bits + self.kernel = inf_module.moe_gemm + self.act_fn = act_fn + + def __call__(self, + ordered_output: torch.Tensor, + ordered_input: torch.Tensor, + weights: torch.Tensor, + scales: torch.Tensor, + total_rows_before_expert: torch.Tensor, + biases: Optional[torch.Tensor] = None) -> None: + """ + Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2). + + Arguments: + ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons]. + ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons]. + weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous. + scales (torch.Tensor): The scales of shape [n_experts, out_neurons]. These scales must be contiguous. + total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts]. + biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous. + + Returns: + ordered_output + """ + self.kernel(ordered_output, ordered_input, weights, scales, biases, total_rows_before_expert, self.num_bits, + self.act_fn) + return ordered_output diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu new file mode 100644 index 000000000000..d1cafc9fff4c --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "moe_gemm.h" +#include "moe_gemm_api.h" +#include "weight_variant.h" + +// Switch helpers inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#define HIDDEN_DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using ActivationDtype = __half; \ + constexpr WeightVariant WVariant = WeightVariant::kFP16; \ + return __VA_ARGS__(); \ + } else { \ + using ActivationDtype = __nv_bfloat16; \ + constexpr WeightVariant WVariant = WeightVariant::kBF16; \ + return __VA_ARGS__(); \ + } \ + }() + +void moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int activation_raw) +{ + TORCH_CHECK(output.dtype() == hidden_states.dtype(), + "Output and hidden states must have the same dtype"); + TORCH_CHECK(output.dtype() == weight.dtype(), "Output and weight must have the same dtype"); + + int64_t total_rows = hidden_states.size(0); + int64_t gemm_k = hidden_states.size(1); + int64_t gemm_n = weight.size(2); + int num_experts = weight.size(0); + + TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch"); + TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch"); + TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch"); + TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch"); + + HIDDEN_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] { + fastertransformer::MoeGemmRunner runner = + *MoeGemmContext::Instance().GeMM_Runner(); + + ActivationType activation_type = (ActivationType)activation_raw; + if (!bias.has_value() && activation_type == ActivationType::IDENTITY) { + runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + nullptr, + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + at::cuda::getCurrentCUDAStream()); + return; + } else { + ActivationDtype* bias_ptr = nullptr; + if (bias.has_value()) { + bias_ptr = (ActivationDtype*)bias.value().data_ptr(); + TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch"); + TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch"); + } + runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + nullptr, + bias_ptr, + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + activation_type, + at::cuda::getCurrentCUDAStream()); + return; + } + }); +} + +#define ACT_DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using ActivationDtype = __half; \ + return __VA_ARGS__(); \ + } else { \ + using ActivationDtype = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + }() + +#define WEIGHT_VARIANT_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + constexpr WeightVariant WVariant = WeightVariant::kFP8; \ + return __VA_ARGS__(); \ + } else { \ + constexpr WeightVariant WVariant = WeightVariant::kFP4; \ + return __VA_ARGS__(); \ + } \ + }() + +void mixed_moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int num_bits, + int activation_raw) +{ + TORCH_CHECK(output.dtype() == hidden_states.dtype(), + "Output and hidden states must have the same dtype"); + + int64_t total_rows = hidden_states.size(0); + int64_t gemm_k = hidden_states.size(1); + int64_t gemm_n = weight.size(2); + int num_experts = weight.size(0); + + TORCH_CHECK(total_rows == output.size(0), "Total rows dimension mismatch"); + TORCH_CHECK(gemm_k == weight.size(1), "GEMM K dimension mismatch"); + TORCH_CHECK(gemm_n == output.size(1), "GEMM N dimension mismatch"); + TORCH_CHECK(num_experts == total_rows_before_expert.size(0), "Number of experts mismatch"); + + ACT_DTYPE_SWITCH(hidden_states.dtype() == torch::kFloat16, [&] { + WEIGHT_VARIANT_SWITCH(num_bits == 8, [&] { + fastertransformer::MoeGemmRunner runner = + *MoeGemmContext::Instance().GeMM_Runner(); + + ActivationType activation_type = (ActivationType)activation_raw; + if (!bias.has_value() && activation_type == ActivationType::IDENTITY) { + runner.moe_gemm((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + at::cuda::getCurrentCUDAStream()); + return; + } else { + ActivationDtype* bias_ptr = nullptr; + if (bias.has_value()) { + bias_ptr = (ActivationDtype*)bias.value().data_ptr(); + TORCH_CHECK(num_experts == bias.value().size(0), "Number of experts mismatch"); + TORCH_CHECK(gemm_n == bias.value().size(1), "GEMM N dimension mismatch"); + } + runner.moe_gemm_bias_act((ActivationDtype*)hidden_states.data_ptr(), + (char*)weight.data_ptr(), + (ActivationDtype*)scales.data_ptr(), + bias_ptr, + (ActivationDtype*)output.data_ptr(), + (int64_t*)total_rows_before_expert.data_ptr(), + total_rows, + gemm_n, + gemm_k, + num_experts, + activation_type, + at::cuda::getCurrentCUDAStream()); + return; + } + }); + }); +} diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.h b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.h new file mode 100644 index 000000000000..dfd3d4561567 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +void moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int activation_raw); + +void mixed_moe_gemm(at::Tensor& output, + at::Tensor& hidden_states, + at::Tensor& weight, + at::Tensor& scales, + c10::optional& bias, + at::Tensor& total_rows_before_expert, + int num_bits, + int activation_raw); diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.py b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.py new file mode 100644 index 000000000000..0cc233e8d87a --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import ActivationType, DtypeEnum +from deepspeed.ops.op_builder import InferenceCutlassBuilder + +from typing import Optional + + +class MoEGEMM(DSKernelBase): + """ + CUTLASS implementation of MoE GEMM. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_act_fns = [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU, ActivationType.IDENTITY] + + def __init__(self, fp_dtype: DtypeEnum, act_fn: ActivationType) -> None: + + if not isinstance(fp_dtype, DtypeEnum): + fp_dtype = DtypeEnum(fp_dtype) + + if fp_dtype not in MoEGEMM.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, MoEGEMM.supported_dtypes)) + + if act_fn not in MoEGEMM.supported_act_fns: + raise ValueError("Unsupported activation function: {}, supported_act_fns are {}".format( + act_fn, MoEGEMM.supported_act_fns)) + + inf_module = InferenceCutlassBuilder().load() + self.kernel = inf_module.moe_gemm + self.act_fn = act_fn + + def __call__(self, + ordered_output: torch.Tensor, + ordered_input: torch.Tensor, + weights: torch.Tensor, + total_rows_before_expert: torch.Tensor, + biases: Optional[torch.Tensor] = None) -> None: + """ + Performs a MoE GEMM. Note that the stride between token inputs must be even (the distance between byte 1 of token 0 and token 1 must be the same as the distance between byte 1 of token 1 and token 2). + + Arguments: + ordered_output (torch.Tensor): The output of the MoE GEMM of shape [n_tokens, out_neurons]. + ordered_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, in_neurons]. + weights (torch.Tensor): The weights of shape [n_experts, in_neurons, out_neurons]. These weights must be contiguous. + total_rows_before_expert (torch.Tensor): The total number of rows before each expert of shape [n_experts]. + biases (torch.Tensor): The biases of shape [n_experts, out_neurons]. These biases must be contiguous. + + Returns: + ordered_output + """ + self.kernel(ordered_output, ordered_input, weights, biases, total_rows_before_expert, self.act_fn) + return ordered_output diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm_api.h b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm_api.h new file mode 100644 index 000000000000..7ad92070b35f --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm_api.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "activation_type.h" +#include "weight_variant.h" + +namespace fastertransformer { + +template +class MoeGemmRunner { +public: + MoeGemmRunner(); + + void moe_gemm_bias_act(const T* A, + const char* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + ActivationType activation_type, + cudaStream_t stream); + + void moe_gemm(const T* A, + const char* B, + const T* weight_scales, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + cudaStream_t stream); + +private: + int sm_; + int multi_processor_count_; +}; + +} // namespace fastertransformer + +template +class MoeGemmContext { +public: + MoeGemmContext() { _runner = new fastertransformer::MoeGemmRunner(); } + + virtual ~MoeGemmContext() { delete _runner; } + + static MoeGemmContext& Instance() + { + static MoeGemmContext _ctx; + return _ctx; + } + + fastertransformer::MoeGemmRunner* GeMM_Runner() const { return _runner; } + + fastertransformer::MoeGemmRunner* _runner; +}; diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/shared_resources/weight_variant.h b/deepspeed/inference/v2/kernels/cutlass_ops/shared_resources/weight_variant.h new file mode 100644 index 000000000000..4d17c799f726 --- /dev/null +++ b/deepspeed/inference/v2/kernels/cutlass_ops/shared_resources/weight_variant.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// Data structure that allows us to abstract internal CUTLASS datatypes/mappings +// to the DeepSpeed-Kernels repo. + +#pragma once + +enum WeightVariant { kFP16, kBF16, kFP8, kFP4 }; diff --git a/deepspeed/inference/v2/kernels/ds_kernel.py b/deepspeed/inference/v2/kernels/ds_kernel.py new file mode 100644 index 000000000000..8dbfa1de86a6 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ds_kernel.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod + + +class DSKernelBase(ABC): + + @abstractmethod + def __init__(self, *args, **kwargs): + """ + If necessary trigger compilation and warmup + Autotuning of the kernel would happen at this stage to + eliminate any potential hangs that might occur mid-deployment + Validate that the desired run configuration is compatible. + + It is not necessary to call super on this method. + """ + raise NotImplementedError() + + @abstractmethod + def __call__(self, *args, **kwargs): + """ + However the kernel needs to be called, it can be called here. Auto-tuning + should never be performed here. + + All inputs/outputs should be passed as arguments to this function. No allocations + should be performed here. + """ + raise NotImplementedError() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py new file mode 100644 index 000000000000..988152b2e7c0 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .atom_builder import * +from .blocked_flash import * +from .embed import * +from .linear_blocked_kv_rotary import * +from .logits_gather import * +from .moe_gather import * +from .moe_scatter import * +from .top_1_gating import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/__init__.py new file mode 100644 index 000000000000..c79201cdf165 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .atom_builder import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp new file mode 100644 index 000000000000..7ad4dc5faa20 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "atom_builder.h" +#include "attention_atom.h" +#include "ragged_dtypes.h" + +int32_t build_atoms(torch::Tensor& atoms_ten, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& kv_ptrs, + const int32_t q_block_size, + const int32_t kv_block_size) +{ + const RaggedBatchDescriptor* batch_desc = + reinterpret_cast(batch_metadata.data_ptr()); + + const InflightSeqDescriptor* seq_desc = + reinterpret_cast(seq_metadata.data_ptr()); + + int32_t** kv_ptr_list = reinterpret_cast(kv_ptrs.data_ptr()); + + AttentionAtom* atoms = reinterpret_cast(atoms_ten.data_ptr()); + + int32_t n_atoms = 0; + for (int i = 0; i < batch_desc->n_sequences; i++) { + const int seq_atoms = (seq_desc[i].n_tokens + q_block_size - 1) / q_block_size; + int32_t cur_start_idx = seq_desc[i].start_idx; + int32_t global_start_idx = seq_desc[i].seen_tokens; + int32_t remaining_toks = seq_desc[i].n_tokens; + + for (int j = 0; j < seq_atoms; j++) { + atoms[n_atoms].block_idx_list = kv_ptr_list[i]; + atoms[n_atoms].q_start_idx = cur_start_idx; + atoms[n_atoms].q_len = std::min(remaining_toks, q_block_size); + atoms[n_atoms].global_q_idx = global_start_idx; + + const int32_t end_toks = global_start_idx + atoms[n_atoms].q_len; + // TODO(cmikeh2): This logic needs to be changed for sparse implementations + atoms[n_atoms].kv_blocks = (end_toks + kv_block_size - 1) / kv_block_size; + atoms[n_atoms].total_extent = end_toks; + + cur_start_idx += atoms[n_atoms].q_len; + global_start_idx += atoms[n_atoms].q_len; + remaining_toks -= atoms[n_atoms].q_len; + n_atoms++; + } + } + + return n_atoms; +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.h b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.h new file mode 100644 index 000000000000..a3342d0e6695 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +/* +Construct the attention atoms given the ragged metadata for the current batch. +This could largely be done at the Python level, but since we pack the KV ptr +alongside the int32_t metadata, it gets very ugly to handle the mixed-width +data structures (since we're packing them in a single tensor). +*/ +int32_t build_atoms(torch::Tensor& atoms_ten, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& kv_ptrs, + const int32_t q_block_size, + const int32_t kv_block_size); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.py b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.py new file mode 100644 index 000000000000..3355ca76c6a4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/atom_builder/atom_builder.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch + +from ... import DSKernelBase +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....ragged import RaggedBatchWrapper + + +class AtomBuilder(DSKernelBase): + """ + C++ implementation to populate the attention atoms for the blocked attention + kernel. + """ + + def __init__(self) -> None: + """ + Triggers compilation of the C++ implementation. + """ + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.build_atoms + + def __call__(self, atoms: torch.Tensor, ragged_batch: RaggedBatchWrapper, q_block_size: int, + kv_block_size: int) -> Tuple[torch.Tensor, int]: + """ + Populates the attention atoms for the blocked attention kernel. + + Args: + atoms (torch.Tensor): Pre-allocated int32 tensor of shape [max_atoms, 8] + ragged_batch (torch.Tensor): Wrapper for the ragged batch. + q_block_size (int): The block size for the queries (as determined by the + attention implementation) + kv_block_size (int): The block size for the keys/values (as determined by the + attention implementation) + + Returns: + + """ + if atoms.device != torch.device("cpu"): + raise RuntimeError("AtomBuilder must be called on tensors") + + n_atoms = self.kernel(atoms, ragged_batch.batch_metadata_buffer(on_device=False), + ragged_batch.inflight_seq_descriptors(on_device=False), + ragged_batch.kv_ptrs(on_device=False), q_block_size, kv_block_size) + return atoms, n_atoms diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/__init__.py new file mode 100644 index 000000000000..87b2b3d68777 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blocked_flash import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/attention_atom.h b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/attention_atom.h new file mode 100644 index 000000000000..ed8eb9e19b3d --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/attention_atom.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "cuda.h" + +struct AttentionAtom { + /* + The attention atom describes the workload of a particular query. The attention + kernel will execute each ``AttentionAtom`` for each head of the model. + */ + + // Pointer to a list of KV block indices. + int32_t* block_idx_list; + + // Index of first token in the ragged batch associated with this atom. + int32_t q_start_idx; + + // Number of tokens in the ragged batch associated with this atom. + int32_t q_len; + + // Number of key/value blocks associated with this atom. All but the last are + // assumed to be fully dense. + int32_t kv_blocks; + + // Number of tokens in the last key/value block. + int32_t total_extent; + + // Global index of the first token in the atom. For example, in a prompt continuation + // in which we have already processed 768 tokens, this would be 768. + int32_t global_q_idx; + + // Unused + int32_t unused; +}; diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp new file mode 100644 index 000000000000..6cb60ad00f9b --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include + +#include "blocked_flash.h" +#include "flash.h" + +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + #x " must have shape (" #__VA_ARGS__ ")") + +void flash_attn_by_atoms(at::Tensor& out, + at::Tensor& q, + at::Tensor& k, + at::Tensor& v, + at::Tensor& attention_atoms, + const float softmax_scale, + const bool is_causal) +{ + auto dprops = at::cuda::getCurrentDeviceProperties(); + + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + if (q_dtype == torch::kBFloat16) { + TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + } + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device"); + TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + const int total_q = q.size(0); + const int head_size = k.size(-1); + const int num_heads_kv = k.size(-2); + const int num_heads_q = q.size(-1) / head_size; + + TORCH_CHECK(head_size <= 256, "head_size must be <= 256"); + TORCH_CHECK(head_size % 8 == 0, "head_size must be divisible by 8"); + TORCH_CHECK(num_heads_q % num_heads_kv == 0, "num_heads_q must be divisible by num_heads_kv"); + + Flash_fwd_params params; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.o_ptr = out.data_ptr(); + params.atoms = reinterpret_cast(attention_atoms.data_ptr()); + + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(0); + params.k_row_stride = k.stride(1); + params.v_row_stride = v.stride(1); + params.o_row_stride = out.stride(0); + + // Assume heads are contiguous. + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_head_stride = head_size; + + // Head params + params.h = num_heads_q; + params.h_k = num_heads_kv; + params.h_h_k_ratio = num_heads_q / num_heads_kv; + params.d = head_size; + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.d_rounded = round_multiple(head_size, 32); + params.num_atoms = attention_atoms.size(0); + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + params.is_causal = is_causal; + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.h b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.h new file mode 100644 index 000000000000..68037b425113 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +void flash_attn_by_atoms(at::Tensor& out, + at::Tensor& q, + at::Tensor& k, + at::Tensor& v, + at::Tensor& attention_atoms, + const float softmax_scale, + const bool is_causal); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.py b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.py new file mode 100644 index 000000000000..54d465698b4e --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.accelerator import get_accelerator +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder + +from ... import DSKernelBase + + +def get_q_block_size(head_size: int) -> int: + """ + Returns the query block size required by the kernel given a head size. + """ + cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda + + if cc_major < 8: + raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0") + + if head_size <= 64: + return 128 + elif head_size <= 160: + if cc_minor != 0: + return 64 + else: + return 128 + elif head_size == 192: + return 128 + elif head_size == 224: + if cc_minor != 0: + return 64 + else: + return 128 + else: + if cc_major == 8 and cc_minor == 0: + return 128 + else: + return 64 + + +def get_kv_block_size(head_size: int) -> int: + """ + Return preferred granulatity for blocked KV-cache implementation. + """ + cc_major, cc_minor = torch.cuda.get_device_capability(get_accelerator().current_device()) #ignore-cuda + + if cc_major < 8: + raise RuntimeError("Blocked attention requires CUDA compute capability >= 8.0") + + if head_size <= 64: + return 128 + elif head_size != 160 or cc_minor != 0: + return 64 + else: + return 32 + + +class BlockedFlashAttn(DSKernelBase): + """ + Modified implementation of flash-attn-2 tuned for inference on blocked KV-cache and wider + range of input sequence lengths. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + + def __init__(self, head_size: int, dtype: DtypeEnum) -> None: + """ + Triggers any compilation of the kernels. + """ + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in BlockedFlashAttn.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported data types are {}".format( + dtype, BlockedFlashAttn.supported_dtypes)) + + # For testing, need to revert to 32 + if head_size % 16 != 0: + raise ValueError("Head size must be divisible by 32 (configured with {})".format(head_size)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.flash_attn_by_atoms + + def __call__(self, out: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, atoms: torch.Tensor, + softmax_scale: float) -> torch.Tensor: + """ + Flash attention implementation atop a blocked KV-cache. Atoms should be pre-populated. + See attention_atom.h for further details on the structure of the information. + + Arguments: + out (torch.Tensor): Output tensor of shape [tokens, hidden_size] + q (torch.Tensor): Query tensor of shape [tokens, hidden_size] + k (torch.Tensor): Key cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension. + v (torch.Tensor): Value cache tensor of shape [n_blocks, block_size, n_heads_kv, head_size]. This Tensor only needs to be contiguous on the final dimension. + atoms (torch.Tensor): Atom information tensor of shape [num_atoms, 8] and type int32. + Not all data is readable in this format. See attention_atom.h for further details. + softmax_scale (float): Softmax scale factor. + + Returns: + out (torch.Tensor): Output tensor of shape [tokens, hidden_size] + """ + self.kernel(out, q, k, v, atoms, softmax_scale, True) + return out diff --git a/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/flash.h b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/flash.h new file mode 100644 index 000000000000..b4a53e6d7f52 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/blocked_flash/flash.h @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/****************************************************************************** +Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#include "attention_atom.h" + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr; + + // The attention metadata + AttentionAtom* __restrict__ atoms; + + // Total attention atoms + int num_atoms; + + // The stride between rows of O. + index_t o_row_stride; + index_t o_head_stride; + + // The dimensions + int d, d_rounded; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + bool is_bf16; + bool is_causal; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/embed/__init__.py new file mode 100644 index 000000000000..d6b8e6047d74 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .embed import RaggedEmbeddingKernel diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cpp b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cpp new file mode 100644 index 000000000000..04b72bf948db --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "embed.h" +#include "ragged_kernel_helpers.h" + +#ifdef BF16_AVAILABLE +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using float_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using float_t = __half; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kBFloat16) { \ + using float_t = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dispatch type"); \ + } \ + }() +#else +#define DISPATCH_FOR_FLOAT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kFloat32) { \ + using float_t = float; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kFloat16) { \ + using float_t = __half; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dispatch type"); \ + } \ + }() +#endif + +#define DISPATCH_FOR_INT(DTYPE, ...) \ + [&] { \ + if (DTYPE == torch::kInt32) { \ + using int_t = int32_t; \ + return __VA_ARGS__(); \ + } else if (DTYPE == torch::kInt64) { \ + using int_t = int64_t; \ + return __VA_ARGS__(); \ + } else { \ + TORCH_CHECK(false, "Unsupported dispatch type"); \ + } \ + }() + +/* +Embeddings kernel aware of ragged batch structure. +*/ +void ragged_embed(torch::Tensor& embedded_tokens, + torch::Tensor& input_ids, + torch::Tensor& embedding_weight, + c10::optional& position_embedding_weight, + int32_t pos_embed_offset, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + // We don't care about KV cache here, so just hardcoding 0s for block_size/num_blocks + BatchWrapperCPP batch_wrapper = + make_cpp_batch_wrapper(batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, 0, 0); + + const int32_t n_tokens = input_ids.numel(); + const int32_t embed_dim = embedding_weight.size(1); + const int32_t vocab_size = embedding_weight.size(0); + + DISPATCH_FOR_INT(input_ids.scalar_type(), [&] { + DISPATCH_FOR_FLOAT(embedding_weight.scalar_type(), [&] { + float_t* pos_embed_ptr = nullptr; + int32_t max_position_embed_idx = 0; + if (position_embedding_weight.has_value()) { + TORCH_CHECK( + position_embedding_weight.value().options().dtype() == + embedding_weight.options().dtype(), + "position_embedding_weight and embedding_weight must have the same dtype"); + pos_embed_ptr = + reinterpret_cast(position_embedding_weight.value().data_ptr()); + max_position_embed_idx = position_embedding_weight.value().size(0) - 1; + } + + launch_ragged_embed_kernel((float_t*)embedded_tokens.data_ptr(), + (const int_t*)input_ids.data_ptr(), + (const float_t*)embedding_weight.data_ptr(), + pos_embed_ptr, + batch_wrapper, + n_tokens, + embed_dim, + vocab_size, + max_position_embed_idx, + pos_embed_offset, + at::cuda::getCurrentCUDAStream()); + }); + }); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cu b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cu new file mode 100644 index 000000000000..81d6d534ddf5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cu @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "embed.cuh" +#include "memory_access_utils.h" +#include "ragged_dtypes.h" + +namespace embed { + +constexpr int granularity = 16; +constexpr int threads = 512; + +} // namespace embed + +template +__global__ void ragged_embed_kernel(EmbedType* embedded_tokens, + const TokenType* input_ids, + const EmbedType* embedding_weight, + const EmbedType* position_weight, + const BatchWrapperCPP batch_desc, + const int32_t embed_dim, + const int32_t vocab_size, + const int32_t max_position_embed_idx, + const int32_t position_embed_offset) +{ + constexpr int T_vector = embed::granularity / sizeof(EmbedType); + + const int32_t token_idx = blockIdx.y; + + // It's possible our batch is padded (under CG conditions typically) + if (token_idx >= batch_desc.batch_metadata->n_tokens) return; + + TokenType token_value = input_ids[token_idx]; + + if (token_value >= vocab_size || token_value < 0) { + // TODO(cmikeh2): This is invalid, but not sure how we want to handle it being invalid + // yet. + return; + } + + const EmbedType* embedding_row = embedding_weight + token_value * embed_dim; + EmbedType* dest_row = embedded_tokens + token_idx * embed_dim; + + const int channel_offset = (threadIdx.x + embed::threads * blockIdx.x) * T_vector; + + if (channel_offset < embed_dim) { + EmbedType reg_buf[T_vector]; + + mem_access::load_global(reg_buf, embedding_row + channel_offset); + + if (position_weight != nullptr) { + // Map the token to its global idx (indirect memory accesses aren't great but whatever) + const int32_t seq_idx = batch_desc.tokens_to_seq[token_idx]; + const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_idx]; + int32_t pos_emb_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx); + + // Position embed offset is an OPT-specific feature I think? + pos_emb_idx = pos_emb_idx + position_embed_offset; + + // This clamping is technically + pos_emb_idx = (pos_emb_idx < 0) ? 0 : pos_emb_idx; + pos_emb_idx = (pos_emb_idx >= max_position_embed_idx) ? max_position_embed_idx + : pos_emb_idx; + + const EmbedType* position_embedding_row = position_weight + pos_emb_idx * embed_dim; + + EmbedType pos_buf[T_vector]; + mem_access::load_global(pos_buf, + position_embedding_row + channel_offset); + +#pragma unroll + for (int i = 0; i < T_vector; i++) { reg_buf[i] += pos_buf[i]; } + } + + mem_access::store_global(dest_row + channel_offset, reg_buf); + } +} + +template +void launch_ragged_embed_kernel(EmbedType* embedded_tokens, + const TokenType* input_ids, + const EmbedType* embedding_weight, + const EmbedType* position_weight, + const BatchWrapperCPP batch_desc, + const int32_t n_tokens, + const int32_t embed_dim, + const int32_t vocab_size, + const int32_t max_position_embed_idx, + const int32_t position_embed_offset, + cudaStream_t stream) +{ + constexpr int T_vector = embed::granularity / sizeof(EmbedType); + constexpr int elems_per_block = embed::threads * T_vector; + const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block; + + const dim3 grid_dim(parallel_blocks, n_tokens, 1); + const dim3 block_dim(embed::threads, 1, 1); + + ragged_embed_kernel + <<>>(embedded_tokens, + input_ids, + embedding_weight, + position_weight, + batch_desc, + embed_dim, + vocab_size, + max_position_embed_idx, + position_embed_offset); +} + +#define INSTANTIATE_EMBED_FOR_TYPES(TOKEN_TYPE, EMBED_TYPE) \ + template void launch_ragged_embed_kernel( \ + EMBED_TYPE * embedded_tokens, \ + const TOKEN_TYPE* input_ids, \ + const EMBED_TYPE* embedding_weight, \ + const EMBED_TYPE* position_weight, \ + const BatchWrapperCPP batch_descriptor, \ + const int32_t n_tokens, \ + const int32_t embed_dim, \ + const int32_t vocab_size, \ + const int32_t max_position_embed_idx, \ + const int32_t position_embed_offset, \ + cudaStream_t stream); + +INSTANTIATE_EMBED_FOR_TYPES(int32_t, float) +INSTANTIATE_EMBED_FOR_TYPES(int64_t, float) + +INSTANTIATE_EMBED_FOR_TYPES(int32_t, __half) +INSTANTIATE_EMBED_FOR_TYPES(int64_t, __half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_EMBED_FOR_TYPES(int32_t, __nv_bfloat16) +INSTANTIATE_EMBED_FOR_TYPES(int64_t, __nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cuh b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cuh new file mode 100644 index 000000000000..94c397439b80 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +#ifdef BF16_AVAILABLE +#include +#endif + +template +void launch_ragged_embed_kernel(EmbedType* embedded_tokens, + const TokenType* input_ids, + const EmbedType* embedding_weight, + const EmbedType* position_weight, + const BatchWrapperCPP batch_desc, + const int32_t n_tokens, + const int32_t embed_dim, + const int32_t vocab_size, + const int32_t max_position_embed_idx, + const int32_t position_embed_offset, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.h b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.h new file mode 100644 index 000000000000..7897c1362669 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "embed.cuh" + +/* +Embeddings kernel aware of ragged batch structure. +*/ +void ragged_embed(torch::Tensor& embedded_tokens, + torch::Tensor& input_ids, + torch::Tensor& embedding_weight, + c10::optional& position_weight, + int32_t position_embed_offset, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.py b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.py new file mode 100644 index 000000000000..0443ce3fdd8e --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/embed/embed.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from ... import DSKernelBase +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....inference_utils import elem_size +from ....ragged import RaggedBatchWrapper + + +class RaggedEmbeddingKernel(DSKernelBase): + """ + Ragged-aware CUDA kernel implementation for an embedding lookup. This will only lookup + the necessary tokens for a padded batch (i.e. if we are CGed and running with a slightly + larger batch size than the actual tokens). + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + supported_token_dtypes = [torch.int32, torch.int64] + + def __init__(self, embed_dtype: torch.dtype, token_dtype: torch.dtype, embed_dim: int) -> None: + """ + Args: + fp_dtype (torch.dtype): Data type of the embedding table and output dtype. + Supported values are torch.float16, torch.bfloat16, and torch.float32. + token_dtype (torch.dtype): Data type of the token ids. Supported values are + torch.int32 and torch.int64. + embed_dim (int): Embedding dimension. Must be aligned to 16 bytes. + """ + if embed_dtype not in RaggedEmbeddingKernel.supported_dtypes: + raise ValueError("Unsupported embedding data type: {}, supported_dtypes are {}".format( + embed_dtype, RaggedEmbeddingKernel.supported_dtypes)) + + if token_dtype not in RaggedEmbeddingKernel.supported_token_dtypes: + raise ValueError("Unsupported token data type: {}, supported_dtypes are {}".format( + token_dtype, RaggedEmbeddingKernel.supported_token_dtypes)) + + if elem_size(embed_dtype) * embed_dim % 16 != 0: + raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(embed_dim)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.ragged_embed + + def __call__(self, + embedded_tokens: torch.Tensor, + ragged_wrapper: RaggedBatchWrapper, + embedding_weight: torch.Tensor, + position_embed_weight: Optional[torch.Tensor] = None, + position_embed_offset: int = 0) -> torch.Tensor: + """ + Ragged aware embedding lookup. + + Args: + embedded_tokens (torch.Tensor): Output tensor of shape [num_tokens, embed_dim] + ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch. + embedding_weight (torch.Tensor): Embedding table of shape [vocab_size, embed_dim] + """ + self.kernel(embedded_tokens, ragged_wrapper.input_ids(), + embedding_weight, position_embed_weight, position_embed_offset, + ragged_wrapper.batch_metadata_buffer(), ragged_wrapper.inflight_seq_descriptors(), + ragged_wrapper.tokens_to_seq(), ragged_wrapper.kv_ptrs()) + return embedded_tokens diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/__init__.py new file mode 100644 index 000000000000..0e239dd6b4c7 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blocked_kv_rotary import * +from .blocked_trained_kv_rotary import * +from .linear_blocked_kv_copy import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp new file mode 100644 index 000000000000..8493bbf4b9af --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "blocked_kv_rotary.h" +#include "ragged_kernel_helpers.h" + +#define DISPATCH_KV_ROTARY(T_TYPE, C_TYPE) \ + if (q.options().dtype() == torch::T_TYPE) { \ + launch_kv_rotary_kernel((C_TYPE*)kv_cache.data_ptr(), \ + (C_TYPE*)q.data_ptr(), \ + (C_TYPE*)k.data_ptr(), \ + (C_TYPE*)v.data_ptr(), \ + (C_TYPE*)inv_freq_ptr, \ + batch_wrapper, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + inv_freq_stride, \ + q_ratio, \ + head_size, \ + n_tokens, \ + n_q_heads, \ + at::cuda::getCurrentCUDAStream()); \ + } + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be ready from global memory rather than +synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] + inv_freq: [max_seq_len, head_size // 2] +*/ +void kv_trained_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& inv_freq, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + const int32_t n_tokens = q.size(0); + TORCH_CHECK(n_tokens == k.size(0)); + TORCH_CHECK(n_tokens == v.size(0)); + + // Dimensions + const int32_t block_size = kv_cache.size(1); + const int32_t n_kv_heads = kv_cache.size(3); + const int32_t head_size = kv_cache.size(4); + + // Strides + const int32_t qkv_stride = q.stride(0); // Per token + const int32_t kv_cache_stride = kv_cache.stride(1); // Per token + const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache + const int32_t inv_freq_stride = inv_freq.stride(0); // Per token idx + + const int n_q_heads = q.size(1) / head_size; + const int q_ratio = n_q_heads / n_kv_heads; + + void* inv_freq_ptr = (void*)inv_freq.data_ptr(); + + BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper( + batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0)); + + DISPATCH_KV_ROTARY(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16); +#endif +} + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] +*/ +void kv_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + const int32_t n_tokens = q.size(0); + TORCH_CHECK(n_tokens == k.size(0)); + TORCH_CHECK(n_tokens == v.size(0)); + + // Dimensions + const int32_t block_size = kv_cache.size(1); + const int32_t n_kv_heads = kv_cache.size(3); + const int32_t head_size = kv_cache.size(4); + + // Strides + const int32_t qkv_stride = q.stride(0); // Per token + const int32_t kv_cache_stride = kv_cache.stride(1); // Per token + const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache + const int32_t inv_freq_stride = 0; // Per token idx + + const int n_q_heads = q.size(1) / head_size; + const int q_ratio = n_q_heads / n_kv_heads; + + void* inv_freq_ptr = nullptr; + + BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper( + batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0)); + + DISPATCH_KV_ROTARY(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_KV_ROTARY(kBFloat16, __nv_bfloat16); +#endif +} + +#define DISPATCH_KV_COPY(T_TYPE, C_TYPE) \ + if (q.options().dtype() == torch::T_TYPE) { \ + launch_kv_copy_kernel((C_TYPE*)kv_cache.data_ptr(), \ + (C_TYPE*)q.data_ptr(), \ + (C_TYPE*)k.data_ptr(), \ + (C_TYPE*)v.data_ptr(), \ + batch_wrapper, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + q_ratio, \ + head_size, \ + n_tokens, \ + n_q_heads, \ + at::cuda::getCurrentCUDAStream()); \ + } + +/* +Copy into linear KV cache. +*/ +void linear_kv_copy(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs) +{ + const int32_t n_tokens = q.size(0); + TORCH_CHECK(n_tokens == k.size(0)); + TORCH_CHECK(n_tokens == v.size(0)); + + // Dimensions + const int32_t block_size = kv_cache.size(1); + const int32_t n_kv_heads = kv_cache.size(3); + const int32_t head_size = kv_cache.size(4); + + // Strides + const int32_t qkv_stride = q.stride(0); // Per token + TORCH_CHECK(qkv_stride == k.stride(0)); + TORCH_CHECK(qkv_stride == v.stride(0)); + + const int32_t kv_cache_stride = kv_cache.stride(1); // Per token + const int32_t v_offset = kv_cache.stride(2); // From k_cache to v_cache + + const int n_q_heads = q.size(1) / head_size; + + TORCH_CHECK(n_q_heads % n_kv_heads == 0); + const int q_ratio = n_q_heads / n_kv_heads; + + BatchWrapperCPP batch_wrapper = make_cpp_batch_wrapper( + batch_metadata, seq_metadata, tokens_to_seq, kv_ptrs, block_size, kv_cache.size(0)); + + DISPATCH_KV_COPY(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_KV_COPY(kBFloat16, __nv_bfloat16); +#endif +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu new file mode 100644 index 000000000000..63ea5bc88bab --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu @@ -0,0 +1,314 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "blocked_kv_rotary.cuh" +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "memory_access_utils.h" + +namespace cg = cooperative_groups; + +namespace kv_rot { + +constexpr int granularity = 16; +constexpr int threads = 256; + +} // namespace kv_rot + +/* +Supports head size 32, 64, 128, 256 +*/ + +template +__global__ void kv_rotary_pos_kernel(T* kv_cache, + T* q, + T* k, + T* v, + const T* inv_freq, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int inv_freq_stride) +{ + // Derived constexpr + constexpr int vector_T = kv_rot::granularity / sizeof(T); + constexpr int threads_per_head = headSize / vector_T; + constexpr int half_head_size = headSize >> 1; + constexpr int tokens_per_block = kv_rot::threads / threads_per_head; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + cg::thread_block_tile head_group = + cg::tiled_partition(warp); + + // Parallelize on the head dimension for X blocks + const int head_idx = blockIdx.x; + + const int block_seq_idx = threadIdx.x / threads_per_head; + const int base_neuron_idx = (threadIdx.x * vector_T) % headSize; + const int half_idx = base_neuron_idx % half_head_size; + const int half_head_lanes = threads_per_head / 2; + + // Multiple tokens processed by the same threadblock + const int token_idx = blockIdx.y * tokens_per_block + block_seq_idx; + const bool valid_token = token_idx < batch_desc.batch_metadata->n_tokens; + const bool load_inv_freq = (inv_freq != nullptr) && valid_token; + + // If we have GQA, then only one of the Q heads needs to do rotary + copy + // for each of the heads in the group. + bool need_kv = head_idx % qRatio == 0; + // Make sure the following code is warp uniform + need_kv = warp.shfl(need_kv, 0); + + const int kv_head_idx = head_idx / qRatio; + + // Ensure we don't access invalid portions of the seq_metadata + const int32_t seq_id = (valid_token) ? batch_desc.tokens_to_seq[token_idx] : 0; + const InflightSeqDescriptor seq_desc = batch_desc.seq_metadata[seq_id]; + // This will give an invalid index if valid_token is false, but should never affect memory. + const int32_t global_token_idx = seq_desc.seen_tokens + (token_idx - seq_desc.start_idx); + + T* q_row = q + token_idx * qkv_stride + head_idx * headSize; + T q_reg[vector_T]; + + if (need_kv) { + // The following logic assumes a linearly blocked KV cache. This means that no sparsity has + // been introduced into cache history. + const KVCacheDescriptor kv_desc = batch_desc.kv_desc; + const int32_t seq_kv_block_idx = global_token_idx / kv_desc.block_size; + const int32_t mapped_kv_block_idx = + (valid_token) ? kv_desc.block_lists[seq_id][seq_kv_block_idx] : 0; + + const int32_t kv_block_offset = global_token_idx % kv_desc.block_size; + const int32_t kv_offset = + (mapped_kv_block_idx * kv_desc.block_size + kv_block_offset) * kv_cache_stride + + kv_head_idx * headSize; + + // Load indices from QKV output + T* k_row = k + token_idx * qkv_stride + kv_head_idx * headSize; + T* v_row = v + token_idx * qkv_stride + kv_head_idx * headSize; + + T k_reg[vector_T], v_reg[vector_T], inv_freq_reg[vector_T]; + + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_token); + mem_access::load_global(k_reg, k_row + base_neuron_idx, valid_token); + mem_access::load_global(v_reg, v_row + base_neuron_idx, valid_token); + mem_access::load_global( + inv_freq_reg, inv_freq + half_idx, load_inv_freq); + + if constexpr (doRotary) { +#pragma unroll + for (int i = 0; i < vector_T; i++) { + const int head_neuron_idx = base_neuron_idx + i; + + float inv_freq_flt; + if (inv_freq != nullptr) { + inv_freq_flt = conversion::to(inv_freq_reg[i]) * (float)global_token_idx; + } else { + inv_freq_flt = + (float)((head_neuron_idx % half_head_size) * 2) / (float)headSize; + // Conversion to T and back means that both branches of this if statement + // will produce the same results if using the same algo for producing the + // freqs. + T trunc_freq = conversion::to(1.0 / powf(10000.0, inv_freq_flt)); + inv_freq_flt = conversion::to(trunc_freq) * (float)global_token_idx; + } + + float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f; + float q_f = conversion::to(q_reg[i]); + float k_f = conversion::to(k_reg[i]); + float q_rot = q_f * rotary_sign; + float k_rot = k_f * rotary_sign; + + const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes); + const float k_rot_temp = head_group.shfl_xor(k_rot, half_head_lanes); + + q_reg[i] = + conversion::to(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt)); + k_reg[i] = + conversion::to(k_f * cosf(inv_freq_flt) + k_rot_temp * sinf(inv_freq_flt)); + } + } + + if (valid_token) { + mem_access::store_global(kv_cache + kv_offset + base_neuron_idx, + k_reg); + mem_access::store_global( + kv_cache + kv_offset + base_neuron_idx + v_offset, v_reg); + } + } else { + T inv_freq_reg[vector_T]; + + mem_access::load_global(q_reg, q_row + base_neuron_idx, valid_token); + mem_access::load_global( + inv_freq_reg, inv_freq + half_idx, load_inv_freq); + + if constexpr (doRotary) { +#pragma unroll + for (int i = 0; i < vector_T; i++) { + const int head_neuron_idx = base_neuron_idx + i; + + float inv_freq_flt; + if (inv_freq != nullptr) { + inv_freq_flt = conversion::to(inv_freq_reg[i]) * (float)global_token_idx; + } else { + inv_freq_flt = + (float)((head_neuron_idx % half_head_size) * 2) / (float)headSize; + inv_freq_flt = 1.0 / powf(10000.0, inv_freq_flt) * (float)global_token_idx; + } + + float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f; + float q_f = conversion::to(q_reg[i]); + float q_rot = q_f * rotary_sign; + + const float q_rot_temp = head_group.shfl_xor(q_rot, half_head_lanes); + + q_reg[i] = + conversion::to(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt)); + } + } + } + + if (valid_token && doRotary) { + mem_access::store_global(q_row + base_neuron_idx, q_reg); + } +} + +#define DISPATCH_KV_ROTARY_IMPL(Q_RATIO, HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + inv_freq, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + inv_freq_stride); + +template +void launch_kv_rotary_kernel(T* kv_cache, + T* q, + T* k, + T* v, + T* inv_freq, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int inv_freq_stride, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream) +{ + constexpr int vector_T = kv_rot::granularity / sizeof(T); + const int threads_per_head = head_size / vector_T; + const int tokens_per_block = kv_rot::threads / threads_per_head; + + const dim3 block(kv_rot::threads); + const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; + const dim3 grid(n_q_heads, token_blocks); + + DISPATCH_KV_ROTARY_IMPL(1, 64) + DISPATCH_KV_ROTARY_IMPL(1, 128) + DISPATCH_KV_ROTARY_IMPL(2, 64) + DISPATCH_KV_ROTARY_IMPL(2, 128) + DISPATCH_KV_ROTARY_IMPL(4, 64) + DISPATCH_KV_ROTARY_IMPL(4, 128) + DISPATCH_KV_ROTARY_IMPL(5, 64) + DISPATCH_KV_ROTARY_IMPL(5, 128) + DISPATCH_KV_ROTARY_IMPL(8, 64) + DISPATCH_KV_ROTARY_IMPL(8, 128) +} + +#define INSTANTIATE_KV_ROTARY_KERNEL(TYPE) \ + template void launch_kv_rotary_kernel(TYPE * kv_cache, \ + TYPE * q, \ + TYPE * k, \ + TYPE * v, \ + TYPE * inv_freq, \ + const BatchWrapperCPP batch_desc, \ + const int qkv_stride, \ + const int kv_cache_stride, \ + const int v_offset, \ + const int inv_freq_stride, \ + const int q_ratio, \ + const int head_size, \ + const int n_tokens, \ + const int n_q_heads, \ + cudaStream_t stream); + +INSTANTIATE_KV_ROTARY_KERNEL(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16) +#endif + +#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel<<>>( \ + kv_cache, q, k, v, nullptr, batch_desc, qkv_stride, kv_cache_stride, v_offset, 0); + +template +void launch_kv_copy_kernel(T* kv_cache, + T* q, + T* k, + T* v, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream) +{ + constexpr int vector_T = kv_rot::granularity / sizeof(T); + const int threads_per_head = head_size / vector_T; + const int tokens_per_block = kv_rot::threads / threads_per_head; + + const dim3 block(kv_rot::threads); + const int token_blocks = (n_tokens + tokens_per_block - 1) / tokens_per_block; + const dim3 grid(n_q_heads, token_blocks); + + DISPATCH_KV_COPY_IMPL(1, 64) + DISPATCH_KV_COPY_IMPL(1, 128) + DISPATCH_KV_COPY_IMPL(2, 64) + DISPATCH_KV_COPY_IMPL(2, 128) + DISPATCH_KV_COPY_IMPL(4, 64) + DISPATCH_KV_COPY_IMPL(4, 128) + DISPATCH_KV_COPY_IMPL(5, 64) + DISPATCH_KV_COPY_IMPL(5, 128) + DISPATCH_KV_COPY_IMPL(8, 64) + DISPATCH_KV_COPY_IMPL(8, 128) +} + +#define INSTANTIATE_KV_COPY_KERNEL(TYPE) \ + template void launch_kv_copy_kernel(TYPE * kv_cache, \ + TYPE * q, \ + TYPE * k, \ + TYPE * v, \ + const BatchWrapperCPP batch_desc, \ + const int qkv_stride, \ + const int kv_cache_stride, \ + const int v_offset, \ + const int q_ratio, \ + const int head_size, \ + const int n_tokens, \ + const int n_q_heads, \ + cudaStream_t stream); + +INSTANTIATE_KV_COPY_KERNEL(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_KV_COPY_KERNEL(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh new file mode 100644 index 000000000000..be38ff30c46c --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +#ifdef BF16_AVAILABLE +#include +#endif + +template +void launch_kv_rotary_kernel(T* kv_cache, + T* q, + T* k, + T* v, + T* inv_freq, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int inv_freq_stride, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream); + +template +void launch_kv_copy_kernel(T* kv_cache, + T* q, + T* k, + T* v, + const BatchWrapperCPP batch_desc, + const int qkv_stride, + const int kv_cache_stride, + const int v_offset, + const int q_ratio, + const int head_size, + const int n_tokens, + const int n_q_heads, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h new file mode 100644 index 000000000000..0615825c0a21 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "blocked_kv_rotary.cuh" + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be ready from global memory rather than +synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] + inv_freq: [max_seq_len, head_size // 2] +*/ +void kv_trained_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& inv_freq, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); + +/* +Rotary position embeddings + copy into KV cache. This implementation assumes +that the inverse frequencies should be synthesized in the kernel. + +Arguments: + kv_cache: [n_blocks, block_size, 2, n_kv_heads, head_size] + q: [n_tokens, n_q_heads * head_size] + k: [n_tokens, n_kv_heads * head_size] + v: [n_tokens, n_kv_heads * head_size] +*/ +void kv_rotary_embeddings(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); + +/* +Copy into linear KV cache. +*/ +void linear_kv_copy(torch::Tensor& kv_cache, + torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_ptrs); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py new file mode 100644 index 000000000000..630d58d90a23 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....ragged import RaggedBatchWrapper +from ... import DSKernelBase + + +class BlockedRotaryEmbeddings(DSKernelBase): + """ + CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys + before copying into a blocked KV cache. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_head_sizes = [64, 128] + supported_q_ratios = [1, 2, 4, 5, 8] + + def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + """ + Args: + head_size: The size of the attention head. + q_ratio: Ratio of q heads to kv heads (for GQA) + dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16. + """ + + q_ratio = n_q_heads // n_kv_heads + + if head_size not in BlockedRotaryEmbeddings.supported_head_sizes: + raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format( + head_size, BlockedRotaryEmbeddings.supported_head_sizes)) + + if q_ratio not in BlockedRotaryEmbeddings.supported_q_ratios: + raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format( + q_ratio, BlockedRotaryEmbeddings.supported_q_ratios)) + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in BlockedRotaryEmbeddings.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, BlockedRotaryEmbeddings.supported_dtypes)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.kv_rotary_embeddings + self.head_size = head_size + self.n_q_heads = n_q_heads + self.n_kv_heads = n_kv_heads + + def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None: + """ + Perform rotary embeddings on the queries and keys before copying into a blocked KV cache. + + Args: + kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size] + qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)] + ragged_batch: Wrapper for the ragged batch. + """ + + q = qkv[:, :self.head_size * self.n_q_heads] + k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] + v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] + + self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(), + ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py new file mode 100644 index 000000000000..59da1db0f5d6 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_trained_kv_rotary.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....ragged import RaggedBatchWrapper +from ... import DSKernelBase + + +class BlockedTrainedRotaryEmbeddings(DSKernelBase): + """ + CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys + before copying into a blocked KV cache. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_head_sizes = [64, 128] + supported_q_ratios = [1, 2, 4, 5, 8] + + def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + """ + Args: + head_size: The size of the attention head. + dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16. + """ + + q_ratio = n_q_heads // n_kv_heads + + if head_size not in BlockedTrainedRotaryEmbeddings.supported_head_sizes: + raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format( + head_size, BlockedTrainedRotaryEmbeddings.supported_head_sizes)) + + if q_ratio not in BlockedTrainedRotaryEmbeddings.supported_q_ratios: + raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format( + q_ratio, BlockedTrainedRotaryEmbeddings.supported_q_ratios)) + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in BlockedTrainedRotaryEmbeddings.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, BlockedTrainedRotaryEmbeddings.supported_dtypes)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.kv_trained_rotary_embeddings + self.head_size = head_size + self.n_q_heads = n_q_heads + self.n_kv_heads = n_kv_heads + + def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper, + inverse_freqs: torch.Tensor) -> None: + """ + Perform rotary embeddings on the queries and keys before copying into a blocked KV cache. + + Args: + kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size] + qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)] + ragged_batch: Wrapper for the ragged batch. + inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, head_size // 2] + """ + + q = qkv[:, :self.head_size * self.n_q_heads] + k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] + v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] + + self.kernel(kv_cache, q, k, v, inverse_freqs, ragged_batch.batch_metadata_buffer(), + ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py new file mode 100644 index 000000000000..c9f6ffd37b3e --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/linear_blocked_kv_copy.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from ....ragged import RaggedBatchWrapper +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ... import DSKernelBase + + +class LinearBlockedKVCopy(DSKernelBase): + """ + CUDA Kernel implementation that will perform rotary position embeddings on the queries and keys + before copying into a blocked KV cache. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + supported_head_sizes = [64, 128] + supported_q_ratios = [1, 2, 4, 5, 8] + + def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + """ + Args: + head_size: The size of the attention head. + dtype: Data type for the input/output. Supported values are torch.float16 and torch.bfloat16. + """ + + q_ratio = n_q_heads // n_kv_heads + + if head_size not in LinearBlockedKVCopy.supported_head_sizes: + raise ValueError("Unsupported head size: {}, supported_head_sizes are {}".format( + head_size, LinearBlockedKVCopy.supported_head_sizes)) + + if q_ratio not in LinearBlockedKVCopy.supported_q_ratios: + raise ValueError("Unsupported q_ratio: {}, supported_q_ratios are {}".format( + q_ratio, LinearBlockedKVCopy.supported_q_ratios)) + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in LinearBlockedKVCopy.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + dtype, LinearBlockedKVCopy.supported_dtypes)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.linear_kv_copy + self.head_size = head_size + self.n_q_heads = n_q_heads + self.n_kv_heads = n_kv_heads + + def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None: + """ + Perform rotary embeddings on the queries and keys before copying into a blocked KV cache. + + Args: + kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size] + qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)] + ragged_batch: Wrapper for the ragged batch. + """ + + q = qkv[:, :self.head_size * self.n_q_heads] + k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] + v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] + + self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(), + ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/__init__.py new file mode 100644 index 000000000000..72103a0d82a1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .logits_gather import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp new file mode 100644 index 000000000000..1a7e7c0a2167 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "logits_gather.h" + +#define DISPATCH_TO_LOGITS_GATHER(T_TYPE, C_TYPE) \ + if (all_acts.options().dtype() == torch::T_TYPE) { \ + launch_logits_gather((C_TYPE*)final_token_acts.data_ptr(), \ + (const C_TYPE*)all_acts.data_ptr(), \ + batch_metadata_raw, \ + seq_metadata_raw, \ + n_seqs, \ + embed_dim, \ + at::cuda::getCurrentCUDAStream()); \ + } + +/* +Logits gather will parse the ragged batch data structure and gather only the logits that +will be used for token sampling. +*/ +void gather_for_logits(torch::Tensor& final_token_acts, + torch::Tensor& all_acts, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata) +{ + const RaggedBatchDescriptor* batch_metadata_raw = + reinterpret_cast(batch_metadata.data_ptr()); + + const InflightSeqDescriptor* seq_metadata_raw = + reinterpret_cast(seq_metadata.data_ptr()); + + const int n_seqs = final_token_acts.size(0); + const int embed_dim = final_token_acts.size(1); + + TORCH_CHECK(all_acts.scalar_type() == final_token_acts.scalar_type(), + "all_acts and final_token_acts must have the same scalar type"); + + DISPATCH_TO_LOGITS_GATHER(kFloat, float) + DISPATCH_TO_LOGITS_GATHER(kHalf, half) +#ifdef BF16_AVAILABLE + DISPATCH_TO_LOGITS_GATHER(kBFloat16, __nv_bfloat16) +#endif +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cu b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cu new file mode 100644 index 000000000000..a539888ff904 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cu @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "logits_gather.cuh" +#include "memory_access_utils.h" +#include "ragged_dtypes.h" + +namespace logits_gather { + +constexpr int granularity = 16; +constexpr int threads = 512; + +} // namespace logits_gather + +template +__global__ void logits_gather_kernel(T* final_token_acts, + const T* token_acts, + const RaggedBatchDescriptor* ragged_batch, + const InflightSeqDescriptor* inflight_batch, + const int32_t embed_dim) +{ + constexpr int T_vector = logits_gather::granularity / sizeof(T); + + const int32_t seq_id = blockIdx.y; + + // It's possible we've padded the output Tensor (under CG conditions) + if (seq_id >= ragged_batch->n_sequences) return; + + const InflightSeqDescriptor seq = inflight_batch[seq_id]; + const int final_token_idx = seq.start_idx + seq.n_tokens - 1; + + const int token_offset = final_token_idx * embed_dim; + const int thread_offset = + threadIdx.x * T_vector + blockIdx.x * logits_gather::threads * T_vector; + + const int final_token_offset = seq_id * embed_dim; + + T reg_buf[T_vector]; + + if (thread_offset < embed_dim) { + mem_access::load_global( + reg_buf, token_acts + token_offset + thread_offset); + + mem_access::store_global( + final_token_acts + final_token_offset + thread_offset, reg_buf); + } +} + +template +void launch_logits_gather(T* final_token_acts, + const T* all_acts, + const RaggedBatchDescriptor* ragged_batch, + const InflightSeqDescriptor* inflight_batch, + const int32_t n_seqs, + const int32_t embed_dim, + cudaStream_t stream) +{ + constexpr int T_vector = logits_gather::granularity / sizeof(T); + constexpr int elems_per_block = logits_gather::threads * T_vector; + const int parallel_blocks = (embed_dim + elems_per_block - 1) / elems_per_block; + + const dim3 grid(parallel_blocks, n_seqs, 1); + const dim3 block(logits_gather::threads, 1, 1); + + logits_gather_kernel<<>>( + final_token_acts, all_acts, ragged_batch, inflight_batch, embed_dim); +} + +#define INSTANTIATE_FOR_TYPE(T) \ + template void launch_logits_gather(T * final_token_acts, \ + const T* all_acts, \ + const RaggedBatchDescriptor* ragged_batch, \ + const InflightSeqDescriptor* inflight_batch, \ + const int32_t n_seqs, \ + const int32_t embed_dim, \ + cudaStream_t stream); + +INSTANTIATE_FOR_TYPE(float) +INSTANTIATE_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cuh b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cuh new file mode 100644 index 000000000000..c4e84c05e6d8 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cuh @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +#ifdef BF16_AVAILABLE +#include +#endif + +template +void launch_logits_gather(T* final_token_acts, + const T* all_acts, + const RaggedBatchDescriptor* batch_metadata, + const InflightSeqDescriptor* seq_metadata, + const int32_t n_seqs, + const int32_t embed_dim, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.h b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.h new file mode 100644 index 000000000000..73a855984daa --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "logits_gather.cuh" +#include "ragged_dtypes.h" + +/* +Logits gather will parse the ragged batch data structure and gather only the logits that +will be used for token sampling. +*/ +void gather_for_logits(torch::Tensor& final_token_acts, + torch::Tensor& all_acts, + torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.py b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.py new file mode 100644 index 000000000000..64b453e9e9e3 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/logits_gather/logits_gather.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from deepspeed.ops.op_builder import RaggedOpsBuilder +from ....inference_utils import elem_size +from ....ragged import RaggedBatchWrapper + + +class RaggedLogitsGather(DSKernelBase): + """ + CUDA Kernel implementation for gather the hidden states of the final token + of each sequence. This is used to reduce the cost of the performing the unembedding. + """ + + supported_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def __init__(self, model_dim: int, fp_dtype: torch.dtype): + """ + Parameters: + fp_dtype (torch.dtype): Data type for the input/output. Supported values + are torch.float16, torch.bfloat16, and torch.float32. + """ + if fp_dtype not in RaggedLogitsGather.supported_dtypes: + raise ValueError("Unsupported data type: {}, supported_dtypes are {}".format( + fp_dtype, RaggedLogitsGather.supported_dtypes)) + + if elem_size(fp_dtype) * model_dim % 16 != 0: + raise ValueError("Embedding dimension must be aligned to 16 bytes, got {}".format(model_dim)) + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.gather_for_logits + + def __call__(self, final_token_activations: torch.Tensor, all_activations: torch.Tensor, + ragged_wrapper: RaggedBatchWrapper) -> torch.Tensor: + """ + Gather the hidden states of the final token of each sequence from `all_activations` into + `final_token_activations`. + + Args: + final_token_activations (torch.Tensor): Output tensor of shape [num_seqs, model_dim] + all_activations (torch.Tensor): Input tensor of shape [num_tokens, model_dim] + ragged_wrapper (RaggedBatchWrapper): Wrapper for the ragged batch. + """ + + self.kernel(final_token_activations, all_activations, ragged_wrapper.batch_metadata_buffer(), + ragged_wrapper.inflight_seq_descriptors()) + return final_token_activations diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/__init__.py new file mode 100644 index 000000000000..096c0d984a5a --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .moe_gather import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp new file mode 100644 index 000000000000..e55e1f48c125 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "moe_gather.h" +#include + +#define DISPATCH_MOE_GATHER(T_TYPE, C_TYPE) \ + if (layer_output.options().dtype() == torch::T_TYPE) { \ + launch_moe_gather((C_TYPE*)layer_output.data_ptr(), \ + (const C_TYPE*)moe_output.data_ptr(), \ + (const float*)scores.data_ptr(), \ + (const int32_t*)mapped_slots.data_ptr(), \ + (int32_t*)expert_count.data_ptr(), \ + n_channels, \ + n_experts, \ + n_tokens, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +/* +Re-gather the outputs of MoE and scale them by the gating score. +*/ +void moe_gather(torch::Tensor& layer_output, + const torch::Tensor& moe_output, + const torch::Tensor& scores, + const torch::Tensor& mapped_slots, + const torch::Tensor& expert_count) +{ + const int32_t n_channels = layer_output.size(1); + const int32_t n_experts = expert_count.size(0); + const int32_t n_tokens = layer_output.size(0); + + TORCH_CHECK(moe_output.size(0) == n_tokens); + TORCH_CHECK(moe_output.size(1) == n_channels); + TORCH_CHECK(scores.size(0) == n_tokens); + TORCH_CHECK(mapped_slots.size(0) == n_tokens); + + TORCH_CHECK(layer_output.scalar_type() == moe_output.scalar_type()); + TORCH_CHECK(scores.scalar_type() == torch::kFloat32); + TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32); + TORCH_CHECK(expert_count.scalar_type() == torch::kInt32); + + DISPATCH_MOE_GATHER(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_MOE_GATHER(kBFloat16, __nv_bfloat16); +#endif + + TORCH_CHECK(false, "Unsupported data type for MoE gather"); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu new file mode 100644 index 000000000000..c2fae24f5080 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "ds_kernel_utils.h" +#include "moe_gather.cuh" +#include "reduction_utils.h" +#include "top_1_gating.cuh" + +namespace gather { + +constexpr int access_granularity = 16; +constexpr int threads = 256; + +} // namespace gather + +template +__global__ void moe_gather_kernel(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + int32_t* expert_counts, + const int32_t n_channels, + const int32_t n_experts) +{ + constexpr int32_t vector_size = gather::access_granularity / sizeof(T); + constexpr int32_t stride = vector_size * gather::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t mapped_slot = mapped_slots[token_idx]; + + if (token_idx == 0) { + // Reset expert counts for its next use. + if (threadIdx.x < n_experts) { expert_counts[threadIdx.x] = 0; } + } + + if (mapped_slot == gating::unassigned) { + // This token was not assigned. + // TODO(cmikeh2): It's possible we want different behavior here moving forward. + return; + } + + const float score = scores[token_idx]; + const int32_t channel_offset = threadIdx.x * vector_size; + + const T* moe_output_base = moe_output + mapped_slot * n_channels + channel_offset; + T* layer_output_base = layer_output + token_idx * n_channels + channel_offset; + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T reg_buffer[vector_size]; + + if (i * stride + channel_offset < n_channels) { + mem_access::load_global(reg_buffer, + moe_output_base + i * stride); + +#pragma unroll + for (int j = 0; j < vector_size; j++) { + // There are accuracy implications of downcasting the score to a 16-bit + // data type, so we up-convert the input to 32-bit, multiply, and then + // down-convert back to 16-bit. + float up_cast = conversion::to(reg_buffer[j]); + reg_buffer[j] = conversion::to(up_cast * score); + } + + mem_access::store_global(layer_output_base + i * stride, + reg_buffer); + } + } +} + +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_gather_kernel<<>>( \ + layer_output, moe_output, scores, mapped_slots, expert_counts, n_channels, n_experts); \ + break; + +template +void launch_moe_gather(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + int32_t* expert_counts, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_tokens, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = gather::threads * gather::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(gather::threads); + const dim3 grid(n_tokens); + + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1) + LAUNCH_FOR_UNROLL(2) + LAUNCH_FOR_UNROLL(3) + LAUNCH_FOR_UNROLL(4) + LAUNCH_FOR_UNROLL(5) + LAUNCH_FOR_UNROLL(6) + } +} + +#define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \ + template void launch_moe_gather(TYPE * layer_output, \ + const TYPE* moe_output, \ + const float* scores, \ + const int32_t* mapped_slots, \ + int32_t* expert_counts, \ + const int32_t n_channels, \ + const int32_t n_experts, \ + const int32_t n_tokens, \ + cudaStream_t stream); + +INSTANTIATE_GATHER_FOR_TYPE(__half) + +#ifdef BF16_AVAILABLE +INSTANTIATE_GATHER_FOR_TYPE(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh new file mode 100644 index 000000000000..f98a727ead58 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +template +void launch_moe_gather(T* layer_output, + const T* moe_output, + const float* scores, + const int32_t* mapped_slots, + int32_t* expert_counts, + const int32_t n_channels, + const int32_t n_experts, + const int32_t n_tokens, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h new file mode 100644 index 000000000000..7ffe9f8b4dc6 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "moe_gather.cuh" + +/* +Re-gather the outputs of MoE and scale them by the gating score. +*/ +void moe_gather(torch::Tensor& layer_output, + const torch::Tensor& moe_output, + const torch::Tensor& scores, + const torch::Tensor& mapped_slots, + const torch::Tensor& expert_counts); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py new file mode 100644 index 000000000000..c37683d03fbe --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ... import DSKernelBase +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder + + +class MoEGather(DSKernelBase): + """ + CUDA implementation of MoE gather. This will bring the tokens back + to their original indices and perform the output scaling. + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + + def __init__(self, dtype: DtypeEnum, channels: int) -> None: + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in MoEGather.supported_dtypes: + raise RuntimeError(f"Unsupported dtype {dtype}") + + if channels % 8 != 0: + raise RuntimeError(f"Channels {channels} must be divisible by 8") + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.moe_gather + + def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: torch.Tensor, + mapped_slots: torch.Tensor, expert_counts: torch.Tensor) -> torch.Tensor: + """ + Reorders the moe_output tokens into their original order and scales them by their + gating scale. This will be a no-op for padded tokens. + + Arguments: + layer_output (torch.Tensor): The output of the layer of shape [n_tokens, hidden_size]. This has been scaled appropriately. + moe_output (torch.Tensor): The output of the MoE of shape [n_tokens, hidden_size]. + scores (torch.Tensor): The gating scores of shape [n_tokens]. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. The index of token ``i`` in layer_output is ``mapped_slots[i]``. + expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. This is passed to fuse the clearing of this data structure into the gather. + + Returns: + layer_output + """ + self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts) + return layer_output diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/__init__.py new file mode 100644 index 000000000000..a7ca91fe5363 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .moe_scatter import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp new file mode 100644 index 000000000000..902f1cc0ea15 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "moe_scatter.h" +#include + +#define DISPATCH_MOE_SCATTER(T_TYPE, C_TYPE) \ + if (activations.options().dtype() == torch::T_TYPE) { \ + launch_moe_scatter((C_TYPE*)moe_input.data_ptr(), \ + (int64_t*)expert_count_cumsums.data_ptr(), \ + (int32_t*)mapped_slots.data_ptr(), \ + (const C_TYPE*)activations.data_ptr(), \ + (const int32_t*)expert_counts.data_ptr(), \ + (const int32_t*)assignments.data_ptr(), \ + (const int32_t*)offsets.data_ptr(), \ + n_channels, \ + n_tokens, \ + n_experts, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +/* +Performs a cumsum on the expert counts and copies the hidden states to the +appropriate spot to ensure that each experts inputs are contiguous. +*/ +void moe_scatter(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& assignments, + torch::Tensor& offsets) +{ + const int32_t n_tokens = activations.size(0); + const int32_t n_channels = activations.size(1); + + // Should have a lot of matching buffer sizes here. + TORCH_CHECK(n_tokens == moe_input.size(0)); + TORCH_CHECK(n_tokens == assignments.size(0)); + TORCH_CHECK(n_tokens == offsets.size(0)); + TORCH_CHECK(n_channels == moe_input.size(1)); + + const int32_t n_experts = expert_count_cumsums.size(0); + + TORCH_CHECK(moe_input.scalar_type() == activations.scalar_type()); + TORCH_CHECK(expert_count_cumsums.scalar_type() == torch::kInt64); + TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32); + TORCH_CHECK(expert_counts.scalar_type() == torch::kInt32); + TORCH_CHECK(assignments.scalar_type() == torch::kInt32); + TORCH_CHECK(offsets.scalar_type() == torch::kInt32); + + DISPATCH_MOE_SCATTER(kHalf, __half); + +#ifdef BF16_AVAILABLE + DISPATCH_MOE_SCATTER(kBFloat16, __nv_bfloat16); +#endif + + TORCH_CHECK(false, "Unsupported dtype for moe_scatter") +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu new file mode 100644 index 000000000000..0746cd7be645 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu @@ -0,0 +1,208 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "moe_scatter.cuh" +#include "reduction_utils.h" +#include "top_1_gating.cuh" + +using ROp = reduce::ROpType; + +namespace scatter { + +constexpr int access_granularity = 16; +constexpr int threads = 256; +constexpr int warps = threads / hw_warp_size; + +} // namespace scatter + +template +__global__ void moe_scatter_kernel(T* moe_input, + int64_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + const int32_t* assignments, + const int32_t* expert_counts, + const int32_t* offsets, + const int32_t n_channels, + const int32_t n_experts) +{ + constexpr int32_t vector_size = scatter::access_granularity / sizeof(T); + constexpr int32_t load_stride = vector_size * scatter::threads; + + const int32_t token_idx = blockIdx.x; + const int32_t tidx = threadIdx.x; + const int32_t warp_rank = tidx / hw_warp_size; + + // Bank aligned and sufficient + __shared__ int32_t red_buffer[32]; + __shared__ int32_t token_0_row; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + const int assigned_expert = assignments[token_idx]; + + // For the different codepaths, we'll converge on this variable for doing + // the token copy. + int32_t token_base_row; + + if (token_idx == 0) { + // Token 0 will perform a cumsum on the data + int32_t expert_vals; + if (tidx < n_experts) { + expert_vals = expert_counts[tidx]; + } else { + expert_vals = 0; + } + +#pragma unroll + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(expert_vals, i); + expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; + } + + if (warp.thread_rank() == hw_warp_size - 1) { + mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); + } + + tb.sync(); + + int32_t phase_2_val = 0; + if (warp.thread_rank() < scatter::warps) { + mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); + } + +#pragma unroll + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(phase_2_val, i); + phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; + } + + int warp_offset = 0; + if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } + const int32_t expert_cumsum = warp_offset + expert_vals; + + if (tidx < n_experts) { + int64_t expert_cumsum_64 = (int64_t)expert_cumsum; + expert_count_cumsums[tidx] = expert_cumsum_64; + } + + if (assigned_expert == gating::unassigned) return; + if (assigned_expert - 1 == tidx) token_0_row = expert_cumsum; + + tb.sync(); + + if (assigned_expert != 0) { + token_base_row = token_0_row; + } else { + token_base_row = 0; + } + + } else if (assigned_expert == gating::unassigned) { + // For whatever reason, don't need to perform the copy, so we'll early return + // and signal this wasn't mapped with a negative 1. + if (tidx == 0) mapped_slots[token_idx] = gating::unassigned; + return; + } else { + // For all other valid tokens, we can just do a block-scoped sum. + if (tidx < assigned_expert) { + token_base_row = expert_counts[tidx]; + } else { + token_base_row = 0; + } + + warp.sync(); + + // TODO(cmikeh2): Shouldn't use the internal api. + reduce::_block(tb, warp, &token_base_row); + } + + // Data copy to appropriate location + const int32_t thread_offset = tidx * vector_size; + + const int32_t base_load_offset = token_idx * n_channels + thread_offset; + const T* load_base_ptr = activations + base_load_offset; + + const int32_t store_row = token_base_row + offsets[token_idx]; + const int32_t base_store_offset = store_row * n_channels + thread_offset; + T* store_base_ptr = moe_input + base_store_offset; + +#pragma unroll + for (int i = 0; i < copyUnroll; i++) { + T tmp_buf[vector_size]; + + if (i * load_stride + thread_offset < n_channels) { + mem_access::load_global(tmp_buf, + load_base_ptr + i * load_stride); + mem_access::store_global(store_base_ptr + i * load_stride, + tmp_buf); + } + } + + if (threadIdx.x == 0) { mapped_slots[token_idx] = store_row; } +} + +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_scatter_kernel<<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + activations, \ + assignments, \ + expert_counts, \ + offsets, \ + n_channels, \ + n_experts); \ + break; + +template +void launch_moe_scatter(T* moe_input, + int64_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + const int32_t* expert_counts, + const int32_t* assignments, + const int32_t* offsets, + const int32_t n_channels, + const int32_t n_tokens, + const int32_t n_experts, + cudaStream_t stream) +{ + constexpr int vals_per_unroll = scatter::threads * scatter::access_granularity / sizeof(T); + const int copy_unroll = (n_channels + vals_per_unroll - 1) / vals_per_unroll; + + const dim3 block(scatter::threads); + const dim3 grid(n_tokens); + + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } +} + +#define INSTANTIATE_SCATTER_FOR_TYPE(TYPE) \ + template void launch_moe_scatter(TYPE*, \ + int64_t*, \ + int32_t*, \ + const TYPE*, \ + const int32_t*, \ + const int32_t*, \ + const int32_t*, \ + const int32_t, \ + const int32_t, \ + const int32_t, \ + cudaStream_t); + +INSTANTIATE_SCATTER_FOR_TYPE(__half); + +#ifdef BF16_AVAILABLE +INSTANTIATE_SCATTER_FOR_TYPE(__nv_bfloat16); +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh new file mode 100644 index 000000000000..5c94cb0ef734 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +template +void launch_moe_scatter(T* moe_input, + int64_t* expert_count_cumsums, + int32_t* mapped_slots, + const T* activations, + const int32_t* expert_counts, + const int32_t* assignments, + const int32_t* offsets, + const int32_t n_channels, + const int32_t n_tokens, + const int32_t n_experts, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.h b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.h new file mode 100644 index 000000000000..59597f63d123 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "moe_scatter.cuh" +#include "ragged_dtypes.h" + +/* +Performs a cumsum on the expert counts and copies the hidden states to the +appropriate spot to ensure that each experts inputs are contiguous. +*/ +void moe_scatter(torch::Tensor& moe_input, + torch::Tensor& expert_count_cumsums, + torch::Tensor& mapped_slots, + torch::Tensor& activations, + torch::Tensor& expert_counts, + torch::Tensor& assignments, + torch::Tensor& offsets); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py new file mode 100644 index 000000000000..5cd6ae5f0fe2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from typing import Tuple + +from ... import DSKernelBase +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import RaggedOpsBuilder + + +class MoEScatter(DSKernelBase): + """ + CUDA implementation of MoE scatter + """ + + supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] + + def __init__(self, dtype: DtypeEnum, channels: int) -> None: + + if not isinstance(dtype, DtypeEnum): + dtype = DtypeEnum(dtype) + + if dtype not in MoEScatter.supported_dtypes: + raise RuntimeError(f"Unsupported dtype {dtype}") + + if channels % 8 != 0: + raise RuntimeError(f"Channels {channels} must be divisible by 8") + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.moe_scatter + + def __call__(self, moe_input: torch.Tensor, expert_cumsum: torch.Tensor, mapped_slots: torch.Tensor, + activations: torch.Tensor, expert_counts: torch.Tensor, assignments: torch.Tensor, + offsets: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Scatters the hidden states such that the token stride for each expert's input is contiguous. + + Arguments: + moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, hidden_size]. + expert_cumsum (torch.Tensor): The cumulative sum of the expert counts of shape [n_experts]. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. + hidden_states (torch.Tensor): The hidden states of shape [n_tokens, hidden_size]. + expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. + assignments (torch.Tensor): The expert assignments of shape [n_tokens]. + offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input (with scattered values), the cumsum of the offsets (for the MoE kernels themselves), and the assignments Tensor modified in place to show which row that token was mapped to in the input. + """ + self.kernel(moe_input, expert_cumsum, mapped_slots, activations, expert_counts, assignments, offsets) + return moe_input, expert_cumsum, mapped_slots diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_dtypes.h b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_dtypes.h new file mode 100644 index 000000000000..7876b354af0d --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_dtypes.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include + +struct +#ifdef __CUDA_CC__ + __align__(8) +#endif +{ + int32_t n_tokens; + int32_t n_sequences; +} +typedef RaggedBatchDescriptor; + +struct +#ifdef __CUDA_CC__ + __align__(16) +#endif +{ + int32_t start_idx; + int32_t n_tokens; + int32_t seen_tokens; + int32_t UNUSED; // Explicit padding to match the Python code pattern. +} +typedef InflightSeqDescriptor; + +struct +#ifdef __CUDA_CC__ + __align__(8) +#endif +{ + int32_t** block_lists; + int32_t block_size; + int32_t n_blocks; +} +typedef KVCacheDescriptor; + +struct { + const RaggedBatchDescriptor* batch_metadata; // Offset 0 + const InflightSeqDescriptor* seq_metadata; // Offset 8 + const int32_t* tokens_to_seq; // Offset 16 + const KVCacheDescriptor kv_desc; // Offset 24 +} typedef BatchWrapperCPP; diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp new file mode 100644 index 000000000000..a6cb7f275366 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ragged_kernel_helpers.h" + +BatchWrapperCPP make_cpp_batch_wrapper(torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_cache_desc, + int32_t block_size, + int32_t n_blocks) +{ + const RaggedBatchDescriptor* batch_metadata_raw = + reinterpret_cast(batch_metadata.data_ptr()); + + const InflightSeqDescriptor* seq_metadata_raw = + reinterpret_cast(seq_metadata.data_ptr()); + + const int32_t* tokens_to_seq_raw = tokens_to_seq.data_ptr(); + + int32_t** kv_ptrs_raw = reinterpret_cast(kv_cache_desc.data_ptr()); + KVCacheDescriptor kv_desc = {kv_ptrs_raw, block_size, n_blocks}; + + BatchWrapperCPP wrapper = {batch_metadata_raw, seq_metadata_raw, tokens_to_seq_raw, kv_desc}; + return wrapper; +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.h b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.h new file mode 100644 index 000000000000..7ce082d31853 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include "ragged_dtypes.h" + +BatchWrapperCPP make_cpp_batch_wrapper(torch::Tensor& batch_metadata, + torch::Tensor& seq_metadata, + torch::Tensor& tokens_to_seq, + torch::Tensor& kv_cache_desc, + int32_t block_size, + int32_t n_blocks); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp new file mode 100644 index 000000000000..1c09fc52bbb1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "atom_builder.h" +#include "blocked_flash.h" +#include "blocked_kv_rotary.h" +#include "embed.h" +#include "logits_gather.h" +#include "moe_gather.h" +#include "moe_scatter.h" +#include "top_1_gating.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + // atom_builder.h + m.def("build_atoms", &build_atoms, "Host kernel for building the atoms."); + + // blocked_flash.h + m.def("flash_attn_by_atoms", + &flash_attn_by_atoms, + "Blocked flash attention scheduled with atoms"); + + // blocked_kv_rotary.h + m.def("kv_rotary_embeddings", &kv_rotary_embeddings, "KV rotary embedding for blocked KV"); + m.def("kv_trained_rotary_embeddings", + &kv_trained_rotary_embeddings, + "KV rotary embeddings for blocked KV"); + m.def("linear_kv_copy", &linear_kv_copy, "Linear copy for blocked KV"); + + // embed.h + m.def("ragged_embed", &ragged_embed, "Embedding lookup for ragged batch"); + + // logits_gather.h + m.def("gather_for_logits", &gather_for_logits, "Sparse gather from ragged batch"); + + // moe_gather.h + m.def("moe_gather", &moe_gather, "MoE gather for top-1-gating."); + + // moe_scatter.h + m.def("moe_scatter", &moe_scatter, "MoE scatter for top-1-gating."); + + // top_1_gating.h + m.def("top_1_gating", &top_1_gating, "Top-1 gating for MoE with ragged batch awareness."); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py new file mode 100644 index 000000000000..b50a0838d9f8 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .top_1_gating import RaggedTop1Gating diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp new file mode 100644 index 000000000000..55c68454b228 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "top_1_gating.h" +#include + +#define DISPATCH_TOP_1_GATING(T_TYPE, C_TYPE) \ + if (logits.options().dtype() == torch::T_TYPE) { \ + launch_top_1_gating((int32_t*)expert_counts.data_ptr(), \ + (float*)scores.data_ptr(), \ + (int32_t*)assignments.data_ptr(), \ + (int32_t*)offsets.data_ptr(), \ + (const C_TYPE*)logits.data_ptr(), \ + batch_metadata_ptr, \ + n_tokens, \ + n_experts, \ + at::cuda::getCurrentCUDAStream()); \ + return; \ + } + +/* +Perform softmax plus atomics in order to do first pass of top_1_gating. +*/ +void top_1_gating(torch::Tensor& expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& logits, + torch::Tensor& batch_metadata) +{ + const int32_t n_tokens = scores.size(0); + + // Should have the same buffer size for scores and offsets + TORCH_CHECK(n_tokens == offsets.size(0)); + TORCH_CHECK(n_tokens == logits.size(0)); + + TORCH_CHECK(expert_counts.scalar_type() == torch::kInt32); + TORCH_CHECK(scores.scalar_type() == torch::kFloat); + TORCH_CHECK(assignments.scalar_type() == torch::kInt32); + TORCH_CHECK(offsets.scalar_type() == torch::kInt32); + + const int32_t n_experts = logits.size(1); + const RaggedBatchDescriptor* batch_metadata_ptr = + reinterpret_cast(batch_metadata.data_ptr()); + + DISPATCH_TOP_1_GATING(kFloat, float) + DISPATCH_TOP_1_GATING(kHalf, __half) +#ifdef BF16_AVAILABLE + DISPATCH_TOP_1_GATING(kBFloat16, __nv_bfloat16) +#endif + + TORCH_CHECK(false, "Unsupported dtype for logits in top_1_gating"); +} diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu new file mode 100644 index 000000000000..02daee9f692e --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "conversion_utils.h" +#include "memory_access_utils.h" +#include "reduction_utils.h" +#include "top_1_gating.cuh" + +using ROp = reduce::ROpType; + +template +__global__ void top_1_gating_kernel(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + const T* logits, + const RaggedBatchDescriptor* batch_metadata, + const int32_t n_experts) +{ + const int32_t token_idx = blockIdx.x; + const int32_t expert_idx = threadIdx.x; + const int32_t max_warps = 1024 / hw_warp_size; + + // CG helpers + cg::thread_block tb = cg::this_thread_block(); + cg::thread_block_tile warp = cg::tiled_partition(tb); + + // Padding tokens do not require + if (token_idx >= batch_metadata->n_tokens) { + if (threadIdx.x == 0) { + offsets[token_idx] = gating::unassigned; + assignments[token_idx] = gating::unassigned; + } + return; + } + + const T* token_logits = logits + token_idx * n_experts; + + float logit_val; + if (expert_idx < n_experts) { + logit_val = conversion::to(token_logits[expert_idx]); + } else { + reduce::init(&logit_val); + } + + // Training code tends to use ``torch.argmax`` to select the expert, which + // which has ties broken by the lower index. Since our fused comparison algorithm + // breaks ties by the higher index (since it's the lower 32-bits of the 64-bit + // comparison), we invert the expert index to break ties by the lower index. + int32_t inverted_expert = n_experts - expert_idx - 1; + // Perform softmax + const reduce::IdxReduceResult res = + reduce::idx_reduce(tb, warp, logit_val, inverted_expert); + // Recover the original expert index + const int32_t assigned_expert = n_experts - res.idx - 1; + const float max_logit = res.val; + + float softmax_sum = __expf(logit_val - max_logit); + reduce::block(tb, warp, softmax_sum); + + // Compute the score + const float score = __expf(max_logit - max_logit) / softmax_sum; + + if (threadIdx.x == 0) { + scores[token_idx] = score; + assignments[token_idx] = assigned_expert; + offsets[token_idx] = atomicAdd(expert_counts + assigned_expert, 1); + } +} + +template +void launch_top_1_gating(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + const T* logits, + const RaggedBatchDescriptor* batch_metadata, + const int32_t n_tokens, + const int32_t n_experts, + cudaStream_t stream) +{ + const dim3 grid(n_tokens); + const dim3 block(((n_experts + hw_warp_size - 1) / hw_warp_size) * hw_warp_size); + + top_1_gating_kernel<<>>( + expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); +} + +#define INSTANTIATE_TOP_1_KERNEL(T) \ + template void launch_top_1_gating(int32_t * expert_counts, \ + float* scores, \ + int32_t* assignments, \ + int32_t* offsets, \ + const T* logits, \ + const RaggedBatchDescriptor* batch_metadata, \ + const int32_t n_tokens, \ + const int32_t n_experts, \ + cudaStream_t stream); + +INSTANTIATE_TOP_1_KERNEL(float) +INSTANTIATE_TOP_1_KERNEL(__half) +#ifdef BF16_AVAILABLE +INSTANTIATE_TOP_1_KERNEL(__nv_bfloat16) +#endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh new file mode 100644 index 000000000000..c83ad56ff2f1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" +#include "ragged_dtypes.h" + +namespace gating { +constexpr int unassigned = -1; +} // namespace gating + +template +void launch_top_1_gating(int32_t* expert_counts, + float* scores, + int32_t* assignments, + int32_t* offsets, + const T* logits, + const RaggedBatchDescriptor* batch_metadata, + const int32_t n_tokens, + const int32_t n_experts, + cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h new file mode 100644 index 000000000000..b431f4cad30c --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ragged_dtypes.h" +#include "top_1_gating.cuh" + +/* +Perform softmax plus atomics to get token mapping. +*/ +void top_1_gating(torch::Tensor& expert_counts, + torch::Tensor& scores, + torch::Tensor& assignments, + torch::Tensor& offsets, + torch::Tensor& logits, + torch::Tensor& batch_metadata); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py new file mode 100644 index 000000000000..1df97c2e9f8d --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from typing import Tuple + +from ... import DSKernelBase +from ....inference_utils import DtypeEnum +from ....ragged import RaggedBatchWrapper +from deepspeed.ops.op_builder import RaggedOpsBuilder + + +class RaggedTop1Gating(DSKernelBase): + """ + CUDA implementation of top-1 gating. This will perform a softmax on the logits, + and return the scale as well as its idx within that expert's allocation. + """ + + supported_logit_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32] + + def __init__(self, logit_dtype: DtypeEnum) -> None: + + if not isinstance(logit_dtype, DtypeEnum): + logit_dtype = DtypeEnum(logit_dtype) + + if logit_dtype not in RaggedTop1Gating.supported_logit_dtypes: + raise RuntimeError(f"Unsupported logit dtype {logit_dtype}") + + inf_module = RaggedOpsBuilder().load() + self.kernel = inf_module.top_1_gating + + def __call__(self, expert_counts: torch.Tensor, scores: torch.Tensor, assignments: torch.Tensor, + offsets: torch.Tensor, logits: torch.Tensor, + batch: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Perform the ragged top_1_gating. + + Arguments: + expert_counts (torch.Tensor): Tensor of 0s of shape [n_experts] to be filled with + number of tokens assigned to each expert. This must be filled with 0s else + the copy kernel will buffer overflow. In order to minimize the zero-fill cost, + it is recommended to write to 0 during the MoE output remapping. + scores (torch.Tensor): Preallocated output of shape [n_tokens] to place expert scaling + value. + expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens] to place + which expert a token has been assigned to. + expert_offset (torch.Tensor): Preallocated output of shape [n_tokens] to place which + offset within an experts group a token is. + logits (torch.Tensor): Raw logits of gating function. + batch (RaggedBatchWrapper): Batch information for ragged tensor. + + Returns: + tuple of (expert_counts, scores, expert_assignment, expert_offset) + """ + self.kernel(expert_counts, scores, assignments, offsets, logits, batch.batch_metadata_buffer()) + return expert_counts, scores, assignments, offsets diff --git a/deepspeed/inference/v2/logging.py b/deepspeed/inference/v2/logging.py new file mode 100644 index 000000000000..77afe351cbea --- /dev/null +++ b/deepspeed/inference/v2/logging.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import logging + +from deepspeed.utils.logging import LoggerFactory + +inf_logger = None + + +def inference_logger(level: int = logging.INFO) -> logging.Logger: + """ + Create the inference logger. NOTE: Logging is not cost free. On a 3960X, + there is a cost of about 6 us per call to a no-op logger, so this should + be used during setup only and not during the inference loop. + + Args: + level (int, optional): The logging level. Defaults to logging.INFO. + """ + global inf_logger + if inf_logger is None: + inf_logger = LoggerFactory.create_logger(name="DS-Inference", level=level) + inf_logger.debug("Inference logger created.") + return inf_logger diff --git a/deepspeed/inference/v2/model_implementations/AddingAModel.md b/deepspeed/inference/v2/model_implementations/AddingAModel.md new file mode 100644 index 000000000000..8fe27297080b --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/AddingAModel.md @@ -0,0 +1,84 @@ +# Adding Support for a New Model in DeepSpeed Inference V2 + +Adding supoprt for a new model in DeepSpeed Inference requires developing three related components: +- Containers: These describe the parameters contained in the model +- Model implementation: How should the model be computed. +- Policy: The map for adding parameters to your containers and creating the model implementation. + +In this tutorial, we will assume that you'd like to use a relatively traditionally styled Transformer model and will be able to inherit from `DSTransformerModelBase` and can take advantage of the utilities that provides. + +## Defining Your Containers + +A container is the bridge between the original model's parameters and how to transform them to serve them for inference. For a model implementation, there are two primary kinds of containers: transformer containers and non-transformer containers. A transformer container consists of the parameters for a single Transformer layer in the model. So this includes your traditional parameters like the projections for the fully connected network, or query-key-value projections. The non-transformer container will contain basically everything else! However, before defining these containers, we need to understand how to define an individual parameter. + +In DeepSpeed inference, the original model parameters are populated into the model and mapped as dependencies to a parameter. A `Parameter` has two primary components: its dependencies and its `finalize` method. Let's do an example. In Llama models, the native format is for the `query`, `key`, and `value` projections to be performed independently. However, we can achieve higher throughput by fusing them into a single larger projection. We can define this fusion with a parameter: + +```python +from deepspeed.inference.module_implementations.parameter_base import ParameterBase + +class UnfusedQKVParameter(ParameterBase): + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + + def finalize(self) -> torch.Tensor: + fused_param = torch.cat([self.query, self.key, self.value], dim=0) + return self.inference_model.transform_qkv_param(fused_param) +``` + +Let's walk through each part of this implementation. First, parameters should inherit from `ParameterBase`. This will allow it to automatically determine when its dependencies are met and set the appropriate components of a parent `LayerContainer`. The second key component is the type annotations on the class itself. Each type annotation represents a dependency of the parameter. Since the original Llama mode has separate query, key, and value dependencies, our fused parameter will declare dependencies for each. Finally, we have the `finalize` method. This method is automatically called once all dependencies on the layer are met and should return the final parameter. + +In this `finalize` method, we are doing two things: the first is the act of fusing the parameters together through the concatenate method. Note that each of the dependencies can be accessed via `self.{name}`. The second is calling `self.inference_model.transform_qkv_param`. A parameter's finalize method always has access to the inference model. In this case we are using that to use a feature provided by `DSTransformerBase`. This method will automatically shard the parameter for tensor parallelism and then pass it to the linear module implementation to perform additional optimizations or shape transformations, like quantization. + +Since many patterns are very common in Transformer models, `model_implementations.common_parameters` provides implementations for many of the patterns (all compatible with `DSTransformerBase`) to help accelerate development. + +Once all parameters are created, we need to compose them into a layer container. In our simplified Llama model, let's assume there's only QKV and attention output projection matrices. A layer container would appear as the following: + +```python +from deepspeed.inference.module_implementations.layer_container_base import LayerContainer + +class ExampleContainer(LayerContainer): + qkvw: UnfusedQKVParameter + + attn_o: AttentionOutputParameter + + PARAM_MAPPING: { + "self_attn.q_proj.weight": "qkvw.query", + "self_attn.k_proj.weight": "qkvw.key", + "self_attn.v_proj.weight": "qkvw.value", + "self_attn.o_proj.weight": "attn_o.params", + } +``` + +Once again, we have a couple of key components. The first are parameter type annotations. Each annotation corresponds to a parameter that can be used in the model implementation. In the model implementation, I can simply write `container.qkvw` to access my fused and transformed QKV parameter. The second key component is the `PARAM_MAPPING` dictionary. This is our explicit mapping of the names of parameters in the source model to a parameter dependency. This mapping dictionary will be used by the policy to automatically populate dependencies. + +Once you have written `LayerContainer`s for both the transformer and non-transformer parameters, it's time to work on the model implementation! + +## Building a Model Implementation that Inherits from `DSTransformerBase` + +By inheriting from `DSTransformerBase`, most of the implementation work for sharding and transforming parameters will be automatically handled for you. However, there are four key tasks that still need to be completed. + +1. Defining the abstract properties based on your model configuration. +2. Configuring embedding and unembedding modules and the forward implementations for them. +3. Configuring the attention configuration and desired KV cache behaviors. +4. Writing the forward implementation for your layer. + +## Writing a Policy + +The `InferenceV2Policy` is the level of composition. This is the object that will be passed directly to the inference engine and will compose the model implementation and your containers to create an end-to-end solution. There are two main components to be implemented: the first is to create the model that you defined earlier. This is done by implementing the `instantiate_model` method of the policy. In general, this can just be implemented by calling the constructor for your model and passing the engine config, tensor-parallel communication object, and your custom model config. + +The second component is to define how the parameters from the checkpoint will map to each container. From the section on `LayerContainer`s above, you may remember that the `LayerContainer` can handle the internal routing of a checkpoint parameter to its dependency. In order to find the correct `LayerContainer` though, we need a second abstraction: the `ContainerMap`. + +A `ContainerMap` performs this mapping by categorizing checkpoint prefix strings to the type of container they map to. Typically, the easiest way to do this is through iterating over a model checkpoint's state dict or by iterating over the `named_parameters` of a PyTorch model. There are three types of mappings to define: the transformer mappings, the non-transformer mappings, and the what we'll call the rest. Let's work through an example: + +```python +from deepspeed.inference.module_implementations.inference_policy_base import ContainerMap + +def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [MyTransformerContainer(self.model) for _ in range(self.model.num_layers)] + map.set_transformer_params("model.layers", transformer_containers) + + non_transformer_container = MyNonTransformerContainer(self.model) +``` diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py new file mode 100644 index 000000000000..a3481023a8fd --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .inference_model_base import DSInferenceModelBase +from .inference_transformer_base import DSTransformerModelBase, DSMoETransformerModelBase +from .inference_policy_base import InferenceV2Policy, ContainerMap +from .sharding import * diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/__init__.py b/deepspeed/inference/v2/model_implementations/common_parameters/__init__.py new file mode 100644 index 000000000000..60963011cd66 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attn_output_parameters import * +from .embedding_parameters import * +from .mlp_parameters import * +from .moe_parameters import * +from .norm_parameters import * +from .qkv_parameters import * +from .unembed_parameters import * +from .invfreq_parameters import * diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/attn_output_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/attn_output_parameters.py new file mode 100644 index 000000000000..f220cf7a7125 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/attn_output_parameters.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Common Attention Output Parameter Patterns +""" + + +class AttentionOutputParameter(ParameterBase): + """ + Attention output parameter container. + + Note: The differentiation for something like GQA for this matrix is primarily + encompassed in the sharding logic, which is currently expected to be performed by + the model implementation. + """ + + params: torch.Tensor + """ + Unsharded attention output parameter of shape [model_dim, model_dim] + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_attn_out_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/embedding_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/embedding_parameters.py new file mode 100644 index 000000000000..4babc0ee0127 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/embedding_parameters.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +from ...allocator import on_device +""" +Embedding containers. +""" + + +class EmbeddingParameter(ParameterBase): + """ + Embedding container. This should be safe to use for all types of embeddings (i.e. word, position, + and token type). + """ + + params: torch.Tensor + """ + Vocabulary parameter of shape [vocab_size, model_dim]. + """ + + @on_device + def finalize(self) -> torch.Tensor: + return self.params + #return self.inference_model.transform_embed_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/invfreq_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/invfreq_parameters.py new file mode 100644 index 000000000000..3a5a7fb04b9a --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/invfreq_parameters.py @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +from ...allocator import on_device +""" +Common InvFreq Parameter Patterns +""" + + +class InvFreqParameter(ParameterBase): + + params: torch.Tensor + + @on_device + def finalize(self) -> torch.Tensor: + return self.params.to(self.inference_model.activation_dtype.value) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/mlp_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/mlp_parameters.py new file mode 100644 index 000000000000..ddb8996e03a3 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/mlp_parameters.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +MLP Parameter Containers +""" + + +class MLP1Parameter(ParameterBase): + """ + First MLP projection weight container. This performs a straight pass-through to the + model implementation for transformation. + """ + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + # NOTE(cmikeh2): If we are gated but not in the format specified below, we should trigger a permutation here. + # I am not currently aware of any models that use this format (or how we should even detect it; probably should + # just be a different param entirely, but until then we'll just assume the format is correct). + return self.inference_model.transform_mlp_1_param(self.params) + + +class GatedMLPParameter(ParameterBase): + """ + Gated MLP projection container. + """ + + gate_params: torch.Tensor + """ + Weight parameter for the gating matrix. + """ + + up_params: torch.Tensor + """ + For lack of a better name, the non-gating weight parameters. + """ + + def finalize(self) -> torch.Tensor: + """ + Our gated format (this is different from InferenceV1!) is to have the gate and activated neurons + interleaved. So if we have 4 output neurons (two effective neurons) with 4 input neurons, the finalized + parameter will look like: + [g0_0, g0_1, g0_2, g0_3] + [a0_0, a0_1, a0_2, a0_3] + [g1_0, g1_1, g1_2, g1_3] + [a1_0, a1_1, a1_2, a1_3] + + As a reference, in inference v1, the format is: + [g0_0, g0_1, g0_2, g0_3] + [g1_0, g1_1, g1_2, g1_3] + [a0_0, a0_1, a0_2, a0_3] + [a1_0, a1_1, a1_2, a1_3] + """ + assert self.gate_params.shape[0] == self.up_params.shape[ + 0], "Gated MLP parameters must have the same number of neurons." + total_neurons = self.gate_params.shape[0] + self.up_params.shape[0] + + # flip the order if even with the correct tokenizer we get wrong output + #fused_param = torch.cat([self.up_params, self.gate_params], dim=-1).reshape(total_neurons, -1) + fused_param = torch.cat([self.gate_params, self.up_params], dim=-1).reshape(total_neurons, -1) + return self.inference_model.transform_mlp_1_param(fused_param) + + +class MLP2Parameter(ParameterBase): + """ + Second MLP projection weight container. This performs a straight pass-through to the + model implementation for transformation. + """ + + params: torch.Tensor + """ + Full weight parameter. + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_mlp_2_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py new file mode 100644 index 000000000000..ae95e628b779 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...allocator import on_device +from ...model_implementations.parameter_base import ParameterBase, ParamList +""" +Moe Parameters + +These parameters are compatible with any model inheriting from ``DSMoETransformerModelBase``. +""" + + +class MoEGatingWeightParameter(ParameterBase): + """ + Gating weight matrix. + """ + + params: torch.Tensor + """ + Projection matrix from the input activations to the gate logits. + """ + + @on_device + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_moe_gate_param(self.params) + + +class UnfusedMoEMLP1Parameter(ParameterBase): + """ + This container should be used when the experts are held in separate parameters + and need to be joined into a single group. + """ + + experts: ParamList("num_experts") # noqa: F821 + + def finalize(self) -> torch.Tensor: + stacked_experts = torch.stack([p for p in self.experts], dim=0) + return self.inference_model.transform_moe_mlp_1_param(stacked_experts) + + +class UnfusedMoEMLP2Parameter(ParameterBase): + """ + This container should be used when the experts are held in separate parameters + and need to be joined into a single group. + """ + + experts: ParamList("num_experts") # noqa: F821 + + def finalize(self) -> torch.Tensor: + stacked_experts = torch.stack([p for p in self.experts], dim=0) + return self.inference_model.transform_moe_mlp_2_param(stacked_experts) + + +class UnfusedMoEGatedMLPParameter(ParameterBase): + """ + MoE Parameter for a gated activation function in which the gating matrix is not + fused in the same parameter as the non-gating matrix. + """ + + gating_experts: ParamList("num_experts") # noqa: F821 + + up_experts: ParamList("num_experts") # noqa: F821 + + def finalize(self) -> torch.Tensor: + fused_params = [torch.cat([gate, weight], dim=0) for gate, weight in zip(self.gating_experts, self.up_experts)] + stacked_params = torch.stack(fused_params, dim=0) + return self.inference_model.transform_moe_mlp_2_param(stacked_params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/norm_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/norm_parameters.py new file mode 100644 index 000000000000..81ffcc3221df --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/norm_parameters.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Common Attention Output Parameter Patterns +""" + + +class NormParameter(ParameterBase): + """ + Simple normalization container. + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_norm_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/qkv_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/qkv_parameters.py new file mode 100644 index 000000000000..2ed8a8654f5b --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/qkv_parameters.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Common QKV Parameter Patterns +""" + + +class FusedQKVParameter(ParameterBase): + """ + Traditional fused QKV parameters for QKV projection. This is functionally + a direct copy. + + src_qkv_w shape: [3 * out_features, in_features] + qkv_w shape: [3 * out_features, in_features] + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_qkv_param(self.params) + + +class UnfusedQKVParameter(ParameterBase): + """ + QKV parameter container for unfused QKV projection. + + src_param shapes: 3 x [out_features, in_features] + dst_param shape: [3 x out_features, in_features] + """ + + q_params: torch.Tensor + + k_params: torch.Tensor + + v_params: torch.Tensor + + def finalize(self): + fused_param = torch.cat([self.q_params, self.k_params, self.v_params], dim=0) + return self.inference_model.transform_qkv_param(fused_param) + + +def megatron_qkv_reshape(param: torch.Tensor, head_size: int, n_heads: int) -> torch.Tensor: + assert param.shape[0] == 3 * n_heads * head_size + + all_heads = torch.chunk(param, chunks=3 * n_heads, dim=0) + q_heads = all_heads[::3] + k_heads = all_heads[1::3] + v_heads = all_heads[2::3] + return torch.cat([q_heads, k_heads, v_heads], dim=0) + + +class MegatronQKVParameter(ParameterBase): + """ + QKV parameter container for Megatron-style QKV projection. Megatron stores the parameter + as [n_heads, 3, head_size, in_features] whereas our inference system is built around + [3, n_heads, head_size, in_features]. This container handles the conversion. + + Note: this container expects the model implementation to implement properties for + `head_size` and `n_heads`. + + src_qkv_w shape: [3 * out_features, in_features] + qkv_w shape: [3 * out_features, in_features] + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + head_size = self.inference_model.head_size + n_heads = self.inference_model.n_heads + + transposed_param = megatron_qkv_reshape(self.params, head_size, n_heads) + return self.inference_model.transform_qkv_param(transposed_param) + + +def transform_gqa_megatron(src_param: torch.Tensor, head_size: int, n_q_heads: int, n_kv_heads: int) -> torch.Tensor: + assert src_param.shape[0] == (2 * n_kv_heads + n_q_heads) * head_size + + head_ratio = n_q_heads // n_kv_heads + + # Reshape to get the groups as the leading dimension + groups_leading_view = src_param.reshape(n_kv_heads, 2 + head_ratio, head_size, -1) + q_heads = groups_leading_view[:, :head_ratio, :, :].reshape(-1, groups_leading_view.shape[-1]) + k_heads = groups_leading_view[:, head_ratio, :, :].reshape(-1, groups_leading_view.shape[-1]) + v_heads = groups_leading_view[:, head_ratio + 1, :, :].reshape(-1, groups_leading_view.shape[-1]) + # Squeeze will remove extra dimension for bias + return torch.cat([q_heads, k_heads, v_heads], dim=0).squeeze() + + +class GQAMegatronQKVParameter(ParameterBase): + """ + QKV parameter for Megatron-style QKV projection with GQA-style QKV projection. In this + storage format each of the groups is stored consecutively, so there will be multiple q_heads, + then one k head, and one v head. + + Note: this container expects the model implementation to implement properties for + `head_size`, `n_q_heads`, and `n_kv_heads`. + + src_qkv_w shape: [(2 * n_kv_heads + n_q_heads) * head_size, in_features] + qkv_w shape: [(2 * n_kv_heads + n_q_heads) * head_size, in_features] + """ + + params: torch.Tensor + + def finalize(self) -> torch.Tensor: + head_size = self.inference_model.head_size + n_q_heads = self.inference_model.n_heads_q + n_kv_heads = self.inference_model.n_heads_kv + + transposed_param = transform_gqa_megatron(self.params, head_size, n_q_heads, n_kv_heads) + return self.inference_model.transform_qkv_param(transposed_param) diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/unembed_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/unembed_parameters.py new file mode 100644 index 000000000000..9f67c0ce3c27 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/common_parameters/unembed_parameters.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ...model_implementations.parameter_base import ParameterBase +""" +Unembedding containers. +""" + + +class UnembedParameter(ParameterBase): + """ + Unembedding parameter. This will likely be mapped to the same original weight in the model as the + embedding, but we have a different preferred sharding approach. + """ + + params: torch.Tensor + """ + Unembedding parameter of shape [vocab_size, model_dim]. + """ + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform_unembed_param(self.params) diff --git a/deepspeed/inference/v2/model_implementations/inference_model_base.py b/deepspeed/inference/v2/model_implementations/inference_model_base.py new file mode 100644 index 000000000000..bc6b28f8ce52 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/inference_model_base.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod +from typing import Iterable, Optional, Tuple, Type + +import torch + +from ..ragged import DSStateManager, RaggedBatchWrapper +from ..ragged.manager_configs import KVCacheConfig +from ..ragged import DSSequenceDescriptor +from ..model_implementations.layer_container_base import LayerContainer +from ..config_v2 import RaggedInferenceEngineConfig +""" +This abstract class defines the interfaces that a model implementation should implement +in order to include anything that may be called by the engine. Most models should be able +to inherit from `DSInferenceTransformerModelBase` to reduce implementation work so it is recommended +to begin there. +""" +""" +Placeholder for typing the model config, which can vary based on model implementation/ +""" +DSModelImplementationConfig = Type['DSModelImplementationConfig'] +""" +Placeholder for typing the distributed comm object. + +TODO(cmikeh2): Replace when we have a more defined API for the inference communication system. +""" +MPType = Type["MPType"] + + +class DSInferenceModelBase(torch.nn.Module, ABC): + """ + Implementation of a model for inference composable with ragged batching. + """ + + _config: DSModelImplementationConfig + """ + Model-specific configuration. No abstraction surrounds this yet. + """ + + _engine_config: RaggedInferenceEngineConfig + """ + Engine configuration. + """ + + _base_mp_group: MPType + """ + Base communication group for Tensor-parallel inference. + """ + + _non_transformer: Optional[LayerContainer] + """ + Abstract container for storing both embedding (pre-transformer) and unembedding (post-transformer) + parameters. This attribute should be None at model instantiation until the Policy sets + the model parameters. These parameters are grouped together since many model implementations + will tie the embedding and unembedding parameters together. + """ + + _transformer: Optional[Iterable[LayerContainer]] + """ + List of abstract containers (1 per layer) for storing transformer (transformer) + parameters. This attribute should be None at model instantiation until the Policy + sets the model parameters. + """ + + state_manager: Optional[DSStateManager] + """ + Since the state manager is lazy initialized, by the engine, it is not guaranteed to be present + until full initialization. + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Minimal initialization of the model. + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__() + self._config = config + self._engine_config = engine_config + self._base_mp_group = base_mp_group + + # Set to None until the Policy sets the model parameters + self._non_transformer = None + self._transformer = None + + def set_parameters(self, transformer: Iterable[LayerContainer], non_transformer: LayerContainer): + """ + Set the model parameters for the embedding, transformer, and unembedding containers. + """ + self._transformer = transformer + self._non_transformer = non_transformer + + def set_state_manager(self, state_manager: DSStateManager): + """ + Sets the state manager attribute. This is called by the inference engine after + the model is fully initialized. + """ + self.state_manager = state_manager + + @abstractmethod + def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: int, + max_new_blocks: int) -> Tuple[int, int]: + """ + Given a sequence and the number of new tokens in the sequence, determine the + number of new KV blocks needed to support the sequence. This method is + used to help the engine provide schedulability APIs and can be used as a helper + for ``maybe_allocate_kv``. + + Args: + sequence (DSSequenceDescriptor): The sequence for which to allocate KV-storage. + max_new_tokens (int): Maximum number of tokens to hypothetically schedule. + max_new_blocks (int): Maximum number of blocks to hypothetically allocate. + + Returns: + Tuple[int, int]: The tuple of number of tokens scheduled and number + of blocks allocated. In general, only one of these numbers will match the + corresponding input argument, but this is not guaranteed. + """ + raise NotImplementedError() + + @abstractmethod + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + """ + Given a sequence and the number of new tokens in the sequence, determine + whether or not additional KV-storage is needed and allocate it if so. + + Args: + sequence (DSSequenceDescriptor): The sequence for which to allocate KV-storage. + n_new_tokens (int): The number of new tokens in the sequence. + """ + raise NotImplementedError() + + @abstractmethod + def kv_cache_config(self) -> KVCacheConfig: + """ + Return the KV-cache configuration for this model. + """ + raise NotImplementedError() + + @property + @abstractmethod + def max_sequence_length(self) -> int: + """ + The maximum sequence length supported by the model. + """ + ... + + def maybe_free_kv(self, sequence: DSSequenceDescriptor): + """ + After completing a forward pass, determine whether or not the there are any KV blocks + that maybe freed since they are no longer in use. + + Consider the following example: + + We have a block size of 4 and a local window size of 8. At the beginning of the forward + pass there 10 tokens had been seen and the new forward has a size of 4. This would lend + itself to the following cache structure prior to the forward: + [[0, 1, 2*, 3*] [4*, 5*, 6*, 7*] [8*, 9*, x, x] [x x x x]] + Where x's denote empty cache locations and * denote values that are needed for attention + of the next open slot. After the forward, the cache would look like the following: + [[0, 1, 2, 3] [4, 5, 6*, 7*] [8*, 9*, 10*, 11*] [12* 13* x x]] + In this case, the first block is no longer needed since it is not needed for any future + local attention windows. This function would be responsible for freeing that block. + + Default behavior assumes no local patterns that require freeing and in general should + be sufficient. + """ + pass + + @abstractmethod + def prepare_batch(self, wrapped_batch: RaggedBatchWrapper) -> None: + """ + This will be called before each forward with the intent of building forward-specific metadata + about a batch. The intent here is to build data structures like attention atoms without necessarily + needing to implement graphable kernels to do so. + + Abstract so as to force model implementations to opt out of doing anything here explicitly. + """ + raise NotImplementedError() + + def forward(wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Complete a forward pass of the model. This interface should be graphable, so it + should not rely on the ability to use python control flow. + """ + raise NotImplementedError() diff --git a/deepspeed/inference/v2/model_implementations/inference_policy_base.py b/deepspeed/inference/v2/model_implementations/inference_policy_base.py new file mode 100644 index 000000000000..f87c4f0c7cfc --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/inference_policy_base.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractmethod +from typing import Any, Iterable, List, Union + +from ..config_v2 import RaggedInferenceEngineConfig +from ..checkpoint import CheckpointEngineBase +from ..logging import inference_logger +from .layer_container_base import LayerContainer +from .inference_model_base import DSInferenceModelBase + + +class ContainerMap: + + def __init__(self) -> None: + self._prefix_map = {} + self._transformer_params = None + self._non_transformer_params = None + + @property + def transformer_params(self) -> Iterable[LayerContainer]: + return self._transformer_params + + @property + def non_transformer_params(self) -> LayerContainer: + return self._non_transformer_params + + def set_transformer_params(self, prefixes: Union[str, Iterable[str]], containers: List[LayerContainer]) -> None: + if not isinstance(containers, list): + raise ValueError( + f"The transformer containers should be a list, of one container per layer, but got {type(containers)} instead." + ) + + self._transformer_prefixes = prefixes if isinstance(prefixes, list) else [prefixes] + self._transformer_params = containers + + def set_non_transformer_params(self, container: LayerContainer) -> None: + self._non_transformer_params = container + + def set_unmapped_params(self, prefixes: Union[str, Iterable[str]]) -> None: + self._unmapped_prefixes = prefixes + + def map_param(self, name, parameter) -> None: + for unmapped_prefix in self._unmapped_prefixes: + if name.startswith(unmapped_prefix): + inference_logger().debug(f"Ignoring: {name} for {unmapped_prefix}") + return + + for transformer_prefix in self._transformer_prefixes: + if name.startswith(transformer_prefix): + popped_name = name[len(transformer_prefix) + 1:] + layer_idx = popped_name.split(".")[0] + assert layer_idx.isdigit( + ), f"expected name to start w. list index but got {layer_idx} instead, name={name}" + layer_idx = int(layer_idx) + inference_logger().debug( + f"Setting: {'.'.join(popped_name.split('.')[1:])} for layer-idx={layer_idx} to {parameter.shape}") + self._transformer_params[layer_idx].set_dependency(".".join(popped_name.split(".")[1:]), parameter) + return + + try: + inference_logger().debug(f"Setting: {name} to {parameter.shape}") + self._non_transformer_params.set_dependency(name, parameter) + except ValueError: + # Catch the ValueError here from the non_transformer_params because we are knowingly + # calling it with something that may not match. This should allow us to raise a slightly more + # informative error message. + raise ValueError(f"Cannot find container for {name}, please double check the Containers/ContainerMap") + + def validate(self) -> None: + if not self._non_transformer_params.is_initialized: + raise RuntimeError("Non-transformer parameters not fully initialized after checkpoint load.") + + for layer_idx, container in enumerate(self._transformer_params): + if not container.is_initialized: + raise RuntimeError( + f"Transformer container at index {layer_idx} not fully initialized after checkpoint load.") + + +class InferenceV2Policy(ABC): + """ + The InferenceV2Policy is the base class for all inference policies. An inference policy + is responsible for instantiating the inference model and mapping the parameters from the + checkpoint engine to the model itself. + """ + + def __init__(self, checkpoint_engine: CheckpointEngineBase, model_config: Any) -> None: + self._checkpoint_engine = checkpoint_engine + self._model_config = model_config + + def build_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> DSInferenceModelBase: + """ + Completely instantiate the inference model. This will both create the ops needed to run the + model, as well as load the model parameters via the checkpoint engine. For more context + on each of these components please see ``instantiate_model`` and ``populate_model_parameters``. + + Arguments: + engine_config: The config that has been used to instantiate the engine. This is used + to communicate to the model implementation the limits on batches (sequences/tokens) + and bound the size of intermediate buffers. + mp_group: Object to enable communication between tensor parallel ranks. + + Returns: + DSInferenceModelBase: An implementation of the inference model abstraction that will be + run by the engine. + """ + self.model = self.instantiate_model(engine_config, mp_group) + self.populate_model_parameters() + return self.model + + @abstractmethod + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig) -> DSInferenceModelBase: + """ + Instantiate the inference model. Depending on the engine/model config, this could be where + different model implementations could be selected. + + Arguments: + engine_config: The config that has been used to instantiate the engine. This is used + to communicate to the model implementation the limits on batches (sequences/tokens) + and bound the size of intermediate buffers. + + Returns: + DSInferenceModelBase: An implementation of the inference model abstraction that will be + run by the engine. + """ + ... + + @abstractmethod + def build_container_map(self) -> ContainerMap: + """ + Build a dictionary representing the structure of the string prefixes leading + to the parameters to be mapped to the container. + + Returns: + ContainerMap: An instantiated mapping describing how checkpoint prefixes map + to ``LayerContainer`` instances. + """ + raise NotImplementedError() + + def populate_model_parameters(self) -> None: + """ + This model will iterate over the parameters (as provided by the checkpoint engine) and + use the container map built by ``build_container_map`` to populate the model + """ + + container_map = self.build_container_map() + for name, parameter in self._checkpoint_engine.parameters(): + container_map.map_param(name, parameter) + container_map.validate() + + self.model.set_parameters(transformer=container_map.transformer_params, + non_transformer=container_map.non_transformer_params) diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py new file mode 100644 index 000000000000..ce3a486373bc --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -0,0 +1,616 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Optional + +import torch + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from ..allocator import on_device +from ..config_v2 import RaggedInferenceEngineConfig +from ..inference_utils import ActivationType, ceil_div, is_gated +from ..model_implementations import * +from ..model_implementations.sharding import * +from ..modules.configs import ( + DSEmbeddingsConfig, + DSLinearConfig, + DSMoEConfig, + DSNormConfig, + DSSelfAttentionConfig, + DSUnembedConfig, + NormTypeEnum, + PositionalEmbeddingType, +) +from ..modules import heuristics +from ..ragged import ( + DSSequenceDescriptor, + KVCacheConfig, + RaggedBatchWrapper, +) +from .inference_model_base import ( + DSInferenceModelBase, + DSModelImplementationConfig, + MPType, +) + +try: + from functools import cached_property +except ImportError: + + def cached_property(func): + return property(func) + + +class DSTransformerModelBase(DSInferenceModelBase): + """ + Dimensioning properties + """ + + @property + @abstractmethod + def num_layers(self) -> int: + """ + Number of the layers in the model + """ + ... + + @property + @abstractmethod + def model_dim(self) -> int: + """ + Size of embedding projection and residuals. + """ + ... + + @property + @abstractmethod + def vocab_size(self) -> int: + """ + Size of the vocabulary (including padding). + """ + ... + + @property + @abstractmethod + def head_size(self) -> int: + """ + Size of each attention head. + """ + ... + + @property + @abstractmethod + def n_heads(self) -> int: + """ + The number of query heads on the model. This should not take into account + any dimension reductions from model sharding. + """ + ... + + @property + def n_heads_q(self) -> int: + """ + Alias to n_heads. + """ + return self.n_heads + + @property + def n_heads_kv(self) -> int: + """ + The number of key and value heads on the model. For GQA or MQA, overload this attribute. + Otherwise it adopts MHA formulations and uses n_heads. This should not take into account + any dimension reductions from model sharding. + """ + return self.n_heads + + @property + @abstractmethod + def intermediate_dim(self) -> int: + """ + The size of the (unsharded) intermediate projection dim. For a gated activation function + this is the size of the input to the second MLP layer. This should not take into account + any dimension reductions from model sharding. + """ + ... + + @property + @abstractmethod + def positional_embedding_type(self) -> PositionalEmbeddingType: + """ + The type of positional embedding used by the model. + """ + ... + + """ + Architectural properties + """ + + @property + @abstractmethod + def activation_dtype(self) -> torch.dtype: + """ + The activation dtype of the model. + """ + ... + + @property + @abstractmethod + def mlp_activation_fn(self) -> ActivationType: + """ + The activation function used in the MLP. + """ + ... + + @property + @abstractmethod + def norm_type(self) -> NormTypeEnum: + """ + The type of normalization used in the model. + """ + ... + + """ + Derived helpers + """ + + @cached_property + def tp_rank(self) -> int: + """ + The rank of the current process. + + # TODO(cmikeh2): Kind of a hack right now, but this is too verbose to use at + the frequency we need. + """ + return dist.get_rank(group=self._base_mp_group) + + @cached_property + def tp_size(self) -> int: + """ + The total number of processes. + + # TODO(cmikeh2): Kind of a hack right now, but this is too verbose to use at + the frequency we need. + """ + return dist.get_world_size(group=self._base_mp_group) + + @cached_property + def n_heads_q_local(self) -> int: + """ + Number of local heads post sharding. + """ + return get_local_heads(self.tp_rank, self.tp_size, self.n_heads_q, self.n_heads_kv)[0] + + @cached_property + def n_heads_kv_local(self) -> int: + """ + Number of local heads post sharding. + """ + return get_local_heads(self.tp_rank, self.tp_size, self.n_heads_q, self.n_heads_kv)[1] + + @property + def gated_mlp(self) -> bool: + """ + Return a boolean to determine whether the model uses a gated activation function. + """ + return is_gated(self.mlp_activation_fn) + + """ + Method implementations + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_mlp_1_layer() + self.make_mlp_2_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + ######### Embedding ######### + def make_embedding_layer(self) -> None: + """ + Performs setup and creates embedding DSModule. This will set the `self.embed` attribute. + """ + + embed_config = DSEmbeddingsConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + residual_dtype=self.activation_dtype, + embedding_dim=self.model_dim, + ) + + self.embed = heuristics.instantiate_embed(embed_config, self._engine_config) + + @on_device + def transform_embedding_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Performs embedding sharding along the channels dimension. + """ + # Until we can do non-contiguous all-gather, we won't shard the embedding parameters. + return param.to(self.activation_dtype.value) + + ######### Unembedding ######### + def make_unembedding_layer(self) -> None: + """ + Performs setup and creates an unembedding layer. This implementation assumes + normalization prior to the LM head projection. If this does not match the model's + implementation, override this method. This will set the ``self.unembed`` attribute. + """ + unembed_dim = sharded_unembed_dim(self.vocab_size, self.tp_rank, self.tp_size) + + unembed_config = DSUnembedConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + dtype=self.activation_dtype, + model_dim=self.model_dim, + vocab_size=unembed_dim, + norm_type=self.norm_type, + ) + + self.unembed = heuristics.instantiate_unembed(unembed_config, self._engine_config) + + if self.tp_size > 1: + self._comm_logits = torch.empty(self.tp_size, + self._engine_config.state_manager.max_ragged_sequence_count, + unembed_dim, + device=get_accelerator().current_device(), + dtype=self.activation_dtype.value) + self._return_logits = torch.empty(self._engine_config.state_manager.max_ragged_sequence_count, + self.vocab_size, + device=get_accelerator().current_device(), + dtype=self.activation_dtype.value) + + @on_device + def transform_unembed_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Performs sharding along the vocab dimension. + """ + return shard_unembed_param(param, self.tp_rank, self.tp_size).to(self.activation_dtype.value) + + ######### QKV ######### + def make_qkv_layer(self) -> None: + """ + Instantiates the linear projection layer for the QKV linear layer. This sets the + `self.qkv` attribute. + """ + out_features = qkv_out_features(self.model_dim, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, + self.n_heads_kv) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=self.model_dim, + out_channels=out_features, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.qkv = heuristics.instantiate_linear(linear_config, self._engine_config) + + @on_device + def transform_qkv_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Passes a QKV parameter to the underlying implementation for any necessary + transformations. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons) + """ + param = shard_qkv_param(param, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, self.n_heads_kv) + return self.qkv.transform_param(param) + + ######### Attention ######### + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + + def get_kv_requirements(self, sequence: DSSequenceDescriptor, max_new_tokens: int, + max_new_blocks: int) -> Tuple[int, int]: + """ + See ``DSInferenceModelBase.get_kv_requirements`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + total_tokens = sequence.seen_tokens + max_new_tokens + req_blocks = ceil_div(total_tokens, self.attn.kv_block_size) + block_lim = req_blocks - sequence.cur_allocated_blocks + + if block_lim <= max_new_blocks: + return max_new_tokens, block_lim + + token_capacity = (max_new_blocks + + sequence.cur_allocated_blocks) * self.attn.kv_block_size - sequence.seen_tokens + + return token_capacity, max_new_blocks + + def maybe_allocate_kv(self, sequence: DSSequenceDescriptor, n_new_tokens: int) -> None: + """ + See ``DSInferenceModelBase.maybe_allocate_kv`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + _, n_needed_blocks = self.get_kv_requirements(sequence, n_new_tokens, self.state_manager.free_blocks) + + if n_needed_blocks > 0: + new_blocks = self.state_manager.allocate_blocks(n_needed_blocks) + sequence.extend_kv_cache(new_blocks) + + def kv_cache_config(self) -> KVCacheConfig: + """ + See ``DSInferenceModelBase.kv_cache_config`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + if self._kv_cache_config is None: + cache_shape = (self.num_layers, self.n_heads_kv_local, self.head_size) + max_blocks = ceil_div(self.max_sequence_length, self.attn.kv_block_size) + self._kv_cache_config = KVCacheConfig(block_size=self.attn.kv_block_size, + cache_shape=cache_shape, + cache_dtype=self.activation_dtype, + max_blocks_per_allocation_group=max_blocks) + return self._kv_cache_config + + def prepare_batch(self, wrapped_batch: RaggedBatchWrapper) -> None: + """ + See ``DSInferenceModelBase.prepare_batch`` for documentation. + + This method assumes an autoregressive dense attention pattern. Override this method + if this does not match the model's attention pattern. + """ + self.attn.build_atoms(wrapped_batch) + + ######### Attention output ######### + def make_attn_out_layer(self) -> None: + """ + Instantiates the linear projection layer for the attention output linear layer. This sets the + `self.attn_out` attribute. + """ + in_features = attn_out_in_features(self.model_dim, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, + self.n_heads_kv) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=in_features, + out_channels=self.model_dim, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.attn_out = heuristics.instantiate_linear(linear_config, self._engine_config) + + @on_device + def transform_attn_out_param(self, param: torch.Tensor) -> Optional[torch.Tensor]: + """ + Shards an attention output projection parameter and passes it to the underlying + implementation for any necessary transformations. This will return `None` for bias parameters + if they are not on TP rank 0. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons). + """ + param = shard_attn_out_param(param, self.tp_rank, self.tp_size, self.head_size, self.n_heads_q, + self.n_heads_kv) + + if param is not None: + param = self.attn_out.transform_param(param) + + return param + + ######### MLP ######### + def make_mlp_1_layer(self) -> None: + """ + Instantiates the linear projection layer for the first MLP in the feedforward network. + This sets the `self.mlp_1` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=self.model_dim, + out_channels=shard_size, + activation=self.mlp_activation_fn, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.mlp_1 = heuristics.instantiate_linear(linear_config, self._engine_config) + + @on_device + def transform_mlp_1_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Shards the first MLP parameter and passes it to the underlying implementation + for any necessary transformations. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons). + """ + param = shard_mlp_1_param(param, self.tp_rank, self.tp_size, gated=self.gated_mlp) + + return self.mlp_1.transform_param(param) + + def make_mlp_2_layer(self) -> None: + """ + Instantiates the linear projection layer for the second MLP in the feedforward network. + This sets the `self.mlp_2` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=shard_size, + out_channels=self.model_dim, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.mlp_2 = heuristics.instantiate_linear(linear_config, self._engine_config) + + @on_device + def transform_mlp_2_param(self, param: torch.Tensor) -> Optional[torch.Tensor]: + """ + Shards the second MLP parameter and passes it to the underlying implementation + for any necessary transformations. This will return `None` for bias parameters + if they are not on TP rank 0. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + the shape (out_neurons, in_neurons). + """ + param = shard_mlp_2_param(param, self.tp_rank, self.tp_size) + + if param is not None: + param = self.mlp_2.transform_param(param) + + return param + + ######### Norm ######### + def make_norm_layer(self) -> None: + """ + Instantiates the normalization layer for the model. This sets the `self.norm` attribute. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + """ + norm_config = DSNormConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + type=self.norm_type, + channels=self.model_dim, + residual_dtype=self.activation_dtype, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) + + @on_device + def transform_norm_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Passes a normalization parameter to the underlying implementation for any + necessary transformations. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + + Args: + param (torch.Tensor): The parameter to transform. This may be either a bias or weight and should have + shape (model_dim,) + """ + return self.norm.transform_param(param) + + +class DSMoETransformerModelBase(DSTransformerModelBase): + + @property + def num_experts(self) -> int: + """ + Return the number of experts in the model. + """ + raise NotImplementedError("Attempted to access an unimplemented number of experts") + + def make_moe_layer(self) -> None: + """ + Instantiates the MoE layer for the model. This sets the `self.moe` attribute. + """ + sharded_dim = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + moe_config = DSMoEConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + model_dim=self.model_dim, + intermediate_features=sharded_dim, + activation=self.mlp_activation_fn, + n_experts=self.num_experts, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.moe = heuristics.instantiate_moe(moe_config, self._engine_config) + + @on_device + def transform_moe_gate_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Passes a MoE gate parameter to the underlying implementation for any necessary transformations. + + TODO(cmikeh2): This will need to be updated/overridden for expert parallelism. + """ + return self.moe.transform_gate_param(param) + + @on_device + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Shards the first MoE param and passes it to the underlying implementation. Since it's possible for an architecture + to have both MoE and non-MoE layers, this can't be overloaded on the MLP1 transform. Furthermore, since both + the MoE DSModule owns both MLP1 and MLP2, under certain sharding conditions it's not possible for the model implementation + to infer from the shape whether to perform a different transformation based on MLP1 or MLP2. This (and the below) + separations are intended to solve both these issues. + + Args: + param (torch.Tensor): The parameter to transform. This should have shape (n_experts, out_neurons, in_neurons). + """ + param = shard_mlp_1_param(param, self.tp_rank, self.tp_size, gated=self.gated_mlp, is_moe=True) + + return self.moe.transform_moe_mlp_1_param(param) + + @on_device + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> Optional[torch.Tensor]: + """ + Shards the second MoE param and passes it to the underlying implementation. See the above for context on why this API + exists. + + This will return `None` for expert bias params not on TP rank 0. NOTE(cmikeh2): Does it make sense to round-robin assign? + My intuition is that this will make debugging much more difficult for minimal memory reduction. + + Args: + param (torch.Tensor): The parameter to transform. This should have shape (n_experts, out_neurons, in_neurons). + """ + param = shard_mlp_2_param(param, self.tp_rank, self.tp_size, is_moe=True) + + if param is not None: + param = self.moe.transform_moe_mlp_2_param(param) + + return param diff --git a/deepspeed/inference/v2/model_implementations/layer_container_base.py b/deepspeed/inference/v2/model_implementations/layer_container_base.py new file mode 100644 index 000000000000..e0ec19372569 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/layer_container_base.py @@ -0,0 +1,289 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Type + +import torch + +from deepspeed.accelerator import get_accelerator +from .parameter_base import ParameterBase, ParametrizedList + +# Currently have dependency loops for the type hints. +InferenceModel = Type["InferenceModel"] +LayerContainer = Type["LayerContainer"] + +MAPPING_KEY = "PARAM_MAPPING" +PLIST_HELPERS = "_ds_plist_strip_vals" + + +def make_finalization_callback(all_names: str): + """ + Helper method for building the finalization callback for a LayerContainer. This + is not client code and should not be used or called directly. + """ + + def finalization_callback(self, param: ParameterBase, finalized_param: torch.Tensor) -> None: + """ + Callback for when a parameter is finalized. + """ + self._finalized_params += 1 + + for name in all_names: + if getattr(self, name) is param: + setattr(self, name, finalized_param) + + return finalization_callback + + +class LayerMetaclass(type): + """ + MetaClass for the LayerContainer base class. This class will parse the annotations + of the class that correspond to `ParameterBase` and create None initializers for each + as well as a finalization callback that for when each `ParameterBase` is finalized + and should be replaced with a Tensor. + """ + + def __new__(cls, clsname, bases, attrs): + + annotations = attrs.get("__annotations__", {}) + + for base in bases: + # We'll pick up all annotations on any base classes. This will allow us to + # to use inheritance to share common parameter groups in base classes. + if hasattr(base, "__annotations__"): + annotations.update(base.__annotations__) + + if hasattr(base, MAPPING_KEY): + if MAPPING_KEY not in attrs: + # This is likely a fail state. If a parent has MAPPING KEY but the child does + # not, then we're guaranteed only a subset of the parameters will be mapped. + attrs[MAPPING_KEY] = {} + attrs[MAPPING_KEY].update(getattr(base, MAPPING_KEY)) + + all_names = [name for name, annotation in annotations.items() if issubclass(annotation, ParameterBase)] + + if MAPPING_KEY in attrs: + # If we have a mapping key at all, then we will enter the validation mode for building + # helpers for mapping and ensuring we have complete mapping. + + # First we'll build a flat list of every dependency for this layer. + all_deps = set() + for name in all_names: + parameter_deps = [ + name for name, annotation in annotations[name].__annotations__.items() + if issubclass(annotation, (torch.Tensor, ParametrizedList)) + ] + + all_deps.update([f"{name}.{dep}" for dep in parameter_deps]) + + # Create static helper for doing the string processing only once. + attrs[PLIST_HELPERS] = [] + + # Iterate over all the mappings + for src_name, target_or_targets in attrs[MAPPING_KEY].items(): + if isinstance(target_or_targets, str): + target_or_targets = [target_or_targets] + + actual_targets = [] + for target_name in target_or_targets: + base_dependency, dependency_attr = target_name.split(".") + + # Check for invalid mappings + if base_dependency not in all_names: + raise ValueError( + "Target parameter \"{}\" not found in this layer. Valid targets are {}".format( + base_dependency, all_names)) + if dependency_attr not in annotations[base_dependency].__annotations__: + # This check is not universal (see below) if a single dependency is being + # mapped to by a single row. + raise ValueError( + "Target dependency \"{}\" not found on parameter \"{}\". Valid targets are {}".format( + dependency_attr, base_dependency, annotations[base_dependency].__annotations__.keys())) + if target_name not in all_deps: + raise ValueError( + "Target dependency \"{}\" was targeted with multiple mapping rules.".format(target_name)) + + # If we've made it this far, the dependency definitely exists. + actual_targets.append(annotations[base_dependency].__annotations__[dependency_attr]) + + all_deps.remove(target_name) + + are_plists = [issubclass(target, ParametrizedList) for target in actual_targets] + if all(are_plists): + # We can do direct sets on everything but ParametrizedLists, so we'll only explicitly + # handle these here. + # TODO(cmikeh2): SPLIT, error if more than 1 + glob_count = src_name.count("*") + if glob_count > 1: + raise ValueError( + "ParametrizedList index inference can only work with a single glob: {}".format(src_name)) + elif glob_count == 0: + raise ValueError( + "Must have wildcard (*) in source name for ParametrizedList mapping: {}".format(src_name)) + + wildcard_idx = src_name.find("*") + prefix = src_name[:wildcard_idx] + suffix = src_name[wildcard_idx + 1:] + attrs[PLIST_HELPERS].append((prefix, suffix, target_or_targets)) + elif any(are_plists): + raise ValueError("Cannot mix ParametrizedLists and Tensors in a single mapping rule.") + + if len(all_deps) > 0: + raise ValueError( + "A parameter mapping was provided for {}, but the following dependencies were not mapped: {}". + format(clsname, all_deps)) + + attrs["finalization_callback"] = make_finalization_callback(all_names) + + new_obj = super().__new__(cls, clsname, bases, attrs) + + setattr(new_obj, "_n_params", len(all_names)) + setattr(new_obj, "_annotation_attrs", all_names) + + return new_obj + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls, *args, **kwargs) + instance.__init__(*args, **kwargs) + + for name, annotation in instance.__annotations__.items(): + if issubclass(annotation, ParameterBase): + # TODO(cmikeh2): Do we want to make this a property + # It might also make sense to do this in the base class __init__ + # but since it is tied with the changes made in __new__ it feels + # to me like it should be here. + setattr(instance, name, annotation(instance.inference_model, instance)) + + return instance + + +class LayerContainer(metaclass=LayerMetaclass): + """ + Abstract base class for containing model parameters. + + This is primarily a guidance abstraction since we do not put any restrictions + on how the parameters are stored. + + To use this class, annotate the class with `ParameterBase` subclasses and give them + names. As a checkpoint is loaded into this container, the `ParameterBase` instances + will be replaced with realized Tensors as soon as each of their dependencies are met. + + To enable automatic mapping, add a static attribute `PARAM_MAPPING` to the class + definition. This should be a dictionary mapping from a source string to one or + more dependencies. + + ```python + class MyLayer(LayerContainer): + PARAM_MAPPING = { + "path.to.param.dependency", "container_param_1.dependency", + "path.to.param2.dependency", "container_param_2.dependency", + "path.to.param3.*.dependency", "container_param_3.list_dependency" + } + + ... + ``` + """ + + def __init__(self, model: InferenceModel) -> None: + """ + Initialization of the LayerContainer. This method does not need to be overridden + for any children classes. + + Args: + model (InferenceModel): Inference model that will be used to shard and transform + parameters correctly, as well as provide specific information about the model + for `ParameterizedList`s that may be part of one of the member `ParameterBase`s. + """ + self.inference_model = model + self._finalized_params = 0 + + @property + def is_initialized(self) -> bool: + """ + Returns whether or not all parameters have been initialized and transformed by + the model. Once this returns True, all the `ParameterBase` instances will be + torch.Tensors. + """ + if self._finalized_params != self.n_params: + return False + + for name in self._annotation_attrs: + tensor = getattr(self, name) + if tensor is None: + continue + elif not isinstance(tensor, torch.Tensor): + raise ValueError("Layer should be finalized, but {} is neither Tensor or None".format(name)) + elif tensor.device != torch.device(get_accelerator().current_device()): + raise RuntimeError("Layer should be finalized, but {} is not on device {}".format( + name, + get_accelerator().current_device())) + return True + + @property + def n_params(self) -> int: + """ + The number of parameters this container holds. This is a read-only value + that is set by the metaclass. + """ + return self._n_params + + @property + def mapping_params(self) -> dict: + return getattr(self.__class__, MAPPING_KEY, {}) + + @property + def plist_helpers(self) -> list: + return getattr(self.__class__, PLIST_HELPERS, []) + + def set_dependency(self, dep_name: str, dep_value: torch.Tensor) -> None: + """ + Set dependency can be used for managing dependencies when a mapping is provided + in the class definition for the layer. The dep_name here should have any prefix + for transformer layers removed (such as model.layers.*.attn.qkv.weight -> attn.qkv.weight). + + Args: + dep_name (str): The name of the dependency to set. + dep_value (torch.Tensor): The value to set the dependency to. + """ + if dep_name in self.mapping_params: + # If we have an exact match, it's a direct mapping and we can immediately set + # the value. + target = self.mapping_params[dep_name] + + # Convert single targets to a list for consistency + if isinstance(target, str): + target = [target] + + for target_name in target: + # Double setting doesn't set the attribute correctly, so we do a getattr then setattr + target_param_name, target_dependency_name = target_name.split(".") + target_param = getattr(self, target_param_name) + setattr(target_param, target_dependency_name, dep_value) + return + + # Otherwise we need to map to one of the parameter lists. + for prefix, suffix, dests in self.plist_helpers: + if dep_name.startswith(prefix) and dep_name.endswith(suffix): + # We have a match, so we can set the value. + target_idx = int(dep_name[len(prefix):-len(suffix)]) + + # Convert single targets to a list for consistency + if isinstance(dests, str): + dests = [dests] + + for dest in dests: + target_param_name, target_dependency_name = dest.split(".") + target_param = getattr(self, target_param_name) + target_dependency = getattr(target_param, target_dependency_name) + target_dependency[target_idx] = dep_value + return + + raise ValueError( + "Could not find a mapping for dependency \"{}\". Check that it is included in the ``MAPPING_PARAMS``. See docstring for more on ``MAPPING_PARAMS``" + .format(dep_name)) + + +class ContainerMap: + pass diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py b/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py new file mode 100644 index 000000000000..ec39866d0d8d --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ...model_implementations.common_parameters import * +from ...model_implementations.layer_container_base import LayerContainer +''' + # HF Llama model looks like this: + +LlamaForCausalLM( + (model): LlamaModel( + (embed_tokens): Embedding(32000, 4096, padding_idx=0) + (layers): ModuleList( + (0-31): 32 x LlamaDecoderLayer( + (self_attn): LlamaAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=4096, bias=False) + (v_proj): Linear(in_features=4096, out_features=4096, bias=False) + (o_proj): Linear(in_features=4096, out_features=4096, bias=False) + (rotary_emb): LlamaRotaryEmbedding() + ) + (mlp): LlamaMLP( + (gate_proj): Linear(in_features=4096, out_features=11008, bias=False) + (up_proj): Linear(in_features=4096, out_features=11008, bias=False) + (down_proj): Linear(in_features=11008, out_features=4096, bias=False) + (act_fn): SiLUActivation() + ) + (input_layernorm): LlamaRMSNorm() + (post_attention_layernorm): LlamaRMSNorm() + ) + ) + (norm): LlamaRMSNorm() + ) + (lm_head): Linear(in_features=4096, out_features=32000, bias=False) +) +''' + + +class Llama2TransformerContainer(LayerContainer): + """ + Transformer layer container for the Llama-2 model. + """ + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + #rotary_emb: InvFreqParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_proj.weight": "mlp_1_w.gate_params", + "mlp.up_proj.weight": "mlp_1_w.up_params", + "mlp.down_proj.weight": "mlp_2_w.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + #"self_attn.rotary_emb.inv_freq": "rotary_emb.params", + } + + +class Llama2NonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Llama-2 model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py b/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py new file mode 100644 index 000000000000..9b628f77de01 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer + + +class Llama2InferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Llama-2 models. + """ + + _non_transformer: Optional[Llama2NonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Llama2TransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + # llama model family is special and is always gated so force gated versions of relu, gelu, silu + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) + hidden_states = self.attn(hidden_states, kv_cache, + ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, + self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py b/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py new file mode 100644 index 000000000000..65fe7b705e53 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse + +from typing import Any + +from ...checkpoint import CheckpointEngineBase +from ...config_v2 import RaggedInferenceEngineConfig +from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy +from ...model_implementations.llama_v2.llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer +from ...model_implementations.llama_v2.llama_v2_model import Llama2InferenceModel + + +class Llama2Policy(InferenceV2Policy): + + def __init__(self, checkpoint_engine: CheckpointEngineBase, model_config: argparse.Namespace) -> None: + super().__init__(checkpoint_engine, model_config) + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Llama2InferenceModel: + return Llama2InferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Llama2TransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Llama2NonTransformerContainer(self.model)) + + map.set_unmapped_params( + [f'model.layers.{i}.self_attn.rotary_emb.inv_freq' for i in range(self.model.num_layers)]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/mistral/__init__.py b/deepspeed/inference/v2/model_implementations/mistral/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/inference/v2/model_implementations/mistral/container.py b/deepspeed/inference/v2/model_implementations/mistral/container.py new file mode 100644 index 000000000000..b4c0956f4049 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/container.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer +''' + # HF Mistral model (mistralai/Mistral-7B-v0.1) looks like this: +MistralForCausalLM( + (model): MistralModel( + (embed_tokens): Embedding(32000, 4096) + (layers): ModuleList( + (0-31): 32 x MistralDecoderLayer( + (self_attn): MistralAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=1024, bias=False) + (v_proj): Linear(in_features=4096, out_features=1024, bias=False) + (o_proj): Linear(in_features=4096, out_features=4096, bias=False) + (rotary_emb): MistralRotaryEmbedding() + ) + (mlp): MistralMLP( + (gate_proj): Linear(in_features=4096, out_features=14336, bias=False) + (up_proj): Linear(in_features=4096, out_features=14336, bias=False) + (down_proj): Linear(in_features=14336, out_features=4096, bias=False) + (act_fn): SiLUActivation() + ) + (input_layernorm): MistralRMSNorm() + (post_attention_layernorm): MistralRMSNorm() + ) + ) + (norm): MistralRMSNorm() + ) + (lm_head): Linear(in_features=4096, out_features=32000, bias=False) +) +''' + + +class MistralTransformerContainer(LayerContainer): + """ + Transformer layer container for the Mistral model. + """ + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + mlp_1_w: GatedMLPParameter + mlp_2_w: MLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate_proj.weight": "mlp_1_w.gate_params", + "mlp.up_proj.weight": "mlp_1_w.up_params", + "mlp.down_proj.weight": "mlp_2_w.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class MistralNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Mistral model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/mistral/model.py b/deepspeed/inference/v2/model_implementations/mistral/model.py new file mode 100644 index 000000000000..d9b06b91e308 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/model.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...ragged import RaggedBatchWrapper + +from .container import MistralNonTransformerContainer, MistralTransformerContainer + + +class MistralInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for Mistral models. + """ + + _non_transformer: Optional[MistralNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[MistralTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) + hidden_states = self.attn(hidden_states, kv_cache, + ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=None) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, + self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/mistral/policy.py b/deepspeed/inference/v2/model_implementations/mistral/policy.py new file mode 100644 index 000000000000..93458827aa96 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mistral/policy.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse + +from typing import Any + +from deepspeed.inference.v2.checkpoint import CheckpointEngineBase +from deepspeed.inference.v2.config_v2 import RaggedInferenceEngineConfig +from deepspeed.inference.v2.model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy +from deepspeed.inference.v2.model_implementations.mistral.container import MistralNonTransformerContainer, MistralTransformerContainer +from deepspeed.inference.v2.model_implementations.mistral.model import MistralInferenceModel + + +class MistralPolicy(InferenceV2Policy): + + def __init__(self, checkpoint_engine: CheckpointEngineBase, model_config: argparse.Namespace) -> None: + super().__init__(checkpoint_engine, model_config) + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> MistralInferenceModel: + return MistralInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [MistralTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(MistralNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/opt/__init__.py b/deepspeed/inference/v2/model_implementations/opt/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/inference/v2/model_implementations/opt/container.py b/deepspeed/inference/v2/model_implementations/opt/container.py new file mode 100644 index 000000000000..5b1c9ce4c8a3 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/container.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ...model_implementations.common_parameters import * +from ...model_implementations.layer_container_base import LayerContainer +''' + # HF OPT model looks like this: + +OPTForCausalLM( + (model): OPTModel( + (decoder): OPTDecoder( + (embed_tokens): Embedding(50272, 768, padding_idx=1) + (embed_positions): OPTLearnedPositionalEmbedding(2050, 768) + (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + (layers): ModuleList( + (0-11): 12 x OPTDecoderLayer( + (self_attn): OPTAttention( + (k_proj): Linear(in_features=768, out_features=768, bias=True) + (v_proj): Linear(in_features=768, out_features=768, bias=True) + (q_proj): Linear(in_features=768, out_features=768, bias=True) + (out_proj): Linear(in_features=768, out_features=768, bias=True) + ) + (activation_fn): ReLU() + (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + (fc1): Linear(in_features=768, out_features=3072, bias=True) + (fc2): Linear(in_features=3072, out_features=768, bias=True) + (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) + ) + ) + ) + ) + (lm_head): Linear(in_features=768, out_features=50272, bias=False) +) + +''' + + +class OPTTransformerContainer(LayerContainer): + """ + Transformer layer container for the OPT model. + """ + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + attn_out_b: AttentionOutputParameter + mlp_1_w: MLP1Parameter + mlp_1_b: MLP1Parameter + mlp_2_w: MLP2Parameter + mlp_2_b: MLP2Parameter + attn_norm_beta: NormParameter + attn_norm_gamma: NormParameter + mlp_norm_beta: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.out_proj.weight": "attn_out_w.params", + "self_attn.out_proj.bias": "attn_out_b.params", + "fc1.weight": "mlp_1_w.params", + "fc1.bias": "mlp_1_b.params", + "fc2.weight": "mlp_2_w.params", + "fc2.bias": "mlp_2_b.params", + "self_attn_layer_norm.weight": "attn_norm_gamma.params", + "self_attn_layer_norm.bias": "attn_norm_beta.params", + "final_layer_norm.weight": "mlp_norm_gamma.params", + "final_layer_norm.bias": "mlp_norm_beta.params", + } + + +class OPTNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the OPT model. + """ + word_emb: EmbeddingParameter + word_emb_pos: EmbeddingParameter + word_unembed: UnembedParameter + final_norm_w: NormParameter + final_norm_b: NormParameter + + PARAM_MAPPING = { + "model.decoder.embed_tokens.weight": "word_emb.params", + "model.decoder.embed_positions.weight": "word_emb_pos.params", + "model.decoder.final_layer_norm.weight": "final_norm_w.params", + "model.decoder.final_layer_norm.bias": "final_norm_b.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/opt/model.py b/deepspeed/inference/v2/model_implementations/opt/model.py new file mode 100644 index 000000000000..fa221e15a0b7 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/model.py @@ -0,0 +1,195 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import ( + DSEmbeddingsConfig, + NormTypeEnum, + PositionalEmbeddingType, +) +from ...ragged import RaggedBatchWrapper +from .container import OPTNonTransformerContainer, OPTTransformerContainer + +from ...modules.heuristics import instantiate_embed + + +class OPTInferenceModel(DSTransformerModelBase): + """ + Inference model implementation for ragged batching for OPT models. + """ + + _non_transformer: Optional[OPTNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[OPTTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_seq_length + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.ffn_dim + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.RELU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.LayerNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.none + + """ + Overrides of ``DSTransformerModelBase`` methods + """ + + def make_embedding_layer(self) -> None: + """ + Performs setup and creates embedding DSModule. Since OPT includes trained + positional embeddings, we will override the base model implementation. + """ + + embed_config = DSEmbeddingsConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + residual_dtype=self.activation_dtype, + embedding_dim=self.model_dim, + positional_embedding=True, + positional_offset=2) + + self.embed = instantiate_embed(embed_config, self._engine_config) + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + embed = self.embed(ragged_batch, self._non_transformer.word_emb, self._non_transformer.word_emb_pos) + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, + ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, + hidden_states, + cur_params.mlp_norm_gamma, + beta=cur_params.mlp_norm_beta) + + # Should be configurable in the future + hidden_states = self.mlp_1(hidden_states, cur_params.mlp_1_w, b=cur_params.mlp_1_b) + hidden_states = self.mlp_2(hidden_states, cur_params.mlp_2_w, b=cur_params.mlp_2_b) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, + hidden_states, + next_params.attn_norm_gamma, + beta=next_params.attn_norm_beta) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, + self._non_transformer.final_norm_w, self._non_transformer.final_norm_b) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, + None, + self._transformer[0].attn_norm_gamma, + beta=self._transformer[0].attn_norm_beta) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer_layer(layer_idx, residual, hidden_states, + wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/opt/policy.py b/deepspeed/inference/v2/model_implementations/opt/policy.py new file mode 100644 index 000000000000..0f5002cdaa54 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/opt/policy.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import argparse + +from typing import Any + +from ...checkpoint import CheckpointEngineBase +from ...config_v2 import RaggedInferenceEngineConfig +from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy +from ...model_implementations.opt.container import OPTNonTransformerContainer, OPTTransformerContainer +from ...model_implementations.opt.model import OPTInferenceModel + + +class OPTPolicy(InferenceV2Policy): + + def __init__(self, checkpoint_engine: CheckpointEngineBase, model_config: argparse.Namespace) -> None: + super().__init__(checkpoint_engine, model_config) + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> OPTInferenceModel: + return OPTInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [OPTTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.decoder.layers'], transformer_containers) + + map.set_non_transformer_params(OPTNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/parameter_base.py b/deepspeed/inference/v2/model_implementations/parameter_base.py new file mode 100644 index 000000000000..a413c6c4027a --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/parameter_base.py @@ -0,0 +1,257 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import weakref +from abc import abstractmethod +from typing import Type + +import torch + +# Currently have dependency loops for the type hints. +InferenceModel = Type["InferenceModel"] +LayerContainer = Type["LayerContainer"] +ParametrizedList = Type["ParametrizedList"] + +MAPPING_KEY = "PARAM_MAPPING" +PLIST_HELPERS = "_ds_plist_strip_vals" + + +def make_param_getter(clsname, param): + """ + Normal getter implementation for a property. + """ + + def param_getter(self): + return getattr(self, f"__{clsname}__{param}") + + return param_getter + + +def make_param_setter(clsname, param): + """ + Setter implementation that will call complete component to potentially + finalize the parameter. + """ + + def param_setter(self, value): + setattr(self, f"__{clsname}__{param}", value) + self.complete_component() + + return param_setter + + +def make_readonly_setter(): + """ + Setter implementation that will raise an error if called. + """ + + def paramlist_setter(self, value): + raise ValueError("Cannot set a ParametrizedList directly.") + + return paramlist_setter + + +class ParameterMetaclass(type): + """ + MetaClass for the ParameterBase base class. This class will parse the `src_params` + attribute and create properties for each of the dependencies. A dependency can either + be represented as a string, which is interpreted as a named Tensor, or a `ParametrizedList` + subclass. + """ + + def __new__(cls, clsname, bases, attrs): + + annotations = attrs.get("__annotations__", {}) + dependencies = { + name: annotation + for name, annotation in annotations.items() if issubclass(annotation, (torch.Tensor, ParametrizedList)) + } + n_dependencies = len(dependencies) + + # Create properties for each of our dependencies + for d_name, d_type in dependencies.items(): + if issubclass(d_type, ParametrizedList): + assert hasattr( + d_type, "count_attr" + ), "ParametrizedList must have a count_attr attribute to access on the inference module." + attrs[d_name] = property(make_param_getter(clsname, d_name), make_readonly_setter()) + else: # torch.Tensor + attrs[d_name] = property(make_param_getter(clsname, d_name), make_param_setter(clsname, d_name)) + + new_cls = super().__new__(cls, clsname, bases, attrs) + new_cls.n_dependencies = n_dependencies + + return new_cls + + def __call__(cls, *args, **kwargs): + new_obj = super().__call__(*args, **kwargs) + new_obj.__init__(*args, **kwargs) + + setattr(new_obj, "dest_param", None) + + # Initialize our dependences to None/empty `ParametrizedList`s + for name, annotation in new_obj.__annotations__.items(): + if issubclass(annotation, ParametrizedList): + #TODO(jeff): update assert with this, model implementation attribute does not align or missing wrt the ParametrizedList attributes + assert hasattr( + new_obj.inference_model, annotation.count_attr + ), f"new_obj={new_obj.__class__.__name__}, name={name}, annotation.count_attr={annotation.count_attr}" + param_list = annotation(new_obj, getattr(new_obj.inference_model, annotation.count_attr)) + setattr(new_obj, f"__{new_obj.__class__.__name__}__{name}", param_list) + else: # torch.Tensor + setattr(new_obj, f"__{new_obj.__class__.__name__}__{name}", None) + + return new_obj + + +class ParameterBase(metaclass=ParameterMetaclass): + """ + A ParameterBase allows us to consolidate tracking the dependencies of loading a parameter from + a checkpoint into a single object. This class should not be used directly, but rather subclassed + and the `src_params` attribute set to a list of strings and/or `ParametrizedList`s. + """ + + # inference_model: InferenceModel + """ + Inference model that will provide context on how to shard and transform the parameter. + """ + + #completed_components: int + """ + How many of the layer dependencies have been met. This is used to determine when the parameter + is ready to be finalized. A ParametrizedList counts as a single dependency for the purposes + of this counter. + """ + + def __init__(self, model: InferenceModel, parent_container: LayerContainer) -> None: + """ + Direct constructor. This should not be called from client code. + + Args: + model (InferenceModel): Inference model that will be used to shard and transform the + parameter in `finalize`. + parent_container (LayerContainer): The parent container that this parameter is a member + of. We will build a weakref to this container to call the finalization callback. + """ + self.inference_model = model + self.completed_components = 0 + self.parent_container = weakref.ref(parent_container) + + @abstractmethod + def finalize(self) -> torch.Tensor: + """ + Finalize the parameter after all of its source parameters have been set. This method + will be automatically called when all inputs have been set. It should return the Tensor + with all transformations performed on it. + """ + pass + + def complete_component(self) -> None: + """ + Mark a component as completed. This should be called by the relevant setter of a direct + property or a ParametrizedList. This method will automatically call `finalize` when all + dependencies have been met and then call the finalization callback on the parent container. + + Once the finalization callback has been called, the parameter will be replaced with the + `dst_param` attribute on the parent container, and this instance will be destroyed. + """ + self.completed_components += 1 + + if self.completed_components != self.n_dependencies: + return + + finalized_param = self.finalize() + self.parent_container().finalization_callback(self, finalized_param) + + +class ParametrizedList: + """ + A ParametrizedList is a list of parameters that are dependencies + of a `ParameterBase` but may vary in length depending on the model + configuration (rather than architecture). For example, a MoE layer + may have different number of experts depending on the size of the model. + + This class is used to manage these lists and provide integer indexing + of a single component rather than accessing names directly. For example, + it tends to be more natural to access the 8th expert with `experts[8]` + rather than a name like `expert_8`, especially as an attribute. + + To inherit from this class, set static variables `name` and `count_attr`. + + ```python + class MyParametrizedList(ParametrizedList): + count_attr: str = "my_list_count" + ``` + + In the above example, `my_list_count` should be an accessible attribute + of the inference model (i.e. via `self.inference_model.my_list_count`). + + NOTE: There are some APIs in which this type cannot be used as if it is + just a list of Tensors. For example, `torch.cat(param_list)` will not work. + However, you can make it compatible with a tuple wrapper: + `torch.cat(tuple(param_list))` + """ + + n_params: int + """ + Number of params this list contains. + """ + + param: ParameterBase + """ + WeakRef to the owning parameter. + """ + + def __init__(self, param: ParameterBase, n_params: int) -> None: + """ + Constructor. Should not be called from client code. + + Args: + param (ParameterBase): The owning parameter. + n_params (int): The number of parameters this list contains. This should be + """ + self.n_params = n_params + self.set_params = 0 + self.param = weakref.ref(param) + self._params = [None] * n_params + + def __getitem__(self, index): + return self._params[index] + + def __setitem__(self, index, value): + if self._params[index] is not None: + raise ValueError("Cannot set a parameter twice.") + + self._params[index] = value + self.set_params += 1 + + if self.set_params != self.n_params: + return + + self.param().complete_component() + + def __iter__(self): + return iter(self._params) + + +def ParamList(attr: str): + """ + Helper to create a subclass of ParametrizedList with the desired `count_attr`. + + In this manner, we can annotate the type of a Parameter dependency with the + following: + + ```python + class CustomParameter(ParameterBase): + dependency_list: ParamList("dependencies_count_name") + ``` + + where "dependencies_count_name" is the name of the attribute on the inference model. + """ + + class ParametrizedListInstance(ParametrizedList): + count_attr: str = attr + + return ParametrizedListInstance diff --git a/deepspeed/inference/v2/model_implementations/sharding/__init__.py b/deepspeed/inference/v2/model_implementations/sharding/__init__.py new file mode 100644 index 000000000000..63421bc1c622 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attn import * +from .attn_out import * +from .embedding import * +from .mlp import * +from .qkv import * +from .types import * +from .unembed import * diff --git a/deepspeed/inference/v2/model_implementations/sharding/attn.py b/deepspeed/inference/v2/model_implementations/sharding/attn.py new file mode 100644 index 000000000000..de8d6f6ac4c5 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/attn.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + + +def get_local_heads(shard_rank: int, + num_shards: int, + n_heads_q: int, + n_heads_kv: Optional[int] = None) -> Tuple[int, int]: + """ + Helper to determine the number of local heads of a given shard. + + Args: + shard_rank (int): The rank of the shard. + num_shards (int): The total number of shards that attention is distributed over. + n_heads_q (int): The number of query heads. + n_heads_kv (int): The number of key/value heads. If not passed, it is assumed that + the number of query and key/value heads are the same. + """ + if n_heads_q < num_shards: + raise ValueError("There must be at least as many attention heads as there are shards.") + + if n_heads_kv is None or n_heads_kv == n_heads_q: + # MHA attention + base_heads = n_heads_q // num_shards + extra_heads = n_heads_q % num_shards + + if shard_rank < extra_heads: + return (base_heads + 1), (base_heads + 1) + else: + return base_heads, base_heads + else: + # GQA attention + if n_heads_q % n_heads_kv != 0: + raise ValueError("Must be an even ratio between query and key/value heads.") + + if n_heads_kv < num_shards and num_shards % n_heads_kv != 0: + raise ValueError( + "If splitting a group across multiple shards, we must be able to distribute the groups evenly.") + + if n_heads_kv >= num_shards and n_heads_kv % num_shards != 0: + raise ValueError("If parallelizing groups, must be able to evenly distribute them.") + + q_ratio = n_heads_q // n_heads_kv + + if n_heads_kv >= num_shards: + local_kv_heads = n_heads_kv // num_shards + local_q_heads = local_kv_heads * q_ratio + return local_q_heads, local_kv_heads + else: + group_sharding_size = num_shards // n_heads_kv + group_rank_idx = shard_rank % group_sharding_size + + base_heads = q_ratio // group_sharding_size + extra_heads = q_ratio % group_sharding_size + + if group_rank_idx < extra_heads: + return (base_heads + 1), 1 + else: + return base_heads, 1 diff --git a/deepspeed/inference/v2/model_implementations/sharding/attn_out.py b/deepspeed/inference/v2/model_implementations/sharding/attn_out.py new file mode 100644 index 000000000000..ce7c105531ea --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/attn_out.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_attn_out_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> Optional[torch.Tensor]: + """ + Utility method for sharding an attention output parameter. + """ + if len(param.shape) == 1: + # We will do the bias addition on the 0th rank only rather than scale the parameter and + # implicitly reconstruct this in the distributed reduce. + return param if shard_rank == 0 else None + + assert n_heads_kv is None or (n_heads_q is not None + and n_heads_kv is not None), "n_heads_kv should not be passed without n_heads_q" + + mha_sharding = n_heads_kv is None or n_heads_q == n_heads_kv + + if mha_sharding: + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards, granularity=head_size) + else: + assert param.shape[0] == head_size * n_heads_q, "GQA param shape is not correct" + + # 32 KV heads, 16 shards for example + even_kv_sharding = n_heads_kv % num_shards == 0 + + # 8 KV heads, 16 shards for example + even_kv_distribution = num_shards % n_heads_kv == 0 + + assert even_kv_sharding or even_kv_distribution, "No partitioning algorithm for this yet." + + if even_kv_sharding: + # Same as original sharding scenario + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards, granularity=head_size) + else: + # We will first do a sharding on the KV and Q to map to the one KV shard per group of Q. + q_sharding_degree = num_shards // n_heads_kv + + kv_head = shard_rank // q_sharding_degree + + q_sharding_rank = shard_rank % q_sharding_degree + q_factor = n_heads_q // n_heads_kv + + q_chunk = param[..., q_factor * kv_head * head_size:q_factor * (kv_head + 1) * head_size] + + return shard_param(q_chunk, + ShardingType.INNER_DIMENSION, + q_sharding_rank, + q_sharding_degree, + granularity=head_size) + + +def attn_out_in_features(out_features: int, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> int: + """ + Helper to calculate the expected output projection dimension of a QKV projection matrix. + + Args: + in_features (int): The model dimension. + shard_rank (int): Which rank to return the corresponding size for. + num_shards (int): The total number of shards the parameter is distributed across. + head_size (int): The size of each attention head. + n_heads_q (int): The number of query heads on the model. This only needs to be passed if the number + of query and key/value heads are different. If passed without n_heads_kv, default + MHA partitioning will be used. + n_heads_kv (int): The number of key and value heads on the model. This only needs to be passed + if the number of query and key/value heads are different. This argument cannot be passed without + also passing n_heads_q (we want to explicitly opt into GQA sharding). + """ + assert n_heads_kv is None or (n_heads_q is not None + and n_heads_kv is not None), "n_heads_kv should not be passed without n_heads_q" + + mha_sharding = n_heads_kv is None or n_heads_q == n_heads_kv + + if mha_sharding: + endpoints = get_shard_endpoints(out_features, shard_rank, num_shards, granularity=head_size) + return endpoints[1] - endpoints[0] + else: + if n_heads_kv >= num_shards: + assert n_heads_kv % num_shards == 0, "No partitioning algorithm for this yet." + n_local_groups = n_heads_kv // num_shards + group_size = n_heads_q // n_heads_kv + + return n_local_groups * head_size * group_size + else: + assert num_shards % n_heads_kv == 0, "No partitioning algorithm for this yet." + q_split_degree = num_shards // n_heads_kv + q_split_rank = shard_rank % q_split_degree + split_granularity = (n_heads_q // n_heads_kv) * head_size + + q_endpoints = get_shard_endpoints(split_granularity, q_split_rank, q_split_degree, granularity=head_size) + + return q_endpoints[1] - q_endpoints[0] diff --git a/deepspeed/inference/v2/model_implementations/sharding/embedding.py b/deepspeed/inference/v2/model_implementations/sharding/embedding.py new file mode 100644 index 000000000000..00d335768ae6 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/embedding.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_embedding_param(param: torch.Tensor, shard_rank: int, num_shards: int) -> torch.Tensor: + """ + Utility method for sharding an embedding parameter. + + Args: + param (torch.Tensor): The parameter to shard. Should be of shape [vocab_size, model_dim] + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + """ + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards) + + +def sharded_embedding_dim(embedding_size: int, shard_rank: int, num_shards: int) -> int: + """ + Utility method for getting the size of the embedding dimension of a sharded embedding. + + Args: + embedding_size (int): The size of the embedding. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + """ + start_idx, end_idx = get_shard_endpoints(embedding_size, shard_rank, num_shards) + return end_idx - start_idx diff --git a/deepspeed/inference/v2/model_implementations/sharding/mlp.py b/deepspeed/inference/v2/model_implementations/sharding/mlp.py new file mode 100644 index 000000000000..8abd0ff8622d --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/mlp.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from .types import ShardingType, DEFAULT_SHARD_GRANULARITY +from .utils import shard_param, get_shard_endpoints + + +def shard_mlp_1_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + gated: bool = False, + is_moe: bool = False) -> torch.Tensor: + """ + Utility method for sharding an MLP 1 parameter. Both biases and weights are supported, as well + as for fused weights for MoE. + + Args: + param (torch.Tensor): The parameter to shard. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + gated (bool): Whether or not the parameter is from a gated MLP. + """ + bias_dims = 2 if is_moe else 1 + + if gated: + return shard_param(param, + ShardingType.OUTER_DIMENSION, + shard_rank, + num_shards, + granularity=DEFAULT_SHARD_GRANULARITY * 2, + bias_dims=bias_dims) + else: + return shard_param(param, ShardingType.OUTER_DIMENSION, shard_rank, num_shards, bias_dims=bias_dims) + + +def shard_mlp_2_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + is_moe: bool = False) -> Optional[torch.Tensor]: + """ + Utility method for sharding an MLP 2 parameter. + + Args: + param (torch.Tensor): The parameter to shard. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + is_moe (bool): Whether or not the parameter is from a MoE model. + """ + bias_dim_size = 2 if is_moe else 1 + + if len(param.shape) == bias_dim_size: + # We will do the bias addition on the 0th rank only rather than scale the parameter and + # implicitly reconstruct this in the distributed reduce. + return param if shard_rank == 0 else None + + return shard_param(param, ShardingType.INNER_DIMENSION, shard_rank, num_shards) + + +def sharded_intermediate_dim(intermediate_size: int, num_shards: int, shard_rank: int) -> int: + """ + Utility method for getting the size of the intermediate dimension of a sharded MLP. + + Args: + intermediate_size (int): The size of the intermediate dimension. + num_shards (int): The total number of shards the parameter is distributed across. + shard_rank (int): Which shard of the partitioned tensor to return. + """ + endpoints = get_shard_endpoints(intermediate_size, shard_rank, num_shards) + return endpoints[1] - endpoints[0] diff --git a/deepspeed/inference/v2/model_implementations/sharding/qkv.py b/deepspeed/inference/v2/model_implementations/sharding/qkv.py new file mode 100644 index 000000000000..2b6d7f40836e --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/qkv.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_qkv_param(param: torch.Tensor, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> Optional[torch.Tensor]: + """ + Utility method for sharding a QKV parameter. Both biases and weights are supported. It is assumed + that the layout of the parameter is such that all Q heads, all K heads, and all V heads + are contiguous with respect to each other. + + Args: + param (torch.Tensor): The parameter to shard. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + head_size (int): The size of each head. + n_heads_q (int): The number of query heads. This only needs to be passed if the number + of query and key/value heads are different. If passed without n_heads_kv, default + MHA partitioning will be used. + n_heads_kv (int): The number of key/value heads. This only needs to be passed if the number + of query and key/value heads are different. This argument should not be passed without + n_heads_q (we want to explicitly opt into GQA sharding). + """ + if n_heads_kv is not None and n_heads_q is None: + raise ValueError("n_heads_kv should not be passed without n_heads_q") + + if n_heads_q is None: + # Guaranteed to be in MHA + if param.shape[0] // 3 % head_size != 0: + raise ValueError("MHA param shape is not correct") + n_heads_q = param.shape[0] // head_size // 3 + mha_sharding = True + else: + mha_sharding = n_heads_q == n_heads_kv + + if n_heads_q < num_shards: + raise ValueError("There must be at least as many query heads as there are shards.") + + if mha_sharding: + return shard_param(param, + ShardingType.OUTER_DIMENSION, + shard_rank, + num_shards, + num_concatenated_matrices=3, + granularity=head_size) + else: + if n_heads_q % n_heads_kv != 0: + raise ValueError("Must be an even ratio between query and key/value heads.") + + if param.shape[0] != head_size * (n_heads_q + 2 * n_heads_kv): + raise ValueError("GQA param shape is not correct") + + # 32 KV heads, 16 shards for example + if n_heads_kv >= num_shards and n_heads_kv % num_shards != 0: + raise ValueError("Currently do not support uneven partitioning of KV heads for GQA.") + + # 8 KV heads, 16 shards for example + if n_heads_kv < num_shards and num_shards % n_heads_kv != 0: + raise ValueError("Currently do not support distributing KV heads across different numbers of shards.") + else: + even_kv_sharding = n_heads_kv >= num_shards + + if param is None: + return None + + q_param = param[:head_size * n_heads_q] + kv_param = param[head_size * n_heads_q:] + + if even_kv_sharding: + # This is equivalent to the original sharding algorithm since n_heads_q = C * n_heads_kv. + # If n_heads_kv % num_shards == 0, then n_heads_q % num_shards == 0. + q_param = shard_param(q_param, ShardingType.OUTER_DIMENSION, shard_rank, num_shards, granularity=head_size) + kv_param = shard_param(kv_param, + ShardingType.OUTER_DIMENSION, + shard_rank, + num_shards, + num_concatenated_matrices=2, + granularity=head_size) + return torch.cat([q_param, kv_param], dim=0) + else: + # We will first do a sharding on the KV and Q to map to the one KV shard per group of Q. + q_sharding_degree = num_shards // n_heads_kv + + kv_head = shard_rank // q_sharding_degree + k_param = kv_param[kv_head * head_size:(kv_head + 1) * head_size] + v_param = kv_param[(n_heads_kv + kv_head) * head_size:(n_heads_kv + kv_head + 1) * head_size] + + q_sharding_rank = shard_rank % q_sharding_degree + q_factor = n_heads_q // n_heads_kv + + q_chunk = q_param[q_factor * kv_head * head_size:q_factor * (kv_head + 1) * head_size] + + q_param = shard_param(q_chunk, + ShardingType.OUTER_DIMENSION, + q_sharding_rank, + q_sharding_degree, + granularity=head_size) + + return torch.cat([q_param, k_param, v_param], dim=0) + + +def qkv_out_features(in_features: int, + shard_rank: int, + num_shards: int, + head_size: int, + n_heads_q: Optional[int] = None, + n_heads_kv: Optional[int] = None) -> int: + """ + Helper to calculate the expected output projection dimension of a QKV projection matrix. + + Args: + in_features (int): The model dimension. + shard_rank (int): Which rank to return the corresponding size for. + num_shards (int): The total number of shards the parameter is distributed across. + head_size (int): The size of each head. + n_heads_q (int): The number of query heads. This only needs to be passed if the number + of query and key/value heads are different. If passed without n_heads_kv, default + MHA partitioning will be used. + n_heads_kv (int): The number of key/value heads. This only needs to be passed if the number + of query and key/value heads are different. This argument cannot be passed without also + passing n_heads_q (we want to explicitly opt into GQA sharding). + """ + if n_heads_kv is not None and n_heads_q is None: + raise ValueError("n_heads_kv should not be passed without n_heads_q") + + mha_sharding = n_heads_kv is None or n_heads_q == n_heads_kv + + if n_heads_q is not None and in_features != head_size * n_heads_q: + raise ValueError("in_features is not consistent with n_heads_q and head_size") + + if mha_sharding: + endpoints = get_shard_endpoints(in_features, shard_rank, num_shards, granularity=head_size) + return (endpoints[1] - endpoints[0]) * 3 + else: + if n_heads_kv >= num_shards: + if n_heads_kv % num_shards != 0: + raise ValueError("The KV heads must be evenly distributed across the shards.") + + n_local_groups = n_heads_kv // num_shards + group_size = n_heads_q // n_heads_kv + + return n_local_groups * head_size * (2 + group_size) + else: + if num_shards % n_heads_kv != 0: + raise ValueError("A shared KV head must always partition across the same number of shards.") + + q_split_degree = num_shards // n_heads_kv + q_split_rank = shard_rank % q_split_degree + split_granularity = (n_heads_q // n_heads_kv) * head_size + + q_endpoints = get_shard_endpoints(split_granularity, q_split_rank, q_split_degree, granularity=head_size) + + return (q_endpoints[1] - q_endpoints[0]) + 2 * head_size diff --git a/deepspeed/inference/v2/model_implementations/sharding/types.py b/deepspeed/inference/v2/model_implementations/sharding/types.py new file mode 100644 index 000000000000..01dce0db523a --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/types.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum + +DEFAULT_SHARD_GRANULARITY = 32 + + +class ShardingType(Enum): + # Inner dimension sharding corresponds to splitting the Tensor along the K-dimension + # of a matrix multiplication. This would be used for attention_output or MLP2. + INNER_DIMENSION = 1 + + # Outer dimension sharding corresponds to splitting the Tensor along the N-dimension + # of a matrix multiplication. This would be used for the QKV and MLP1 projections. + OUTER_DIMENSION = 0 diff --git a/deepspeed/inference/v2/model_implementations/sharding/unembed.py b/deepspeed/inference/v2/model_implementations/sharding/unembed.py new file mode 100644 index 000000000000..6cc771969ad9 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/unembed.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from .types import ShardingType +from .utils import shard_param, get_shard_endpoints + + +def shard_unembed_param(param: torch.Tensor, shard_rank: int, num_shards: int) -> torch.Tensor: + """ + Utility method for sharding an unembed parameter. We shard unembeddings on the vocab dimension + with the expectation of an all-gather to produce the full results. + + TODO(cmikeh2): Really ideal would be if MII could have access to the comm and we would do + an A2A and sharded sampling. + + Args: + param (torch.Tensor): The parameter to shard. Should be of shape [vocab_size, model_dim] + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + + Returns: + torch.Tensor: The sharded parameter of shape [sharded_vocab_size, model_dim] + """ + return shard_param(param, ShardingType.OUTER_DIMENSION, shard_rank, num_shards, granularity=1) + + +def sharded_unembed_dim(vocab_size: int, shard_rank: int, num_shards: int) -> int: + """ + Utility method for determining the sharded vocab size of a sharded unembed parameter. + + Args: + vocab_size (int): The size of the vocabulary. + shard_rank (int): Which shard of the partitioned tensor to return. + num_shards (int): The total number of shards the parameter is distributed across. + """ + start_idx, end_idx = get_shard_endpoints(vocab_size, shard_rank, num_shards, granularity=1) + return end_idx - start_idx diff --git a/deepspeed/inference/v2/model_implementations/sharding/utils.py b/deepspeed/inference/v2/model_implementations/sharding/utils.py new file mode 100644 index 000000000000..fd0eb51873f8 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/sharding/utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import torch + +from .types import ShardingType, DEFAULT_SHARD_GRANULARITY + + +def get_shard_endpoints(dim_size: int, + shard_rank: int, + num_shards: int, + granularity: int = DEFAULT_SHARD_GRANULARITY) -> Tuple[int, int]: + """ + Given a dimension to shard with size dim_size, return the start and end indices of the slice + that belong to the given rank. + + The typical use of this is as an internal helper function, so see if there is a higher level + API that better suits the application. + + Args: + dim_size (int): The size of the dimension to shard. + shard_rank (int): The rank of the shard to return. + num_shards (int): Total number of shards the dimension will be distributed across. + granularity (int): The minimum alignment of the shard endpoints. This is used to support + non-even head counts as well as align dimensions to cleaner GEMM boundaries. + """ + assert dim_size % granularity == 0, "Dimension size must be divisible by granularity" + + total_chunks = dim_size // granularity + base_chunks_per_rank = total_chunks // num_shards + remainder_chunks = total_chunks % num_shards + + start_chunk_id = shard_rank * base_chunks_per_rank + min(shard_rank, remainder_chunks) + end_chunk_id = start_chunk_id + base_chunks_per_rank + (1 if shard_rank < remainder_chunks else 0) + + return start_chunk_id * granularity, end_chunk_id * granularity + + +def shard_param(param: Optional[torch.Tensor], + shard_mode: ShardingType, + shard_rank: int, + num_shards: int, + num_concatenated_matrices: int = 1, + granularity: int = 32, + bias_dims: int = 1) -> torch.Tensor: + """ + Utility for sharding a parameter. This will return the slice of the parameter that should + exist on the given shard_rank given the sharding configuration. The workflow here is + to find the minimum bounded Tensor to shard, get the slicing endpoints, and then concatenate + as needed. + + The typical use of this is as an internal helper function, so see if there is a higher level + API that better suits the application. + + Args: + param (torch.Tensor): The parameter to shard. + shard_mode (ShardingType): The type of sharding to apply. See ShardingType for more context. + shard_rank (int): The rank of the shard to return. + num_shards (int): Total number of shards the parameter will be distrbuted across. + num_concatenated_matrices (int): The number of matrices that have been concatenated together in the original + parameter. An example of this is a fused QKV projection matrix, where the `num_concatenated_matrices` + argument would be 3. + granularity (int): The minimum alignment of the shard endpoints. For attention projection matrices, this + should be set to the head size to support non-even sharding. + bias_dims (int): The number of dimensions that are considered bias dimensions. This is used to support + sharding of MoE and non-MoE biases on the same codepath. + """ + assert shard_rank < num_shards, "Shard rank must be less than num_shards" + + # Easier to hide this inside of the sharding logic than to add checks in every model + # implementation. + if param is None: + return None + + if num_shards == 1: + # Trivial case of no sharding. + return param + + if shard_mode == ShardingType.OUTER_DIMENSION: + + def get_matrices(dim_idx: int) -> torch.Tensor: + dim_size = param.size(dim_idx) // num_concatenated_matrices + start_channel_id, end_channel_id = get_shard_endpoints(dim_size, shard_rank, num_shards, granularity) + return torch.chunk(param, num_concatenated_matrices, dim=dim_idx), start_channel_id, end_channel_id + + if param.ndim == bias_dims: + # Special case for bias parameters. + matrices, start_channel_id, end_channel_id = get_matrices(dim_idx=-1) + return torch.cat([mat[..., start_channel_id:end_channel_id] for mat in matrices], dim=-1) + else: + # General case for weight parameters. This assumes MoE parameters are stored in the format of + # [num_experts, out_features, in_features] + matrices, start_channel_id, end_channel_id = get_matrices(dim_idx=-2) + return torch.cat([mat[..., start_channel_id:end_channel_id, :] for mat in matrices], dim=-2) + + elif shard_mode == ShardingType.INNER_DIMENSION: + dim_size = param.size(-1) // num_concatenated_matrices + start_channel_id, end_channel_id = get_shard_endpoints(dim_size, shard_rank, num_shards, granularity) + matrices = torch.chunk(param, num_concatenated_matrices, dim=-1) + return torch.cat([mat[..., start_channel_id:end_channel_id] for mat in matrices], dim=-1) diff --git a/deepspeed/inference/v2/modules/__init__.py b/deepspeed/inference/v2/modules/__init__.py new file mode 100644 index 000000000000..917c1599de2e --- /dev/null +++ b/deepspeed/inference/v2/modules/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from . import implementations +from . import interfaces +from .module_registry import ConfigBundle diff --git a/deepspeed/inference/v2/modules/configs/__init__.py b/deepspeed/inference/v2/modules/configs/__init__.py new file mode 100644 index 000000000000..19b9fb99ddea --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attention_configs import (DSSelfAttentionConfig, PositionalEmbeddingType, MaskingType) +from .embedding_config import DSEmbeddingsConfig +from .linear_config import DSLinearConfig +from .moe_config import DSMoEConfig +from .norm_config import DSNormConfig, NormTypeEnum +from .unembed_config import DSUnembedConfig diff --git a/deepspeed/inference/v2/modules/configs/attention_configs.py b/deepspeed/inference/v2/modules/configs/attention_configs.py new file mode 100644 index 000000000000..bcdc3d2613d5 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/attention_configs.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum +from typing import Dict + +from ...inference_utils import DtypeEnum +from ...modules.ds_module import DSModuleConfig + + +class PositionalEmbeddingType(Enum): + + # No positional embeddings + none = "none" + + # Rotary positional embeddings - every half + rotate_half = "rotate_half" + + # Rotary positional embeddings - every other + rotate_every_other = "rotate_every_other" + + # Alibi + alibi = "alibi" + + +class MaskingType(Enum): + + # No masking + none = "none" + + # Causal masking + causal = "causal" + + # Local masking + local = "local" + + # Symmetric masking (this is a 1D tensor mask) + symmetric = "symmetric" + + # Arbitrary masking (this would correspond to a 2D tensor mask) + asymmetric = "asymmetric" + + +class DSSelfAttentionConfig(DSModuleConfig): + """ + Config class for attention. + """ + + # Number of query attention heads on this shard + n_heads_q: int + + # Number of KV attention heads on this shard + n_heads_kv: int + + # Size of each attention head + head_size: int + + # Max number of sequences that may compose a ragged batch + max_sequences: int + + # Scale factor for attention scores + scale_factor: float = 1.0 + + # Input data type + input_dtype: DtypeEnum = DtypeEnum.fp16 + + # Output data type + output_dtype: DtypeEnum = DtypeEnum.fp16 + + # Masking type + masking_type: MaskingType = MaskingType.causal + + # Masking args + masking_args: Dict = {} + + # Positional embedding type + positional_embedding_type: PositionalEmbeddingType = PositionalEmbeddingType.none + + # Positional embedding args + positional_embedding_args: Dict = {} diff --git a/deepspeed/inference/v2/modules/configs/embedding_config.py b/deepspeed/inference/v2/modules/configs/embedding_config.py new file mode 100644 index 000000000000..2486c5986e95 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/embedding_config.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +from ...inference_utils import DtypeEnum, NormTypeEnum +from ...modules.ds_module import DSModuleConfig +""" +Trying to define the space we need to support here right now: + +Types of embeddings I've found so far: + 1. Token embedding + 2. Position embedding + 3. Token type embedding + 4. LN + +GPTNeo: 1, 2, 3 (shared with 1) +GPTNeoX: 1 +GPTJ: 1, 3 +LLaMA: 1 +BERT: 1, 2, 3, 4 +GPT2: 1, 2, 3 (shared with 1) + +Sidebar for OPT: +OPT: 1, 2 +1 may not actually project to the actual hidden dimension according to the raw +code, but for the model configs we care about it does. +2 has a weird offset associated with it that the others do not. +""" + + +class DSEmbeddingsConfig(DSModuleConfig): + """ + Config class for DSEmbeddings. + """ + + residual_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type the module should use for its output. + """ + + embedding_dim: int + """ + Dimensionality of the embedding projections. + """ + + positional_embedding: bool = False + """ + Whether the module should expect a positional embedding matrix. The shape of this + matrix should be of shape [max_seq_len + positional_offset, embedding_dim] + """ + + positional_offset: int = 0 + """ + Whether the linearized token IDs should be offset by a certain amount. For an example + of this, see the OPT model implementation. + """ + + use_token_type: bool = False + """ + Whether the module should expect a token type embedding matrix. + """ + + output_normalization: Optional[NormTypeEnum] = None + """ + If a the output of the embedding module should be normalized, specify here. See + ``inference.inference_utils.NormTypeEnum`` for supported values. + """ diff --git a/deepspeed/inference/v2/modules/configs/linear_config.py b/deepspeed/inference/v2/modules/configs/linear_config.py new file mode 100644 index 000000000000..40fe0773aeee --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/linear_config.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import ActivationType, DtypeEnum +from ...modules.ds_module import DSModuleConfig + + +class DSLinearConfig(DSModuleConfig): + """ + Config class for DSLinearBase. + """ + + in_channels: int + """ + Number of input channels + """ + + out_channels: int + """ + Number of output channels. NOTE: If this linear layer is using a gated activation function, + the value for ``out_channels`` passed here should refer to the number of channels after + gating (i.e., the expected weight shape before transformations will be ``[out_channels * 2, in_channels]``). + """ + + activation: ActivationType = ActivationType.IDENTITY + """ + The activation function for this layer. See :class:`deepspeed.inference.inference_utils.ActivationType` for + supported activation functions. + """ + + input_dtype: DtypeEnum = DtypeEnum.fp16 + """ + The data type of the input tensor. See :class:`deepspeed.inference.inference_utils.DtypeEnum` for supported + data types. + """ + + output_dtype: DtypeEnum = DtypeEnum.fp16 + """ + The data type of the output tensor. See :class:`deepspeed.inference.inference_utils.DtypeEnum` for supported + data types. + """ diff --git a/deepspeed/inference/v2/modules/configs/moe_config.py b/deepspeed/inference/v2/modules/configs/moe_config.py new file mode 100644 index 000000000000..1a88d54af19f --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/moe_config.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import ActivationType, DtypeEnum +from ...modules.ds_module import DSModuleConfig + + +class DSMoEConfig(DSModuleConfig): + """ + Config class for DSMoEBase + """ + + model_dim: int + """ + Size of input activation. + """ + + intermediate_features: int + """ + Size of intermediate activation. Specifically, this is the number of input features + in the second linear layer. Depending on the activation function, the output of the first + linear layer may have increased dimensionality. + """ + + n_experts: int + """ + Number of experts. + """ + + top_k: int = 1 + """ + top-k gating function (like top-1 or top-2) + """ + + input_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type for the input activations. + """ + + output_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type for the output activations. + """ + + activation: ActivationType = ActivationType.IDENTITY + """ + Activation function of the first MLP1 + """ diff --git a/deepspeed/inference/v2/modules/configs/norm_config.py b/deepspeed/inference/v2/modules/configs/norm_config.py new file mode 100644 index 000000000000..358982253756 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/norm_config.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import DtypeEnum, NormTypeEnum +from ...modules.ds_module import DSModuleConfig + + +class DSNormConfig(DSModuleConfig): + """ + Config class for both DSPreLN and DSPostLN. + """ + + # Type of normalization + type: NormTypeEnum + + # Number of channels in the model embedding + channels: int + + # Data type of the residual input/outputs (we assume the residual must + # be the same data type for the entire model). + residual_dtype: DtypeEnum = DtypeEnum.fp16 + + # Data type of the hidden states input + input_dtype: DtypeEnum = DtypeEnum.fp16 + + # Data type of the hidden states output + output_dtype: DtypeEnum = DtypeEnum.fp16 + + # Epsilon value for numerical stability + eps: float = 1e-5 diff --git a/deepspeed/inference/v2/modules/configs/unembed_config.py b/deepspeed/inference/v2/modules/configs/unembed_config.py new file mode 100644 index 000000000000..ea4cc3cc99c1 --- /dev/null +++ b/deepspeed/inference/v2/modules/configs/unembed_config.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ...inference_utils import DtypeEnum, NormTypeEnum +from ...modules.ds_module import DSModuleConfig +from typing import Optional + + +class DSUnembedConfig(DSModuleConfig): + """ + Config class for DSUnembed + """ + + dtype: DtypeEnum = DtypeEnum.fp16 + """ + Expected data type. + """ + + norm_type: Optional[NormTypeEnum] = None + """ + Whether the input to the unembed is normalized prior to the unembedding projection. + """ + + model_dim: int + """ + Model embedding size. + """ + + max_sequences: int + """ + Max sequences composing the ragged batch. + """ + + vocab_size: int + """ + Local vocab size (the full vocab size may have been sharded across model parallel ranks) + """ diff --git a/deepspeed/inference/v2/modules/ds_module.py b/deepspeed/inference/v2/modules/ds_module.py new file mode 100644 index 000000000000..2a6d294f3266 --- /dev/null +++ b/deepspeed/inference/v2/modules/ds_module.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractstaticmethod +from typing import Any, Dict, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + + +class DSModuleConfig(DeepSpeedConfigModel): + + max_tokens: int + + +class DSModuleBase(torch.nn.Module, ABC): + """ + Base class for all DeepSpeed Inference modules. This class establishes + the basic attributes of a DSModule. Only abstract functionality modules should inherit + directly from this class, not specific implementations. + """ + + @abstractstaticmethod + def name() -> str: + """ + Return a memorable, human-readable name for this module. + + This will be used as a key in custom inference configurations and should only + be implemented by the children of functionality modules. + """ + ... + + @abstractstaticmethod + def config_class() -> Type[DSModuleConfig]: + """ + Return the associated config class for this module. + + This should be implemented (along with the config class) by an abstract functionality + module. + """ + ... + + @abstractstaticmethod + def supports_config(config: DSModuleConfig) -> bool: + """ + Return whether or not this module supports the given config. + + This should be implemented by the children of functionality modules and should report + whether it would be feasible to instantiate this module with the given config. + """ + ... + + def __init__(self, config: DSModuleConfig, implementation_config: Dict[str, Any] = {}) -> None: + """ + Initialize the module with the given config. + """ + super().__init__() + self._config = config + self._implementation_config = implementation_config diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py new file mode 100644 index 000000000000..b89e95c0d834 --- /dev/null +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..config_v2 import RaggedInferenceEngineConfig +from ..inference_utils import NormTypeEnum + +from .module_registry import ConfigBundle +from ..modules.configs import ( + DSEmbeddingsConfig, + DSLinearConfig, + DSMoEConfig, + DSNormConfig, + DSSelfAttentionConfig, + DSUnembedConfig, +) +from ..modules.interfaces import ( + DSEmbeddingBase, + DSEmbeddingRegistry, + DSLinearBase, + DSLinearRegistry, + DSMoEBase, + DSMoERegistry, + DSPostNormBase, + DSPostNormRegistry, + DSPreNormBase, + DSPreNormRegistry, + DSSelfAttentionBase, + DSSelfAttentionRegistry, + DSUnembedBase, + DSUnembedRegistry, +) + + +def instantiate_attention(attention_config: DSSelfAttentionConfig, + engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: + """ + Choose an appropriate attention implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + attention_config (DSSelfAttentionConfig): Configuration for the attention module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + An attention module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="dense_blocked_attention", config=attention_config) + return DSSelfAttentionRegistry.instantiate_config(config) + + +def instantiate_embed(embed_config: DSEmbeddingsConfig, engine_config: RaggedInferenceEngineConfig) -> DSEmbeddingBase: + """ + Choose an appropriate embedding implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + embed_config (DSEmbeddingsConfig): Configuration for the embedding module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + An embedding module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="ragged_embedding", config=embed_config) + return DSEmbeddingRegistry.instantiate_config(config) + + +def instantiate_linear(linear_config: DSLinearConfig, engine_config: RaggedInferenceEngineConfig) -> DSLinearBase: + """ + Choose an appropriate linear implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + linear_config (DSLinearConfig): Configuration for the linear module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A linear module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="blas_fp_linear", config=linear_config) + return DSLinearRegistry.instantiate_config(config) + + +def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngineConfig) -> DSMoEBase: + """ + Choose an appropriate MoE implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + moe_config (DSMoEConfig): Configuration for the MoE module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A MoE module implementing the given configuration. + """ + + moe_type = "cutlass_multi_gemm_moe" + + if moe_type == "cutlass_multi_gemm_moe": + # TODO: Get this off an engine config + implementation_config = { + "weight_dtype": moe_config.input_dtype, + } + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="cutlass_multi_gemm_moe", + config=moe_config, + implementation_config=implementation_config) + return DSMoERegistry.instantiate_config(config) + + +def instantiate_post_norm(norm_config: DSNormConfig, engine_config: RaggedInferenceEngineConfig) -> DSPostNormBase: + """ + Choose an appropriate post-norm implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + norm_config (DSNormConfig): Configuration for the post-norm module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A post-norm module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="cuda_post_ln", config=norm_config) + return DSPostNormRegistry.instantiate_config(config) + + +def instantiate_pre_norm(norm_config: DSNormConfig, engine_config: RaggedInferenceEngineConfig) -> DSPreNormBase: + """ + Choose an appropriate pre-norm implementation based on the given configurations. Currently, + this will select between two CUDA implementations, one for LayerNorm and one for RMSNorm. + + Arguments: + norm_config (DSNormConfig): Configuration for the pre-norm module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + A pre-norm module implementing the given configuration. + """ + if NormTypeEnum(norm_config.type) == NormTypeEnum.LayerNorm: + module_name = "cuda_pre_ln" + elif NormTypeEnum(norm_config.type) == NormTypeEnum.RMSNorm: + module_name = "cuda_pre_rms" + + config = ConfigBundle(name=module_name, config=norm_config) + return DSPreNormRegistry.instantiate_config(config) + + +def instantiate_unembed(unembed_config: DSUnembedConfig, engine_config: RaggedInferenceEngineConfig) -> DSUnembedBase: + """ + Choose an appropriate unembedding implementation based on the given configurations. This + method is currently a stub, but as more implementations may be developed we can centralize + the logic for choosing between them here. + + Arguments: + unembed_config (DSUnembedConfig): Configuration for the unembed module. + engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. + + Returns: + An unembed module implementing the given configuration. + """ + + # Currently, we only have one implementation, so we just return it. + config = ConfigBundle(name="ragged_unembed", config=unembed_config) + return DSUnembedRegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/__init__.py b/deepspeed/inference/v2/modules/implementations/__init__.py new file mode 100644 index 000000000000..1b500a9a0b5a --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Imports for registering ops +from .attention import * +from .linear import * +from .post_norm import * +from .pre_norm import * +from .embedding import * +from .unembed import * +from .moe import * diff --git a/deepspeed/inference/v2/modules/implementations/attention/__init__.py b/deepspeed/inference/v2/modules/implementations/attention/__init__.py new file mode 100644 index 000000000000..1ff593ac1368 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/attention/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .dense_blocked_attention import DSDenseBlockedAttention diff --git a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py new file mode 100644 index 000000000000..bb482f0c58d6 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import DtypeEnum +from ....kernels.ragged_ops import ( + AtomBuilder, + BlockedFlashAttn, + BlockedRotaryEmbeddings, + BlockedTrainedRotaryEmbeddings, + get_q_block_size, + get_kv_block_size, + LinearBlockedKVCopy, +) +from ....ragged import RaggedBatchWrapper, split_kv +from deepspeed.ops.op_builder import RaggedUtilsBuilder + +from ...interfaces import DSSelfAttentionBase, DSSelfAttentionRegistry +from ...configs import DSSelfAttentionConfig, PositionalEmbeddingType, MaskingType + +try: + from functools import cached_property +except ImportError: + + def cached_property(func): + return property(func) + + +@DSSelfAttentionRegistry.register_module +class DSDenseBlockedAttention(DSSelfAttentionBase): + """ + Self attention implementation for dense, blocked self attention. + """ + + @staticmethod + def name() -> str: + return 'dense_blocked_attention' + + @staticmethod + def supports_config(config: DSSelfAttentionConfig) -> bool: + + if config.input_dtype != config.output_dtype: + return False + + if DtypeEnum(config.input_dtype) not in (DtypeEnum.fp16, DtypeEnum.bf16): + return False + + if PositionalEmbeddingType(config.positional_embedding_type) not in [ + PositionalEmbeddingType.none, PositionalEmbeddingType.rotate_half + ]: + return False + + if MaskingType(config.masking_type) != MaskingType.causal: + return False + + return True + + def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[str, Any]) -> None: + """ + Create the Attention DSModule. + + Args: + config (DSSelfAttentionConfig): The self attention config for all attention DSModules. + implementation_config (Dict[str, Any]): The implementation config for this DSModule may + contain a `trained_freqs` key. If passed, the implementation will expect a `trained_freqs` + tensor in the `forward` method and will not synthesize the frequencies internally. + """ + super().__init__(config, implementation_config) + + embed_type = PositionalEmbeddingType(config.positional_embedding_type) + if embed_type == PositionalEmbeddingType.none: + self._kv_copy = LinearBlockedKVCopy(self._config.head_size, self._config.n_heads_q, + self._config.n_heads_kv, self._config.input_dtype) + elif embed_type == PositionalEmbeddingType.rotate_half: + use_trained_freqs = "trained_freqs" in self._config.positional_embedding_args and self._config.positional_embedding_args[ + "trained_freqs"] + if use_trained_freqs: + self._kv_copy = BlockedTrainedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, + self._config.n_heads_kv, self._config.input_dtype) + else: + self._kv_copy = BlockedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, + self._config.n_heads_kv, self._config.input_dtype) + + self._softmax_scale = self._config.scale_factor + + # TODO(cmikeh2): Attention kernel gets created here. + self._attn_kernel = BlockedFlashAttn(self._config.head_size, self._config.input_dtype) + self._atom_builder = AtomBuilder() + + self.model_dim = self._config.head_size * self._config.n_heads_q + self._output = torch.empty((self._config.max_tokens, self._config.head_size * self._config.n_heads_q), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + # TODO(cmikeh2): Pre-allocate storage buffer for the attention atoms. + self._max_atoms = self._config.max_sequences + self._atoms = torch.empty((self._max_atoms, 8), dtype=torch.int32, device=get_accelerator().current_device()) + + alloc_func = RaggedUtilsBuilder().load().allocate_fast_host_buffer + self._atoms_shadow = alloc_func(self._atoms) + self._cur_atoms = 0 + + @cached_property + def kv_block_size(self) -> int: + """ + Return preferred granulatity for blocked KV-cache implementation. + """ + return get_kv_block_size(self._config.head_size) + + @cached_property + def q_block_size(self) -> int: + """ + Property to calculate blocking granularity for the query dimension. + This has no impact on the KV-cache structure, but will affect the + number of attention atoms associated with a batch. + """ + return get_q_block_size(self._config.head_size) + + def build_atoms(self, ragged_batch: RaggedBatchWrapper) -> None: + """ + Build the atoms for the attention kernel. + + Args: + ragged_batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata. + """ + host_atoms, n_atoms = self._atom_builder(self._atoms_shadow, ragged_batch, self.q_block_size, + self.kv_block_size) + + self._cur_atoms = n_atoms + self._atoms[:n_atoms].copy_(host_atoms[:n_atoms], non_blocking=True) + + def forward(self, + q_k_v: torch.Tensor, + kv_cache: torch.Tensor, + batch: RaggedBatchWrapper, + inv_freqs: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward implementation. + + Args: + q_k_v (torch.Tensor): Query/Key/Value projection Tensor of shape + [n_heads, (n_heads_q + 2 * n_heads_kv) * head_size]. + kv_cache (torch.Tensor): Blocked persistent cache of shape + [2, batch, block_size, n_heads_kv, head_size]. + batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata. + inv_freqs (Optional[torch.Tensor]): The inverse frequencies for the rotary embeddings if they + have been modified from synthesizable values. + """ + if inv_freqs is not None: + self._kv_copy(kv_cache, q_k_v, batch, inv_freqs) + else: + self._kv_copy(kv_cache, q_k_v, batch) + + q = q_k_v[:, :self._config.head_size * self._config.n_heads_q] + output = empty_from(self._output, q.shape) + k_cache, v_cache = split_kv(kv_cache) + + self._attn_kernel(output, q, k_cache, v_cache, self._atoms[:self._cur_atoms], self._softmax_scale) + + return output diff --git a/deepspeed/inference/v2/modules/implementations/embedding/__init__.py b/deepspeed/inference/v2/modules/implementations/embedding/__init__.py new file mode 100644 index 000000000000..5458a855abf4 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/embedding/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .ragged_embedding import DSRaggedEmbedding diff --git a/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py b/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py new file mode 100644 index 000000000000..6782bcae81c8 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/embedding/ragged_embedding.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import DtypeEnum +from ....kernels.ragged_ops import RaggedEmbeddingKernel +from ....ragged import RaggedBatchWrapper +from ...interfaces import DSEmbeddingBase, DSEmbeddingRegistry +from ...configs import DSEmbeddingsConfig + + +@DSEmbeddingRegistry.register_module +class DSRaggedEmbedding(DSEmbeddingBase): + + @staticmethod + def name(): + return 'ragged_embedding' + + @staticmethod + def supports_config(config: DSEmbeddingsConfig) -> bool: + + if DtypeEnum(config.residual_dtype) not in [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]: + return False + + if config.use_token_type: + return False + + if config.output_normalization != None: + return False + + try: + _ = RaggedEmbeddingKernel(config.residual_dtype, torch.int32, config.embedding_dim) + except ValueError: + return False + + return True + + def __init__(self, config: DSEmbeddingsConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self.embed_offset = self._config.positional_offset + + # TODO(cmikeh2): How do we want to avoid the int32 vs int64 issue? + self._ragged_embed = RaggedEmbeddingKernel(self._config.residual_dtype, torch.int32, + self._config.embedding_dim) + + self._output = torch.empty((self._config.max_tokens, self._config.embedding_dim), + dtype=self._config.residual_dtype, + device=get_accelerator().current_device()) + + @property + def output(self) -> torch.Tensor: + return self._output + + def forward(self, + ragged_batch: RaggedBatchWrapper, + word_embeddings: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + ragged_batch (RaggedBatchWrapper): The input ids and associated ragged batch metadata. + word_embeddings (torch.Tensor): The word embedding table + """ + output = empty_from(self._output, (ragged_batch.tensor_toks, self._config.embedding_dim)) + self._ragged_embed(output, + ragged_batch, + word_embeddings, + position_embed_weight=position_embeddings, + position_embed_offset=self.embed_offset) + return output diff --git a/deepspeed/inference/v2/modules/implementations/linear/__init__.py b/deepspeed/inference/v2/modules/implementations/linear/__init__.py new file mode 100644 index 000000000000..5acdc69dba7c --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .blas_fp_linear import BlasFPLinear +from .cutlass_fp_linear import DSCutlassFPLinear diff --git a/deepspeed/inference/v2/modules/implementations/linear/blas_fp_linear.py b/deepspeed/inference/v2/modules/implementations/linear/blas_fp_linear.py new file mode 100644 index 000000000000..cbb8801e7193 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/blas_fp_linear.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import is_gated +from ....kernels.core_ops import ( + BlasLibLinear, + CUDABiasActivation, + CUDAGatedActivation, +) + +from ...interfaces import DSLinearBase, DSLinearRegistry +from ...configs import DSLinearConfig + + +@DSLinearRegistry.register_module +class BlasFPLinear(DSLinearBase): + """ + Linear DSModule based on BLAS library and standalone bias + activation kernel implementation. + """ + + @staticmethod + def name(): + return 'blas_fp_linear' + + @staticmethod + def supports_config(config: DSLinearConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: + return False + + if is_gated(config.activation): + try: + _ = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + else: + try: + _ = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + + return True + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self._linear_impl = BlasLibLinear(self._config.input_dtype) + + if is_gated(config.activation): + self._is_gated = True + self._act_fn = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + self._double_buffer = torch.empty((config.max_tokens, config.out_channels * 2), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + else: + self._is_gated = False + self._act_fn = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + + self._output = torch.empty((config.max_tokens, config.out_channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + return param.to(self._config.input_dtype) + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + + output = empty_from(self._output, (hidden_states.shape[0], self._config.out_channels)) + + if self._is_gated: + staging_output = empty_from(self._double_buffer, (hidden_states.shape[0], self._config.out_channels * 2)) + self._linear_impl(staging_output, hidden_states, w) + self._act_fn(output, staging_output, b) + else: + self._linear_impl(output, hidden_states, w) + self._act_fn(output, b) + + return output + + @property + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + return self._output diff --git a/deepspeed/inference/v2/modules/implementations/linear/cutlass_fp_linear.py b/deepspeed/inference/v2/modules/implementations/linear/cutlass_fp_linear.py new file mode 100644 index 000000000000..a3704eaa82a2 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/cutlass_fp_linear.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import ActivationType +from ....kernels.core_ops import CUDAGatedActivation + +from ...interfaces import DSLinearBase, DSLinearRegistry +from ...configs import DSLinearConfig + + +@DSLinearRegistry.register_module +class DSCutlassFPLinear(DSLinearBase): + """ + Linear DSModule based on CUTLASS floating point kernel implementation. + """ + + @staticmethod + def name(): + return 'cutlass_fp_linear' + + @staticmethod + def supports_config(config: DSLinearConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: + return False + + return True + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + # TODO: Load kernel + + if config.activation == ActivationType.GEGLU: + self._geglu = CUDAGatedActivation(config.out_channels, config.output_dtype, ActivationType.GEGLU) + self._activation_int = torch.empty((config.max_tokens, config.out_channels * 2), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + self._output = torch.empty((config.max_tokens, config.out_channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + return param.to(self._config.input_dtype) + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + + output = empty_from(self._output, (hidden_states.shape[0], self._config.out_channels)) + + if self._config.activation == ActivationType.GEGLU: + intermediate = empty_from(self._activation_int, (hidden_states.shape[0], self._config.out_channels * 2)) + self._linear_impl(intermediate, hidden_states, w, b) + self._geglu(output, intermediate) + else: + self._linear_impl(output, hidden_states, w, b) + + return output + + @property + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + return self._output diff --git a/deepspeed/inference/v2/modules/implementations/moe/__init__.py b/deepspeed/inference/v2/modules/implementations/moe/__init__.py new file mode 100644 index 000000000000..053ad5da7746 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cutlass_multi_gemm import DSMultiGemmMoE diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py new file mode 100644 index 000000000000..fb2388c450f0 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -0,0 +1,225 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import ActivationType +from ....kernels.core_ops import BlasLibLinear +from ....kernels.ragged_ops import ( + MoEGather, + MoEScatter, + RaggedTop1Gating, +) +from ....ragged import RaggedBatchWrapper + +from ...interfaces import DSMoEBase, DSMoERegistry +from ...configs import DSMoEConfig +from ....kernels.cutlass_ops import MoEGEMM + + +@DSMoERegistry.register_module +class DSMultiGemmMoE(DSMoEBase): + """ + MoE implementation based on the CUTLASS multi-GEMM. + """ + + @staticmethod + def name(): + return 'cutlass_multi_gemm_moe' + + @staticmethod + def supports_config(config: DSMoEConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: + return False + + if config.top_k != 1: + return False + + if config.activation in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]: + # Currently not supporting gated activations in MoE + return False + + return True + + def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + # Convenience variables for frequently accessed items. + self.max_tokens = self._config.max_tokens + self.n_experts = self._config.n_experts + self.intermediate_dim = self._config.intermediate_features + + self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=config.activation) + self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) + + self._gate_proj = BlasLibLinear(self._config.input_dtype) + self._top_1_gate = RaggedTop1Gating(config.input_dtype) + self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) + self._moe_gather = MoEGather(config.input_dtype, config.model_dim) + + self._create_buffers() + + def _create_buffers(self): + + # Gating buffers + self._logits = torch.empty((self._config.max_tokens, self.n_experts), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_counts = torch.empty((self.n_experts, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._scores = torch.empty((self._config.max_tokens, ), + dtype=torch.float32, + device=get_accelerator().current_device()) + self._assignments = torch.empty((self._config.max_tokens, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._offsets = torch.empty((self._config.max_tokens, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # Scatter buffers + self._moe_input = torch.empty((self._config.max_tokens, self._config.model_dim), + dtype=self._config.input_dtype, + device=get_accelerator().current_device()) + self._expert_cumsum = torch.empty((self._config.n_experts, ), + dtype=torch.int64, + device=get_accelerator().current_device()) + self._mapped_slots = torch.empty((self._config.max_tokens, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + + # GEMM Buffers + self._intermediate = torch.empty((self._config.max_tokens, self._config.intermediate_features), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + self._output_unordered = torch.empty((self._config.max_tokens, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + # Gather buffer + self._output = torch.empty((self._config.max_tokens, self._config.model_dim), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + def transform_gate_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Ensures gate param is going to match the activation data type. + """ + return param.to(self._config.input_dtype) + + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + if len(param.shape) == 3: + return param.permute(0, 2, 1).contiguous() + else: + return param + + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + param = param.to(self._config.input_dtype) + + if len(param.shape) == 3: + return param.permute(0, 2, 1).contiguous() + else: + return param + + @property + def output(self) -> torch.Tensor: + return self._output + + def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Helper function to isolate the logit for gating. This will take the hidden states + and produce the metadata + tensors for the CUTLASS ragged GEMMs. If the input has + been padded for CG, this will strip the padding for MoE. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [n_tokens, model_dim]. + batch_metadata (RaggedBatchWrapper): Batch metadata for the hidden states. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input, the cumsum of the offsets (for the MoE kernels themselves), the scores, and the mapped slots (to recover the original order of the tokens) + """ + + # Get views on the buffers for gating + logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) + scores = empty_from(self._scores, (hidden_states.shape[0], )) + assignments = empty_from(self._assignments, (hidden_states.shape[0], )) + offsets = empty_from(self._offsets, (hidden_states.shape[0], )) + mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], )) + moe_input = empty_from(self._moe_input, (hidden_states.shape[0], self._moe_input.shape[-1])) + + self._gate_proj(logits, hidden_states, gate_w) + self._expert_counts.zero_() + self._top_1_gate(self._expert_counts, scores, assignments, offsets, logits, batch_metadata) + self._moe_scatter(moe_input, self._expert_cumsum, mapped_slots, hidden_states, self._expert_counts, + assignments, offsets) + + return moe_input, self._expert_cumsum, scores, mapped_slots + + def forward(self, + hidden_states: torch.Tensor, + batch_metadata: RaggedBatchWrapper, + gate_w: torch.Tensor, + mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, + mlp_1_b: Optional[torch.Tensor] = None, + mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + MoE forward pass built on top of CUTLASS multi-GEMM. + + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is [batch, seq_len, model_dim]. + gate_w (torch.Tensor): Gate weight tensor. Expected shape is [num_experts, model_dim]. + """ + + moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) + + # Get views on the buffers for GEMM + intermediate = empty_from(self._intermediate, (hidden_states.shape[0], self._intermediate.shape[-1])) + output_unordered = empty_from(self._output_unordered, + (hidden_states.shape[0], self._output_unordered.shape[-1])) + output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) + + self._mlp_1( + intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) + + self._mlp_2( + output_unordered, + intermediate, + mlp_2_w, + expert_cumsum, + mlp_2_b, + ) + + self._moe_gather(output, output_unordered, scores, mapped_slots, self._expert_counts) + return output diff --git a/deepspeed/inference/v2/modules/implementations/moe/gate_fn.py b/deepspeed/inference/v2/modules/implementations/moe/gate_fn.py new file mode 100644 index 000000000000..9eceaab156e4 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/gate_fn.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import torch.nn.functional as F +from torch import Tensor + +from typing import Tuple + +#TODO(cmikeh2): DELETE + + +@torch.jit.script +def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor: + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[-1] + # to(torch.int64) works around a bug in torch.onnx.export: + # it should cast k to int64 when converting torch.topk but it doesn't. + capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64) + if capacity < min_capacity: + capacity = min_capacity.to(torch.int64) + return capacity + + +@torch.jit.script +def _top_idx(source, k): + return torch.topk(source, k=k, dim=0)[1] + + +def top1gating(logits: Tensor, + capacity_factor: float, + min_capacity: int, + drop_tokens: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + # everything is in fp32 in this function + gates = F.softmax(logits, dim=1) + + capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) + + # Create a mask for 1st's expert per token + indices1_s = torch.argmax(gates, dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # gating decisions + exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + + assert logits.shape[ + 0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + top_idx = _top_idx(mask1, capacity) + + mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + + indices_mask = mask1.sum(dim=1) * num_experts - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + gates1_s = (gates * mask1).sum(dim=1) + + return indices1_s, gates1_s diff --git a/deepspeed/inference/v2/modules/implementations/moe/test.py b/deepspeed/inference/v2/modules/implementations/moe/test.py new file mode 100644 index 000000000000..b714366d32ec --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/moe/test.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.modules.interfaces import DSMoERegistry +from deepspeed.inference.modules.configs import DSMoEConfig +from deepspeed.inference.modules.module_registry import ConfigBundle + + +def run_multi_gemm_with_gating(inputs, gate_weight, moe_weight1, moe_bias1, moe_weight2): + config = DSMoEConfig(model_dim=4096, intermediate_features=4096, n_experts=64, max_tokens=128) + moe = DSMoERegistry.instantiate_config( + ConfigBundle(name='cutlass_multi_gemm_moe', + config=config, + implementation_config={ + "weight_dtype": torch.bfloat16, + "transpose_weight": True, + "min_capacity": 8, + "capacity_factor": 1.0 + })) + out = moe(inputs, gate_weight, moe_weight1, moe_weight2, moe_bias1) + return out + + +a = torch.randn( + 128, + 4096, +).bfloat16().to(get_accelerator().current_device()) +weight1 = torch.randn(64, 4096, 4096).bfloat16().to(get_accelerator().current_device()) +bias1 = torch.randn(64, 4096).bfloat16().to(get_accelerator().current_device()) +weight2 = torch.randn(64, 4096, 4096).bfloat16().to(get_accelerator().current_device()) +gate_weight = torch.randn(64, 4096).bfloat16().to(get_accelerator().current_device()) + +out = run_multi_gemm_with_gating(a, gate_weight, weight1, bias1, weight2) +print(out) diff --git a/deepspeed/inference/v2/modules/implementations/post_norm/__init__.py b/deepspeed/inference/v2/modules/implementations/post_norm/__init__.py new file mode 100644 index 000000000000..653a2fe4fb5b --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/post_norm/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_post_ln import DSPostLNCUDAModule diff --git a/deepspeed/inference/v2/modules/implementations/post_norm/cuda_post_ln.py b/deepspeed/inference/v2/modules/implementations/post_norm/cuda_post_ln.py new file mode 100644 index 000000000000..b30c5b937ed2 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/post_norm/cuda_post_ln.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ...interfaces import DSPostNormBase, DSPostNormRegistry +from ...configs import DSNormConfig +from ....kernels.core_ops.cuda_layer_norm.cuda_post_ln import CUDAFPPostLN +from ....allocator import empty_from + + +@DSPostNormRegistry.register_module +class DSPostLNCUDAModule(DSPostNormBase): + + @staticmethod + def name(): + return 'cuda_post_ln' + + @staticmethod + def supports_config(config: DSNormConfig): + if len(set([config.residual_dtype, config.input_dtype, config.output_dtype])) != 1: + return False + + try: + _ = CUDAFPPostLN(config.channels, config.residual_dtype) + except ValueError: + return False + return True + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + self._fp_post_ln = CUDAFPPostLN(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + + self._output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + return param.to(self._config.input_dtype) + + def forward(self, residual: torch.Tensor, hidden_in: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Since the CUDA FP only supports all data types being the same, we will alias the residual + with our output. + """ + self._residual_output = empty_from(self._output, residual.shape) + self._fp_post_ln(residual, residual, hidden_in, gamma, beta) + return residual, residual diff --git a/deepspeed/inference/v2/modules/implementations/pre_norm/__init__.py b/deepspeed/inference/v2/modules/implementations/pre_norm/__init__.py new file mode 100644 index 000000000000..12605f13f955 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/pre_norm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_pre_ln import DSPreLNCUDAModule +from .cuda_pre_rms import DSPreRMSCUDAModule diff --git a/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_ln.py b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_ln.py new file mode 100644 index 000000000000..f50f4a3d2db6 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_ln.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ...interfaces import DSPreNormBase, DSPreNormRegistry +from ...configs import DSNormConfig, NormTypeEnum +from ....kernels.core_ops.cuda_layer_norm.cuda_pre_ln import CUDAFPPreLN +from ....kernels.core_ops.cuda_layer_norm.cuda_ln import CUDAFPLN +from ....allocator import empty_from + + +@DSPreNormRegistry.register_module +class DSPreLNCUDAModule(DSPreNormBase): + + @staticmethod + def name(): + return 'cuda_pre_ln' + + @staticmethod + def supports_config(config: DSNormConfig): + type = NormTypeEnum(config.type) + if type != NormTypeEnum.LayerNorm: + return False + + if len(set([config.residual_dtype, config.input_dtype, config.output_dtype])) != 1: + return False + + try: + _ = CUDAFPPreLN(config.channels, config.residual_dtype) + except ValueError: + return False + return True + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + self._fp_pre_ln = CUDAFPPreLN(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + self._fp_ln = CUDAFPLN(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + + # Buffers for the hidden output (residual is updated in-place) + self._hidden_output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + return param.to(self._config.input_dtype) + + def forward(self, residual: torch.Tensor, hidden_in: Optional[torch.Tensor], gamma: torch.Tensor, + beta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Since the CUDA FP only supports all data types being the same, we will alias the residual + with our output. + + If hidden_in is None, that means we do not need to perform the residual add and will + only return the hidden output modified. + """ + hidden_out = empty_from(self._hidden_output, residual.shape) + if hidden_in is None: + self._fp_ln(hidden_out, residual, gamma, beta) + else: + self._fp_pre_ln(residual, hidden_out, residual, hidden_in, gamma, beta) + return residual, hidden_out diff --git a/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_rms.py b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_rms.py new file mode 100644 index 000000000000..7aeea4b2d386 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/pre_norm/cuda_pre_rms.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from ...interfaces import DSPreNormBase, DSPreNormRegistry +from ...configs import DSNormConfig, NormTypeEnum +from ....kernels.core_ops import CUDARMSNorm, CUDARMSPreNorm +from ....allocator import empty_from + + +@DSPreNormRegistry.register_module +class DSPreRMSCUDAModule(DSPreNormBase): + + @staticmethod + def name(): + return 'cuda_pre_rms' + + @staticmethod + def supports_config(config: DSNormConfig): + type = NormTypeEnum(config.type) + if type != NormTypeEnum.RMSNorm: + return False + + if len(set([config.residual_dtype, config.input_dtype, config.output_dtype])) != 1: + return False + + try: + # Only need to check one since the support matrix for the two rms kernels is the same + _ = CUDARMSPreNorm(config.channels, config.residual_dtype) + except ValueError: + return False + return True + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + self._fp_rms = CUDARMSNorm(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + self._fp_rms_pre = CUDARMSPreNorm(self._config.channels, self._config.residual_dtype, epsilon=self._config.eps) + + # Buffers for both the hidden and residual outputs + self._hidden_output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + self._residual_output = torch.empty((config.max_tokens, config.channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + return param.to(self._config.input_dtype) + + def forward(self, + residual: torch.Tensor, + hidden_in: Optional[torch.Tensor], + gamma: torch.Tensor, + beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Since the CUDA FP only supports all data types being the same, we will alias the residual + with our output. + + If hidden_in is None, that means we do not need to perform the residual add and will + only return the hidden output modified. + """ + assert beta is None, "Beta is not supported for RMSNorm" + + hidden_out = empty_from(self._hidden_output, residual.shape) + if hidden_in is None: + self._fp_rms(hidden_out, residual, gamma) + residual_out = residual + else: + residual_out = empty_from(self._residual_output, residual.shape) + self._fp_rms_pre(residual_out, hidden_out, residual, hidden_in, gamma) + return residual_out, hidden_out diff --git a/deepspeed/inference/v2/modules/implementations/unembed/__init__.py b/deepspeed/inference/v2/modules/implementations/unembed/__init__.py new file mode 100644 index 000000000000..4a5fd24d518b --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/unembed/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .ragged_unembed import DSRaggedUnembed diff --git a/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py new file mode 100644 index 000000000000..40d70cbd4df7 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/unembed/ragged_unembed.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from ....allocator import empty_from +from ....inference_utils import DtypeEnum +from ....kernels.core_ops import CUDAFPLN, BlasLibLinear, CUDARMSNorm +from ....kernels.ragged_ops import RaggedLogitsGather +from ....ragged import RaggedBatchWrapper +from ...interfaces import DSUnembedBase, DSUnembedRegistry +from ...configs import DSUnembedConfig + + +@DSUnembedRegistry.register_module +class DSRaggedUnembed(DSUnembedBase): + """ + Ragged unembedding implementation. This implementation will gather only the last token + of each sequence in the ragged inflight batch and calculate the logits only for those rows. + """ + + @staticmethod + def name(): + return 'ragged_unembed' + + @staticmethod + def supports_config(config: DSUnembedConfig): + + if DtypeEnum(config.dtype) not in [DtypeEnum.fp16, DtypeEnum.bf16, DtypeEnum.fp32]: + return False + + try: + _ = RaggedLogitsGather(config.model_dim, config.dtype) + except ValueError: + return False + + if config.norm_type == 'rms_norm': + try: + _ = CUDARMSNorm(config.model_dim, config.dtype) + except ValueError: + return False + elif config.norm_type == 'layer_norm': + try: + _ = CUDAFPLN(config.model_dim, config.dtype) + except ValueError: + return False + + return True + + def __init__(self, config: DSUnembedConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self._logits_gather = RaggedLogitsGather(config.model_dim, self._config.dtype) + + if self._config.norm_type == 'layer_norm': + self._norm = CUDAFPLN(self._config.model_dim, self._config.dtype) + elif self._config.norm_type == 'rms_norm': + self._norm = CUDARMSNorm(self._config.model_dim, self._config.dtype) + else: + self._norm = None + + self._linear = BlasLibLinear(self._config.dtype) + + self._intermediate = torch.empty((self._config.max_sequences, self._config.model_dim), + dtype=self._config.dtype, + device=get_accelerator().current_device()) + + self._output = torch.empty((self._config.max_sequences, self._config.vocab_size), + dtype=self._config.dtype, + device=get_accelerator().current_device()) + + @property + def output(self) -> torch.Tensor: + return self._output + + def forward(self, + hidden_states: torch.Tensor, + vocab_embedding: torch.Tensor, + ragged_metadata: RaggedBatchWrapper, + gamma: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Return final model logits. + + Args: + hidden_states (torch.Tensor): The hidden states from the model. This is the output of the + final layer of the model. + vocab_embedding (torch.Tensor): The vocab embedding table. + raged_metadata (RaggedBatchWrapper): The ragged batch metadata. + gamma (Optional[torch.Tensor]): The gamma tensor for normalization. + beta (Optional[torch.Tensor]): The beta tensor for normalization. + """ + + cut_down_hidden_states = empty_from(self._intermediate, + (ragged_metadata.current_sequences, self._config.model_dim)) + self._logits_gather(cut_down_hidden_states, hidden_states, ragged_metadata) + + if self._config.norm_type == 'rms_norm': + if gamma is None: + raise ValueError('RMS Normalization enabled but gamma not provided.') + self._norm(cut_down_hidden_states, cut_down_hidden_states, gamma) + elif self._config.norm_type == 'layer_norm': + if gamma is None or beta is None: + raise ValueError('Normalization enabled but gamma and/or beta not provided.') + self._norm(cut_down_hidden_states, cut_down_hidden_states, gamma, beta) + + output = empty_from(self._output, (ragged_metadata.current_sequences, self._config.vocab_size)) + self._linear(output, cut_down_hidden_states, vocab_embedding) + + return output diff --git a/deepspeed/inference/v2/modules/interfaces/__init__.py b/deepspeed/inference/v2/modules/interfaces/__init__.py new file mode 100644 index 000000000000..13b556789e4e --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .attention_base import DSSelfAttentionRegistry, DSSelfAttentionBase +from .embedding_base import DSEmbeddingRegistry, DSEmbeddingBase +from .linear_base import DSLinearRegistry, DSLinearBase +from .moe_base import DSMoERegistry, DSMoEBase +from .post_norm_base import DSPostNormRegistry, DSPostNormBase +from .pre_norm_base import DSPreNormRegistry, DSPreNormBase +from .unembed_base import DSUnembedRegistry, DSUnembedBase diff --git a/deepspeed/inference/v2/modules/interfaces/attention_base.py b/deepspeed/inference/v2/modules/interfaces/attention_base.py new file mode 100644 index 000000000000..c67dc033f92a --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/attention_base.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Tuple, Type + +import torch + +from ...ragged import RaggedBatchWrapper +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSSelfAttentionConfig + + +class DSSelfAttentionBase(DSModuleBase): + """ + Base mixin for all attention modules. The interface represented by this module + is broadly: + + output = attention(query_key_value, + Optional[kv_cache], + Optional[attention_mask], + Optional[attention_bias]) + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSSelfAttentionConfig + + def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @property + def kv_block_size(self) -> int: + """ + Return preferred granulatity for blocked KV-cache implementation. + """ + raise NotImplementedError() + + @property + def q_block_size(self) -> int: + """ + Property to calculate blocking granularity for the query dimension. + This has no impact on the KV-cache structure, but will affect the + number of attention atoms associated with a batch. + """ + raise NotImplementedError() + + def build_atoms(self, ragged_batch: RaggedBatchWrapper) -> None: + """ + Build the atoms for this module. This is not a strict requirement for the class, + so this method is a no-op by default rather than abstract. + """ + pass + + def forward(self, + q_k_v: torch.Tensor, + kv_cache: torch.Tensor, + batch: RaggedBatchWrapper, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + inv_freqs: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters: + q_k_v (torch.Tensor): Query, key, and value tensors. Expected shape is: + [ + batch, + seq_len, + 2 * self._config.n_heads_kv + self._config.n_heads_q, + self._config.head_size + ]. + kv_cache (Optional[torch.Tensor]): Key and value cache tensor. Expected shape is + [ + 2, + batch, + kv_cache_len, + self._config.n_heads_kv, + self._config.head_size + ]. If None, cache is disabled. The `kv_cache_len` dimension does not need to + be contiguous (it should expand stride by `max_out_tokens`). + batch (RaggedBatchWrapper): Ragged batch metadata. + attention_mask (Optional[torch.Tensor]): Attention mask tensor. If None, masking is + disabled. This will defer to the config in the case of conflicting information. + This means if the config class is implying causal attention, the mask will be ignored. + attention_bias (Optional[torch.Tensor]): Attention bias tensor. If None, bias is disabled. + """ + raise NotImplementedError() + + +class DSSelfAttentionRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSSelfAttentionBase diff --git a/deepspeed/inference/v2/modules/interfaces/embedding_base.py b/deepspeed/inference/v2/modules/interfaces/embedding_base.py new file mode 100644 index 000000000000..8078013e36b6 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/embedding_base.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ...ragged import RaggedBatchWrapper +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSEmbeddingsConfig + + +class DSEmbeddingBase(DSModuleBase): + """ + Base mixin for embedding modules. The interface represented by this module is: + + hidden_out = embedding(input_ids) + + position_embedding(position_ids) + + token_type_embedding(token_type_ids) + with optional normalization. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSEmbeddingsConfig + + def __init__(self, config: DSEmbeddingsConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + def transform_param(self, embed_param: torch.Tensor) -> torch.Tensor: + """ + Perform any necessary transformations on an embedding parameter. This module assumes + that all embedding parameters would require the same set of transformations. + + Parameters: + embed_param (torch.Tensor): Embedding parameter. Shape is of [vocab_size, hidden_size] + """ + raise NotImplementedError() + + @property + @abstractmethod + def output(self) -> torch.Tensor: + """ + Pre-allocated output Tensor. This currently needs to be exposed for gather operations + on the output. + + TODO(cmikeh2): This is not ideal. We need a better abstraction for this, such as giving + access to the inference comm object to the DSModule. + """ + raise NotImplementedError() + + def forward(self, + ragged_batch: RaggedBatchWrapper, + word_embeddings: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + token_type_embeddings: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + ragged_batch (torch.Tensor): Ragged batch of token ids + associated metadata. + word_embeddings (torch.Tensor): Word embeddings. + position_embeddings (torch.Tensor): Position embeddings. If passed, IDs will be + inferred from the ragged batch itself. + token_type_ids (torch.Tensor): Token type ids. + token_type_embeddings (torch.Tensor): Token type embeddings. + + Returns: + torch.Tensor: Hidden states. This should be the sum of the relevant + encodings for the model. + """ + raise NotImplementedError() + + +class DSEmbeddingRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSEmbeddingBase diff --git a/deepspeed/inference/v2/modules/interfaces/linear_base.py b/deepspeed/inference/v2/modules/interfaces/linear_base.py new file mode 100644 index 000000000000..bcaad6fe269a --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/linear_base.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSLinearConfig + + +class DSLinearBase(DSModuleBase): + """ + Base mixin for all Linear modules. The interface represented by this module + is: + + hidden_out = activation(hidden_in * weight + bias) + + The format and dtype of the weight and bias tensors are not defined and implementations + may compress as necessary. Must support a bias. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSLinearConfig + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @abstractmethod + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Perform any necessary transformations of the parameters of this module. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + ... + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + hidden_states (torch.Tensor): Hidden states tensor. Expected shape is either + [batch, seq_len, in_channels] or [batch, in_channels]. + + Returns: + torch.Tensor: Output tensor. Tensor should have same number of dimensions as + input tensor. + """ + raise NotImplementedError() + + @property + @abstractmethod + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + ... + + +class DSLinearRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSLinearBase diff --git a/deepspeed/inference/v2/modules/interfaces/moe_base.py b/deepspeed/inference/v2/modules/interfaces/moe_base.py new file mode 100644 index 000000000000..cc80ca55f60a --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/moe_base.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSMoEConfig + + +class DSMoEBase(DSModuleBase): + """ + Base mixing for MoE modules. The interface represented by this module is: + + expert_assignments = gate(hidden_states) + intermediate = ragged_linear(hidden_states, expert_assignments) + output = ragged_linear(intermediate, expert_assignments) + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSMoEConfig + + def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @abstractmethod + def transform_gate_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Perform any necessary transformations of the gate parameter. + + Args: + param (torch.Tensor): gate_w (shape: [num_experts, model_dim]) + """ + ... + + @abstractmethod + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Perform any necessary transformations of the parameter. The specific component + being transformed should be inferred from the shape of the parameter. + + Args: + param (torch.Tensor): One of either mlp_1_w, mlp_1_b + """ + ... + + @abstractmethod + def transform_moe_mlp_2_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Perform any necessary transformations of the parameter. The specified component being + transformed should be inferred from the shape of the parameter. This interface is + separate from transform_moe_1_param because the two components may have identical + shapes. + + Args: + param (torch.Tensor): One of either mlp_2_w or mlp_2_b + """ + ... + + def forward(self, + hidden_states: torch.Tensor, + gate_w: torch.Tensor, + mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, + mlp_1_b: Optional[torch.Tensor] = None, + mlp_2_b: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError() + + @property + @abstractmethod + def output(self) -> torch.Tensor: + """ + Returns the pre-allocated, padded output Tensor. + """ + ... + + +class DSMoERegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSMoEBase diff --git a/deepspeed/inference/v2/modules/interfaces/post_norm_base.py b/deepspeed/inference/v2/modules/interfaces/post_norm_base.py new file mode 100644 index 000000000000..c2a6bf69de8a --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/post_norm_base.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Tuple, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..configs.norm_config import DSNormConfig +from ..module_registry import DSModuleRegistryBase + + +class DSPostNormBase(DSModuleBase): + """ + Base MixIn for all Post-Normalization modules. The interface represented by this + module is: + + residual, hidden_out = norm(residual + hidden_in) + + If residual and hidden_out are the same data type, then they may alias each other. + Furthermore, residual should be updated in-place. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSNormConfig + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + @abstractmethod + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Transform a gamma/beta parameter. It is assumed that both transformations are + the same. + + Parameters: + param (torch.Tensor): Gamma or beta parameter. + """ + ... + + def forward(self, + residual: torch.Tensor, + hidden_states: torch.Tensor, + gamma: torch.Tensor, + beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters: + residual (torch.Tensor): Residual tensor. + hidden_states (torch.Tensor): Hidden states tensor. + + Returns: + (torch.Tensor, torch.Tensor): Tuple of residual and hidden states. + Hidden states may alias with residual. + """ + raise NotImplementedError() + + +class DSPostNormRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSPostNormBase diff --git a/deepspeed/inference/v2/modules/interfaces/pre_norm_base.py b/deepspeed/inference/v2/modules/interfaces/pre_norm_base.py new file mode 100644 index 000000000000..7d8b4ebf1587 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/pre_norm_base.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import abstractmethod +from typing import Any, Dict, Optional, Tuple, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..ds_module import DSModuleBase +from ..configs.norm_config import DSNormConfig +from ..module_registry import DSModuleRegistryBase + + +class DSPreNormBase(DSModuleBase): + """ + Base mixin for all Pre-Normalization modules. The interface represented by this module + is: + + if hidden_in is not None: + residual_out = residual + hidden_in + else: + residual_out = residual + + hidden_out = normalize(residual_out) + return residual_out, hidden_out + + Residual should be updated in-place. + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSNormConfig + + def __init__(self, config: DSNormConfig, implementation_config: Dict[str, Any]): + super().__init__(config, implementation_config) + + @abstractmethod + def transform_param(self, param: torch.Tensor) -> torch.Tensor: + """ + Transform a gamma/beta parameter. It is assumed that both transformations are + the same. + + Parameters: + param (torch.Tensor): Gamma or beta parameter. + """ + ... + + def forward(self, + residual: torch.Tensor, + hidden_states: Optional[torch.Tensor], + gamma: torch.Tensor, + beta: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters: + residual (torch.Tensor): Residual tensor. + hidden_states (torch.Tensor): Hidden states tensor. + + Returns: + (torch.Tensor, torch.Tensor): Tuple of residual and hidden states. + """ + raise NotImplementedError() + + +class DSPreNormRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSPreNormBase diff --git a/deepspeed/inference/v2/modules/interfaces/unembed_base.py b/deepspeed/inference/v2/modules/interfaces/unembed_base.py new file mode 100644 index 000000000000..9eca6fcde768 --- /dev/null +++ b/deepspeed/inference/v2/modules/interfaces/unembed_base.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional, Type + +import torch + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ...ragged import RaggedBatchWrapper +from ..ds_module import DSModuleBase +from ..module_registry import DSModuleRegistryBase +from ..configs import DSUnembedConfig + + +class DSUnembedBase(DSModuleBase): + """ + Base mixin for unmebedding modules. The interface represented by this module is: + + if config.do_normalization + hidden = layer_norm(hidden) + logits = hidden @ projection + """ + + @staticmethod + def config_class() -> Type[DeepSpeedConfigModel]: + return DSUnembedConfig + + def __init__(self, config: DSUnembedConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + def forward(self, + hidden_states: torch.Tensor, + vocab_embedding: torch.Tensor, + ragged_metadata: RaggedBatchWrapper, + gamma: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward interface. Gamma and beta are optional parameters passed depending on + `self.config.do_normalization`. + + Args: + hidden_states (torch.Tensor): Hidden states of shape [tokens, model_dim] + vocab_embedding (torch.Tensor): Embedding matrix of shape [vocab_size, model_dim] + ragged_metadata (RaggedBatchWrapper): Metadata for the ragged batch. + gamma (Optional[torch.Tensor]): Gamma parameter for layer norm. + beta (Optional[torch.Tensor]): Beta parameter for layer norm. + + Returns: + torch.Tensor: Unembedded hidden states of shape [n_seqs, model_dim] + """ + raise NotImplementedError() + + +class DSUnembedRegistry(DSModuleRegistryBase): + registry: Dict = {} + + @staticmethod + def associated_class() -> Type[DSModuleBase]: + return DSUnembedBase diff --git a/deepspeed/inference/v2/modules/module_registry.py b/deepspeed/inference/v2/modules/module_registry.py new file mode 100644 index 000000000000..e04b8d734518 --- /dev/null +++ b/deepspeed/inference/v2/modules/module_registry.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from abc import ABC, abstractstaticmethod +from typing import Any, Dict, Type + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from .ds_module import DSModuleBase + + +class ConfigBundle(DeepSpeedConfigModel): + """ + A config bundle is a collection of configs that are used to instantiate a model implementation. + """ + name: str + config: DeepSpeedConfigModel + implementation_config: Dict[str, Any] = {} + + +class DSModuleRegistryBase(ABC): + """ + Class holding logic for tracking the DSModule implementations of a given interface. + """ + + @classmethod + def instantiate_config(cls, config_bundle: ConfigBundle) -> DSModuleBase: + """ + Given a DSModule key, attempt to instantiate + """ + if config_bundle.name not in cls.registry: + raise KeyError(f"Unknown DSModule: {config_bundle.name}, cls.registry={cls.registry}") + + target_implementation = cls.registry[config_bundle.name] + if not target_implementation.supports_config(config_bundle.config): + raise ValueError(f"Config {config_bundle.config} is not supported by {target_implementation}") + + return cls.registry[config_bundle.name](config_bundle.config, config_bundle.implementation_config) + + @abstractstaticmethod + def associated_class() -> Type[DSModuleBase]: + """ + Return the class associated with this registry. + """ + raise NotImplementedError("Must associated a DSModule class with its registry.") + + @classmethod + def register_module(cls, child_class: DSModuleBase) -> None: + """ + Register a module with this registry. + """ + if not issubclass(child_class, cls.associated_class()): + raise TypeError( + f"Can only register subclasses of {cls.associated_class()}, {child_class} does not inherit from {cls.associated_class()}" + ) + cls.registry[child_class.name()] = child_class + return child_class diff --git a/deepspeed/inference/v2/ragged/__init__.py b/deepspeed/inference/v2/ragged/__init__.py new file mode 100644 index 000000000000..3af09cff4be5 --- /dev/null +++ b/deepspeed/inference/v2/ragged/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .kv_cache import split_kv +from .manager_configs import ( + AllocationMode, + DSStateManagerConfig, + KVCacheConfig, + MemoryConfig, +) +from .ragged_manager import DSStateManager +from .ragged_wrapper import RaggedBatchWrapper +from .sequence_descriptor import DSSequenceDescriptor, PlaceholderSequenceDescriptor diff --git a/deepspeed/inference/v2/ragged/blocked_allocator.py b/deepspeed/inference/v2/ragged/blocked_allocator.py new file mode 100644 index 000000000000..7884d8cccb47 --- /dev/null +++ b/deepspeed/inference/v2/ragged/blocked_allocator.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Union + +import torch + + +class BlockedAllocator: + """ + Allocator class for managing which blocks are free/used in the + blocked KV-cache. This is a simple allocator that uses a linked list + to keep track of which blocks are free/used. The cost of allocation/deallocation + is O(blocks), where blocks is the number of blocks to allocate/deallocate. + + TODO(cmikeh2): Evaluate performance of this allocator and migrate + to C++ if necessary. + """ + # Number of blocks in the KV-cache(s). + _num_blocks: int + + # Array of blocks, where each element is the next block in the linked list. + _blocks: torch.Tensor + + # Index of the head of the linked list. + _head: int + + # Number of free blocks in the KV-cache. + _free_blocks: int + + def __init__(self, num_blocks: int) -> None: + """ + Initialize an allocator with `num_blocks` blocks. This requires at least + `num_blocks` * 4 bytes of host memory. + + Parameters: + num_blocks (int): The number of blocks to allocate. + """ + + if num_blocks < 1: + raise ValueError(f'Blocked KV-cache must have at least 1 block, provided {num_blocks}') + + self._num_blocks = num_blocks + self._blocks = torch.arange(1, num_blocks + 1, dtype=torch.int32, device='cpu', pin_memory=True) + self._head = 0 + self._free_blocks = num_blocks + + def allocate(self, num_blocks: int) -> torch.Tensor: + """ + Allocate a list of blocks from the associated KV-caches. This will + return `num_blocks` blocks from the KV-cache if they are available, + or raise an exception if there are not enough free blocks. + + Parameters: + num_blocks (int): The number of blocks to allocate. + + Returns: + List[int]: The list of blocks allocated. + """ + if num_blocks > self._free_blocks: + raise ValueError(f'Not enough free blocks in the KV-cache to allocate {num_blocks} blocks') + + allocated_blocks = torch.zeros(num_blocks, dtype=torch.int32) + for i in range(num_blocks): + allocated_blocks[i] = self._head + self._head = self._blocks[self._head].item() + self._blocks[allocated_blocks[i]] = -1 # Mark as used + self._free_blocks -= 1 + + return allocated_blocks + + def free(self, blocks: Union[Iterable[int], int]) -> None: + """ + Return a list of blocks to the free pool. If a single invalid block is provided (i.e., + one that is out of range of the allocator or is already free), then an exception is raised + and no blocks are freed. + + Parameters: + blocks (Union[Iterable[int], int]): The list of blocks to free. If only one block + is to be freed, this can be alone as an integer. + """ + if isinstance(blocks, int): + blocks = [blocks] + + for block in blocks: + # Parse all blocks for validity before mutating the list. + if block < 0 or block >= self._num_blocks: + raise ValueError(f'Invalid block {block} provided to free') + + if self._blocks[block] != -1: + raise ValueError(f'Block {block} is already free') + + for block in blocks: + self._blocks[block] = self._head + self._head = block + self._free_blocks += 1 + + @property + def free_blocks(self) -> int: + """ + Return the number of free blocks in the KV-cache. + """ + return self._free_blocks diff --git a/deepspeed/inference/v2/ragged/csrc/fast_host_buffer.cu b/deepspeed/inference/v2/ragged/csrc/fast_host_buffer.cu new file mode 100644 index 000000000000..31347636b50c --- /dev/null +++ b/deepspeed/inference/v2/ragged/csrc/fast_host_buffer.cu @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "ds_kernel_utils.h" +#include "fast_host_buffer.h" + +void* get_cuda_fast_buffer(int64_t size) +{ + void* buffer_ptr; + // Host allocation flags that should minimize the host -> accelerator copy latency + unsigned int alloc_flags = + cudaHostAllocPortable | cudaHostAllocMapped | cudaHostAllocWriteCombined; + + cudaHostAlloc(&buffer_ptr, size, alloc_flags); + return buffer_ptr; +} diff --git a/deepspeed/inference/v2/ragged/csrc/ragged_ops.cpp b/deepspeed/inference/v2/ragged/csrc/ragged_ops.cpp new file mode 100644 index 000000000000..0c5e8812c84c --- /dev/null +++ b/deepspeed/inference/v2/ragged/csrc/ragged_ops.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include + +#include "fast_host_buffer.h" + +/* +Similar to doing an empty_like to replicate a Tensor on the host, but will +attempt to optimize for faster host -> accelerator copies. Since this is on the critical +path for the forward pass, this should directly improve performance. +Allocates the shadow buffers for the input_ids, batch, seq and kv_ids tensors. + +Arguments: + device_mirror: A tensor on the accelerator that should be mirrored by the host. + +Returns: + A tensor on the host of the same size and datatype optimized for fast host -> accelerator +copies. +*/ +torch::Tensor allocate_fast_host_buffer(torch::Tensor device_mirror) +{ +#ifdef __HIP_PLATFORM_HCC__ + auto options = + torch::TensorOptions().device(torch::kCPU).pinned_memory(true).dtype(device_mirror.dtype()); + auto buffer = torch::empty(device_mirror.sizes(), options); +#else + + void* buffer_ptr = get_cuda_fast_buffer(device_mirror.numel() * device_mirror.element_size()); + + auto options = torch::TensorOptions().device(torch::kCPU).dtype(device_mirror.dtype()); + auto buffer = torch::from_blob(buffer_ptr, device_mirror.sizes(), options); +#endif + return buffer; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("allocate_fast_host_buffer", + &allocate_fast_host_buffer, + "Allocate a host mirror of an accelerator Tensor."); +} diff --git a/deepspeed/inference/v2/ragged/includes/fast_host_buffer.h b/deepspeed/inference/v2/ragged/includes/fast_host_buffer.h new file mode 100644 index 000000000000..81f24ed8fdaa --- /dev/null +++ b/deepspeed/inference/v2/ragged/includes/fast_host_buffer.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +/* +Wrapper around cudaHostAlloc with some specific flags. Returns a pointer to the +memory region of `size` bytes. +*/ +void* get_cuda_fast_buffer(int64_t size); diff --git a/deepspeed/inference/v2/ragged/kv_cache.py b/deepspeed/inference/v2/ragged/kv_cache.py new file mode 100644 index 000000000000..d743d6e92530 --- /dev/null +++ b/deepspeed/inference/v2/ragged/kv_cache.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import operator +from functools import reduce +from typing import Any, Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist +from deepspeed.comm.reduce_op import ReduceOp + +from deepspeed.accelerator import get_accelerator +from ..inference_utils import elem_size +from ..logging import inference_logger +from .blocked_allocator import BlockedAllocator +from .manager_configs import AllocationMode, KVCacheConfig, MemoryConfig + + +def split_kv(kv_cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split a KV cache instance into its key and value components. + + Parameters: + kv_cache (torch.Tensor): The KV-cache to split. This should be a 5D tensor with the + following shape: [num_blocks, block_size, 2, num_heads, head_size] + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The key and value components of the KV-cache. Both + tensors will have the shape [num_blocks, block_size, num_heads, head_size]. + """ + if kv_cache.ndim != 5: + raise ValueError(f"KV-cache must have 5 dimensions, got {kv_cache.ndim}.") + + return kv_cache[:, :, 0, :, :], kv_cache[:, :, 1, :, :] + + +class BlockedKVCache: + + _caches: torch.Tensor + """ + Backing storage for all KV caches. This is a 6D tensor with the following shape: + (num_caches, num_blocks, block_size, 2, num_heads, head_size) + """ + + _allocator: BlockedAllocator + """ + Block allocator for tracking cache usage. This manages the GPU cache. + """ + + _config: KVCacheConfig + """ + Configuration of the KV cache. See ``KVCacheConfig`` for more details. + """ + + def __init__(self, + config: KVCacheConfig, + memory_config: MemoryConfig, + mp_group: Optional[Any] = None, + offload: bool = False) -> None: + """ + Create a container that will maintain the storage and allocations for a set of + blocked KV-caches. + + Parameters: + config (KVCacheConfig): The configuration of the KV-cache. + slack (int): The amount of slack space to reserve in GPU memory for the cache. + enable_offload (bool): Whether to enable offloading of the cache to the host. + blocks (int): The number of blocks to pre-allocate for the cache. If this is set, + slack will be ignored. + """ + self._config = config + self._memory_config = memory_config + self._enable_offload = offload + + if self._enable_offload: + raise NotImplementedError("Offloading of KV-caches is not yet supported.") + + if AllocationMode(self._memory_config.mode) is AllocationMode.RESERVE: + per_block_footprint = reduce(operator.mul, self._config.cache_shape, self._config.block_size) + per_block_footprint *= 2 # for key and value + per_block_footprint *= elem_size(self._config.cache_dtype) + + # Perform a dummy nccl call before calculating available memory, on some systems (H100) we've observed higher memory allocations from NCCL + if dist.get_world_size(group=mp_group) > 1: + dummy_tensor = torch.tensor(0, dtype=torch.int32, device=get_accelerator().current_device()) + dist.all_reduce(dummy_tensor, op=ReduceOp.MIN, group=mp_group) + + get_accelerator().empty_cache() + available_kv_memory = get_accelerator().available_memory() - self._memory_config.size + total_memory = get_accelerator().total_memory() + + inference_logger().debug( + f"Memory usage before KV-cache allocation: total_memory={total_memory}, available_kv_memory={available_kv_memory}, per_block_footprint={per_block_footprint}" + ) + + if available_kv_memory < per_block_footprint: + raise ValueError( + f"Insufficient memory to allocate KV-caches. Required: {per_block_footprint}, Available: {available_kv_memory}" + ) + + num_blocks = available_kv_memory // per_block_footprint + + # In a multi-process setting, we need to ensure that all processes have the same + # KV cache capacity to ensure scheduling guarantees are equivalent on all ranks. + if dist.get_world_size(group=mp_group) > 1: + reduce_tensor = torch.tensor(num_blocks, dtype=torch.int32, device=get_accelerator().current_device()) + dist.all_reduce(reduce_tensor, op=ReduceOp.MIN, group=mp_group) + num_blocks = reduce_tensor.item() + + # This is ugly but don't want the fragmentation of the 8 byte Tensor maybe + # hanging around. + del reduce_tensor + get_accelerator().empty_cache() + else: # AllocationMode.ALLOCATE + num_blocks = self._memory_config.size + + num_caches = self._config.cache_shape[0] + num_heads = self._config.cache_shape[1] + head_size = self._config.cache_shape[2] + + alloc_shape = (num_caches, num_blocks, self._config.block_size, 2, num_heads, head_size) + inference_logger().info(f"Allocating KV-cache with shape: {alloc_shape} consisting of {num_blocks} blocks.") + self._caches = torch.empty(alloc_shape, + dtype=self._config.cache_dtype, + device=get_accelerator().current_device()) + self._allocator = BlockedAllocator(num_blocks) + + def reserve(self, num_blocks: int) -> torch.Tensor: + """ + Reserve a number of blocks from the cache. This will return a 1D tensor of + block_ids that have been marked as reserved. + """ + return self._allocator.allocate(num_blocks) + + def free(self, blocks: Iterable[int]) -> None: + """ + Free a set of blocks from the cache. This will mark the blocks as free in the + allocator. + """ + self._allocator.free(blocks) + + def offload(self, blocks: Iterable[int]) -> torch.Tensor: + """ + Offload KV-cache blocks from accelerator memory to the host. + """ + raise NotImplementedError("Offloading is not yet supported.") + + def restore(self, blocks: Iterable[int]) -> torch.Tensor: + """ + Restore KV-cache blocks from the host to accelerator memory. + """ + raise NotImplementedError("Offloading is not yet supported.") + + def get_cache(self, cache_id: int) -> torch.Tensor: + """ + Get the tensor associated with the given cache ID. + """ + return self._caches[cache_id] + + @property + def free_blocks(self): + return self._allocator.free_blocks diff --git a/deepspeed/inference/v2/ragged/manager_configs.py b/deepspeed/inference/v2/ragged/manager_configs.py new file mode 100644 index 000000000000..a35c1cfd2369 --- /dev/null +++ b/deepspeed/inference/v2/ragged/manager_configs.py @@ -0,0 +1,164 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum +from typing import Tuple + +from deepspeed.pydantic_v1 import PositiveInt, validator + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from ..inference_utils import DtypeEnum + + +class KVCacheConfig(DeepSpeedConfigModel): + + block_size: int = 128 + """ + Number of tokens that may be contained in each cache block. + """ + + num_allocation_groups: PositiveInt = 1 + """ + Allocation groups are assumed to be able to use the same allocation block size because + the allocation granularity is the same but the number of blocks required in each group + may differ. + + As a concrete example, consider a model with alternating layers of local and global + attention (such as GPTNeo). The local attention layers do not require the same number + of cache blocks as the global layer. However, a static partitioning scheme is sub-optimal since the ratio of local to global KV-cache blocks is not constant across + the range of sequence lengths that may be encountered. + + NOTE: In theory, this functionality could be used to do per-head and per-layer + KV-cache allocation, but it is likely the allocator will struggle with managing that + many blocks. + + NOTE: This will need to be primarily understood and handled by the model implementation + itself, rather than the KV cache manager. However, I'd like to make this explicit. + """ + + cache_shape: Tuple[PositiveInt, PositiveInt, PositiveInt] + """ + The shape of the cache per token. The first dimension is the number of individual + caches, the second is the number of heads, and the third is the head size. The number + of caches argument here is per allocation group. + """ + + cache_dtype: DtypeEnum = DtypeEnum.fp16 + """ + Data type of the KV-cache. + """ + + max_blocks_per_allocation_group: PositiveInt = 64 + """ + Maximum number of blocks that can be associated with an allocation group. + """ + + +""" +The config above is a little confusing so let's use a couple of concrete examples of +usage: + +Model 1: Llama-13B with a block size of 256 + +Llama is uniform attention so we have a single allocation group. The cache shape is +(40 layers, 40 heads, 128 head size) + +```python +llama_kv_config = KVCacheConfig(block_size=256, + num_allocation_groups=1, + cache_shape=(40, 40, 128)) +``` + +Model 2: GPTNeo-2.7B with a block size of 128 + +GPTNeo has alternating local and global attention layers. We have two allocation groups. +There are 16 layers of each type with 20 heads apiece at 128 head size. + +```python +gptneo_kv_config = KVCacheConfig(num_allocation_groups=2, cache_shape=(16, 20, 128)) +``` +""" + + +class AllocationMode(Enum): + """ + Helper class to describe memory allocation strategies for the KV-cache. + """ + + RESERVE = "reserve" + """ + Reserve a small amount of memory for non-KV cache allocations. + """ + + ALLOCATE = "allocate" + """ + Allocate an explicit number of KV blocks. + """ + + +class MemoryConfig(DeepSpeedConfigModel): + + mode: AllocationMode = AllocationMode.RESERVE + + size: PositiveInt = 1_000_000_000 + """ + Parameter for each of the modes. + + If mode is RESERVE, this is the amount of memory in bytes to reserve after allocating the + KV-cache. If in a tensor-parallel regime, this amount is guaranteed to be reserved on + all devices. + + If mode is ALLOCATE, this is the number of blocks to allocate for the KV-cache. This may + require tuning for model/GPU setups. + """ + + +class DSStateManagerConfig(DeepSpeedConfigModel): + + max_tracked_sequences: PositiveInt = 2048 + """ + How many sequences this engine will track simultaneously. This limit should be greater + than the ``max_ragged_sequence_count``. + """ + + max_ragged_batch_size: PositiveInt = 768 + """ + The maximum number of tokens that can be contained in a single ragged batch. Passing + a larger value than this will raise an exception that must be handled by the runtime. + """ + + max_ragged_sequence_count: PositiveInt = 512 + """ + The maximum number of sequences that can compose a batch. This limitation is only + relevant under CUDA graphing scenarios currently, where the maximum number of blocks + is largely bound by the total number of sequences in the ragged batch. This number cannot + be larger than ``max_tracked_sequences`` or ``max_ragged_batch_size``. + """ + + max_context: PositiveInt = 8192 + """ + The maximum number of tokens (inclusive of generation) that can be contained in a single + sequence. Currently used to bound the size of the KV cache metadata. + """ + + memory_config: MemoryConfig = MemoryConfig() + """ + Directive for how to manage the creation of the KV-cache. See MemoryConfig for more + details. + """ + + offload: bool = False + """ + Enable tracking for offloading KV-cache to host memory. Currently unsupported. + """ + + @validator("max_ragged_sequence_count") + def max_ragged_sequence_count_validator(cls, v: int, values: dict): + # If the attributes below failed their validation they won't appear in the values dict. + if "max_tracked_sequences" in values and v > values["max_tracked_sequences"]: + raise ValueError("max_ragged_sequence_count must be less than max_tracked_sequences") + if "max_ragged_batch_size" in values and v > values["max_ragged_batch_size"]: + raise ValueError("max_ragged_sequence_count must be less than max_ragged_batch_size") + return v diff --git a/deepspeed/inference/v2/ragged/ragged_manager.py b/deepspeed/inference/v2/ragged/ragged_manager.py new file mode 100644 index 000000000000..202f007707c5 --- /dev/null +++ b/deepspeed/inference/v2/ragged/ragged_manager.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from typing import Any, Dict, Optional, Tuple + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import RaggedUtilsBuilder +from deepspeed.utils.logging import logger + +from .blocked_allocator import BlockedAllocator +from .kv_cache import BlockedKVCache +from .manager_configs import DSStateManagerConfig, KVCacheConfig +from .sequence_descriptor import DSSequenceDescriptor + + +class DSStateManager: + """ + Base abstract class for managing blocked KV caches. Will probably have a single + implementation for now. + """ + + _config: DSStateManagerConfig + """ + Config for state management. See DSStateManagerConfig for more details. The arguments here + should come from the engine config. + """ + + _kv_config: KVCacheConfig + """ + Config for the KV cache. See KVCacheConfig for more details. These arguments should derive + from the model implementation. + """ + + _kv_cache: BlockedKVCache + """ + Persistent KV cache store. + """ + + # Container for tracking all sequences in the system. + _seqs: Dict[int, DSSequenceDescriptor] + """ + Container for tracking all sequences in the system. + + TODO(cmikeh2): Evaluate if this has any performance implications. + """ + + # Allocator for tracking sequences. + _tracking_allocator: BlockedAllocator + _all_block_ids: torch.Tensor + _all_block_ids_shadow: torch.Tensor + + # TODO(cmikeh2): This interface both needs to be more flexible and more concrete. + def __init__(self, + config: DSStateManagerConfig, + kv_config: KVCacheConfig, + base_mp_group: Optional[Any] = None) -> None: + """ + The key + + Parameters: + block_size (int): The number of tokens to allocate in each block. + """ + self._config = config + self._kv_config = kv_config + + # Load our helpers for host allocation. + self._ragged_utils = RaggedUtilsBuilder().load() + + # Initialize the allocator for tracking sequences (so this doesn't need to be ad-hoc). + self._tracking_allocator = BlockedAllocator(self._config.max_tracked_sequences) + + # Storage to back tracking the KV cache allocation. + ids_shape = ( + self._config.max_tracked_sequences, + self._kv_config.num_allocation_groups, + self._kv_config.max_blocks_per_allocation_group, + ) + self._all_block_ids = torch.zeros(ids_shape, dtype=torch.int32, device=get_accelerator().current_device()) + self._all_block_ids_shadow = self._ragged_utils.allocate_fast_host_buffer(self._all_block_ids) + + # Initialize the sequence container. + self._seqs = {} + + # Finally initialize the KV cache. + self._kv_cache = BlockedKVCache(self._kv_config, + self._config.memory_config, + mp_group=base_mp_group, + offload=self._config.offload) + + def get_cache(self, cache_id: int) -> torch.Tensor: + """ + Return the Tensor associated with the given cache id. + """ + return self._kv_cache.get_cache(cache_id) + + def query(self, uid: Optional[int] = None) -> Tuple[int, int, Optional[int]]: + """ + Query the state of the KV cache for occupancy. + + Parameters: + seq_id (Optional[int]): The sequence id to query. If None, the last + return value will be None. + + Returns: + Tuple[int, int, Optional[Tuple[int, int]]: A tuple of the block size, the number of + free blocks, and the number of cached tokens for the given sequence. + """ + if uid is not None: + cached_toks = self._seqs[uid].cached_tokens + free_toks = cached_toks % self._block_size + return (self._block_size, self._kv_cache.free_blocks, free_toks) + else: + return (self._block_size, self._kv_cache.free_blocks, None) + + def flush_sequence(self, uid: int) -> None: + """ + Free all resources associated with the given sequence id. + """ + if uid not in self._seqs: + logger.warning(f"Attempting to flush sequence {uid} which does not exist.") + return + + seq = self._seqs[uid] + self._kv_cache.free(seq.all_block_ids) + self._tracking_allocator.free(seq.tracking_id) + del self._seqs[uid] + + def get_sequence(self, uid: int) -> Optional[DSSequenceDescriptor]: + """ + Get the sequence descriptor for the given sequence id. If the sequence does not exist, + then None is returned. + """ + if uid not in self._seqs: + return None + + return self._seqs[uid] + + def get_or_create_sequence(self, uid: int) -> DSSequenceDescriptor: + """ + Get the existing sequence descriptor for a given uid or initialize one if + it does not exist. NOTE: This will always return a valid sequence descriptor + if one may be allocated and should not be used from APIs that are attempting + to test the schedulability of a hypothetical batch. + """ + if uid in self._seqs: + return self._seqs[uid] + else: + return self._create_sequence(uid) + + def _create_sequence(self, uid: int) -> DSSequenceDescriptor: + """ + Create a new sequence descriptor for the given sequence id. + """ + if uid in self._seqs: + raise ValueError(f"Sequence {uid} already exists.") + + try: + tracking_slot = self._tracking_allocator.allocate(1).item() + except ValueError: + raise RuntimeError( + f"Unable to create tracking slot for sequence {uid} since the metadata buffers are full.") + + seq_block_ids = self._all_block_ids[tracking_slot] + seq_block_ids_shadow = self._all_block_ids_shadow[tracking_slot] + self._seqs[uid] = DSSequenceDescriptor(tracking_slot, + seq_block_ids, + seq_block_ids_shadow, + max_context=self._config.max_context) + logger.debug(f"Created sequence {uid} with tracking slot {tracking_slot}.") + return self._seqs[uid] + + @property + def tracked_sequences(self) -> Dict[int, DSSequenceDescriptor]: + """ + Return the tracked sequences. + """ + return self._seqs + + @property + def n_tracked_sequences(self) -> int: + """ + Return the number of sequences currently tracked. + """ + return len(self._seqs) + + @property + def kv_block_size(self) -> int: + """ + Return the block size of the KV cache. + """ + return self._kv_config.block_size + + @property + def free_blocks(self) -> int: + """ + Return the number of free blocks in the KV cache. + """ + return self._kv_cache.free_blocks + + def allocate_blocks(self, n_blocks: int) -> torch.Tensor: + return self._kv_cache.reserve(n_blocks) diff --git a/deepspeed/inference/v2/ragged/ragged_wrapper.py b/deepspeed/inference/v2/ragged/ragged_wrapper.py new file mode 100644 index 000000000000..15abd1987aa1 --- /dev/null +++ b/deepspeed/inference/v2/ragged/ragged_wrapper.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import RaggedUtilsBuilder + +from .sequence_descriptor import DSSequenceDescriptor +from .manager_configs import DSStateManagerConfig + + +def to_padded(original_size: int) -> int: + """ + Pad to a backend friendly granularity. + """ + + def _pad_to_mul_of_pow2(val: int, pow_2_val: int) -> int: + return val + (pow_2_val - 1) & ~(pow_2_val - 1) + + # TODO(cmikeh2): Tune this approach. This is mainly a placeholder right now. + granularity = 64 if original_size <= 512 else 128 + + return _pad_to_mul_of_pow2(original_size, granularity) + + +class RaggedBatchWrapper: + """ + Container for all the auxiliary Tensors used in the management of a ragged batch. + + For each Tensor, we maintain a shadow Tensor on the host. This Tensor is what is + directly populated when constructing the ragged batch. The shadow Tensors, when possible, + should be allocated so as to support fast host-to-accelerator copies. + """ + + # Tensors to populate the ragged batch into. + _input_ids_shadow: torch.Tensor + _input_ids: torch.Tensor + """ + Forward pass input buffer. + """ + + _batch_metadata_storage: torch.Tensor + _batch_metadata_storage_shadow: torch.Tensor + """ + Holds the number of inflight sequences and tokens for the ragged batch. + """ + + _token_to_seq_storage: torch.Tensor + _token_to_seq_storage_shadow: torch.Tensor + """ + Linear mapping for each of the tokens. Let's say we have 8 tokens in the batch, + with the sequence breakdown being [4, 1, 3]. Then, the mapping would be: + [0, 0, 0, 0, 1, 2, 2, 2] + """ + + _inflight_seq_descriptors: torch.Tensor + _inflight_seq_descriptors_shadow: torch.Tensor + """ + For each sequence in the batch, we store the start token in the batch, the number of tokens + the number of tokens in the history of this sequence, and an unused 4th reserved for alignment. + For the above example this would give: + [[0, 4, H0, X], [4, 1, H1, X], [5, 3, H2, X]] + """ + + # Holds the block ids for each sequence in the ragged batch. + _kv_ptrs: torch.Tensor + _kv_ptrs_shadow: torch.Tensor + """ + List of ptrs pointing to the GPU buffer that holds the KV-block ids for each sequence. + If there are multiple allocation groups associated with each of the sequences, then + then accessing the Nth cache will require accessing the Nth block id + """ + + def __init__(self, config: DSStateManagerConfig) -> None: + """ + Convenience wrapper around the data structures used to represent a ragged + batch for inference. Only a single `RaggedBatchWrapper` should be used per + ragged inference engine. + + The underlying data structures are implemented in `ragged_batch_descriptor.h`. + """ + self._config = config + self._input_ids = torch.zeros((self._config.max_ragged_batch_size), + dtype=torch.int64, + device=get_accelerator().current_device()) + + self._batch_metadata_storage = torch.zeros(2, dtype=torch.int32, device=get_accelerator().current_device()) + + self._token_to_seq_storage = torch.zeros((self._config.max_ragged_batch_size), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._inflight_seq_descriptors = torch.zeros((self._config.max_ragged_sequence_count, 4), + dtype=torch.int32, + device=get_accelerator().current_device()) + self._kv_ptrs = torch.zeros((self._config.max_ragged_sequence_count), + dtype=torch.int64, + device=get_accelerator().current_device()) + + self._utils_module = RaggedUtilsBuilder().load() + host_alloc = self._utils_module.allocate_fast_host_buffer + + self._input_ids_shadow = host_alloc(self._input_ids) + self._batch_metadata_storage_shadow = host_alloc(self._batch_metadata_storage) + self._token_to_seq_storage_shadow = host_alloc(self._token_to_seq_storage) + self._inflight_seq_descriptors_shadow = host_alloc(self._inflight_seq_descriptors) + self._kv_ptrs_shadow = host_alloc(self._kv_ptrs) + + # Default behavior should be no padding + self._is_padded = False + + def clear(self) -> None: + """ + Clear the ragged batch. This will reset the number of tokens and sequences to 0. + """ + self._batch_metadata_storage_shadow[0] = 0 + self._batch_metadata_storage_shadow[1] = 0 + + def insert_sequence(self, seq_descriptor: DSSequenceDescriptor, tokens: torch.Tensor, do_checks=True) -> None: + """ + Incrementally insert a sequence into the ragged batch. This will update the + metadata for the ragged batch and the sequence. + + Arguments: + seq_descriptor () + """ + if tokens.device != torch.device("cpu"): + # This doesn't really fall under schedulability, so we'll unconditionally check for it. + raise RuntimeError(f"Expected tokens to be on host but found device '{tokens.device}'") + + if do_checks and self.current_sequences == self._config.max_ragged_sequence_count: + raise RuntimeError(f"Ragged batch is full due to sequence limit: {self._config.max_ragged_sequence_count}") + + seq_tokens = tokens.numel() + + if do_checks and self.current_tokens + seq_tokens > self._config.max_ragged_batch_size: + raise RuntimeError(f"Ragged batch is full due to capacity limit: {self._config.max_ragged_batch_size})") + + self._input_ids_shadow[self.current_tokens:self.current_tokens + seq_tokens].copy_(tokens) + self._token_to_seq_storage_shadow[self.current_tokens:self.current_tokens + seq_tokens].fill_( + self.current_sequences) + + self._inflight_seq_descriptors_shadow[self.current_sequences][0] = self.current_tokens + self._inflight_seq_descriptors_shadow[self.current_sequences][1] = seq_tokens + self._inflight_seq_descriptors_shadow[self.current_sequences][2] = seq_descriptor.seen_tokens + + self._kv_ptrs_shadow[self.current_sequences] = seq_descriptor.kv_blocks_ptr + + self._batch_metadata_storage_shadow[0] += seq_tokens + self._batch_metadata_storage_shadow[1] += 1 + + @property + def tensor_toks(self) -> torch.Tensor: + """ + The number of tokens in the in-flight ragged batch. This will not trigger + synchronization with the device. + """ + cur_toks = self.current_tokens + if self._is_padded: + return to_padded(cur_toks) + else: + return cur_toks + + def finalize(self, padding: Optional[bool] = False) -> None: + """ + Completes construction of the ragged batch by flushing the host buffers to the device. + """ + cur_toks = self.current_tokens + + if padding: + padded_toks = to_padded(cur_toks) + self._input_ids_shadow[cur_toks:padded_toks].fill_(-1) + self._token_to_seq_storage_shadow[cur_toks:padded_toks].fill_(-1) + self._is_padded = True + else: + padded_toks = cur_toks + self._is_padded = False + + current_sequences = self.current_sequences + + def _noblock_copy(dst: torch.Tensor, src: torch.Tensor) -> None: + dst.copy_(src, non_blocking=True) + + _noblock_copy(self._input_ids[:padded_toks], self._input_ids_shadow[:padded_toks]) + _noblock_copy(self._batch_metadata_storage, self._batch_metadata_storage_shadow) + _noblock_copy(self._token_to_seq_storage[:padded_toks], self._token_to_seq_storage_shadow[:padded_toks]) + _noblock_copy(self._inflight_seq_descriptors[:current_sequences], + self._inflight_seq_descriptors_shadow[:current_sequences]) + _noblock_copy(self._kv_ptrs[:current_sequences], self._kv_ptrs_shadow[:current_sequences]) + + def input_ids(self, on_device: bool = True) -> torch.Tensor: + """ + The input ids tensor for the ragged batch. If the device Tensor is requested, the Tensor + is truncated to the number of tokens in the batch. + """ + if on_device: + return self._input_ids[:self.tensor_toks] + else: + return self._input_ids_shadow + + def batch_metadata_buffer(self, on_device: bool = True) -> torch.Tensor: + """ + Buffer associated with the batch metadata tensor that can + be populated in preparation for passing a new input to the device. + """ + if on_device: + return self._batch_metadata_storage + else: + return self._batch_metadata_storage_shadow + + def tokens_to_seq(self, on_device: bool = True) -> torch.Tensor: + """ + Mapping of token to which sequence it belongs to in the ragged batch. If the device Tensor + is requested, the Tensor is truncated to the number of tokens in the batch. + """ + if on_device: + return self._token_to_seq_storage[:self.tensor_toks] + else: + return self._token_to_seq_storage_shadow + + def inflight_seq_descriptors(self, on_device: bool = True) -> torch.Tensor: + """ + Buffer associated with the metadata of each sequence in the ragged batch. If the device Tensor + is requested, the Tensor is truncated to the number of sequences in the batch. + """ + if on_device: + return self._inflight_seq_descriptors[:self.current_sequences] + else: + return self._inflight_seq_descriptors_shadow + + def kv_ptrs(self, on_device: bool = True) -> torch.Tensor: + """ + Pointer to where the list of KV ids associated with a sequence are. If the device Tensor + is requested, the Tensor is truncated to the number of sequences in the batch. + """ + if on_device: + return self._kv_ptrs[:self.current_sequences] + else: + return self._kv_ptrs_shadow + + def masks(self, on_device: bool = True) -> Optional[torch.Tensor]: + """ + Placeholder for supporting complex masks. Currently not supported. + + Models that will need this will be BERT-like, not generative. + """ + return None + + @property + def current_tokens(self) -> int: + """ + The number of tokens in the in-flight ragged batch. This will not trigger + synchronization with the device. + """ + return self._batch_metadata_storage_shadow[0].item() + + @property + def current_sequences(self) -> int: + """ + The number of sequences in the in-flight ragged batch. This will not trigger + synchronization with the device. + """ + return self._batch_metadata_storage_shadow[1].item() diff --git a/deepspeed/inference/v2/ragged/sequence_descriptor.py b/deepspeed/inference/v2/ragged/sequence_descriptor.py new file mode 100644 index 000000000000..b5cad9d2cd4d --- /dev/null +++ b/deepspeed/inference/v2/ragged/sequence_descriptor.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Union + +import torch + + +class BaseSequenceDescriptor: + + @property + def seen_tokens(self) -> int: + """ + The number of tokens for this sequence that have completed a forward pass. + """ + raise NotImplementedError() + + @property + def cur_allocated_blocks(self) -> int: + """ + The number of KV blocks currently allocated for this sequence. + """ + raise NotImplementedError() + + @property + def kv_blocks_ptr(self) -> int: + """ + The pointer to the KV blocks for this sequence. + """ + raise NotImplementedError() + + +class PlaceholderSequenceDescriptor(BaseSequenceDescriptor): + """ + The DummySequenceDescriptor is an empty object that allows us to perform schedulability + checks before formally tracking a sequence. + """ + + def __init__(self, seen_tokens=0, cur_allocated_blocks=0, kv_blocks_ptr=0) -> None: + self._seen_tokens = seen_tokens + self._cur_allocated_blocks = cur_allocated_blocks + self._kv_blocks_ptr = kv_blocks_ptr + + @property + def seen_tokens(self) -> int: + return self._seen_tokens + + @property + def cur_allocated_blocks(self) -> int: + return self._cur_allocated_blocks + + @property + def kv_blocks_ptr(self) -> int: + return self._kv_blocks_ptr + + +class DSSequenceDescriptor(BaseSequenceDescriptor): + + _seen_tokens: int + """ + Number of tokens in the sequence that have completed a forward pass. + """ + + _in_flight_tokens: int + """ + Number of tokens that have begun a forward pass but not yet completed it. + """ + + _max_context: int + """ + Maximum number of tokens this sequence may eventually include. Currently unused but + may be used in future implementations for speculative caching. + """ + + _num_allocation_groups: int + """ + Number of unique allocation groups associated with the sequence. + """ + + _blocks_per_allocation_group: torch.IntTensor + """ + Number of blocks allocated for each allocation group. + """ + + # Padded list of KV-cache IDs for the sequence. + _kv_cache_ids: torch.Tensor + _kv_cache_ids_shadow: torch.Tensor + """ + Padded list of KV-cache IDs for the sequence. The padded shape is [num_allocation_groups, max_blocks_per_allocation_group]. + """ + + # The location in the broader ID tensor where the KV-cache IDs for the sequence + # are stored. Used on flush. + _tracking_id: int + + def __init__(self, + tracking_id: int, + kv_cache_ids: torch.Tensor, + kv_cache_ids_shadow: torch.Tensor, + max_context: int = -1) -> None: + self._tracking_id = tracking_id + self._kv_cache_ids = kv_cache_ids + self._kv_cache_ids_shadow = kv_cache_ids_shadow + self._max_context = max_context + + self._seen_tokens = 0 + self._in_flight_tokens = 0 + + self._num_allocation_groups = kv_cache_ids_shadow.shape[0] + self._blocks_per_allocation_group = torch.zeros(self._num_allocation_groups, dtype=torch.int32, device="cpu") + + assert kv_cache_ids.shape[0] == self._num_allocation_groups + assert len(kv_cache_ids.shape) == 2 + + @property + def seen_tokens(self) -> int: + return self._seen_tokens + + @property + def in_flight_tokens(self) -> int: + return self._in_flight_tokens + + @property + def max_context(self) -> int: + return self._max_context + + @property + def cur_allocated_blocks(self) -> int: + return self._blocks_per_allocation_group.sum() + + @property + def tracking_id(self) -> int: + return self._tracking_id + + def kv_cache_ids(self, on_device: bool = False) -> torch.Tensor: + """ + Returns the Tensor containing the block IDs for this sequence on the appropriate device. + """ + if on_device: + return self._kv_cache_ids + else: + return self._kv_cache_ids_shadow + + @property + def kv_blocks_ptr(self) -> int: + return self._kv_cache_ids.data_ptr() + + @property + def all_block_ids(self) -> torch.Tensor: + block_ids = [] + for allocation_group, num_blocks in zip(self._kv_cache_ids, self._blocks_per_allocation_group): + block_ids.append(allocation_group[:num_blocks]) + return torch.cat(block_ids) + + def pre_forward(self, num_tokens: int) -> None: + """ + Update the state of the sequence before a forward pass. + """ + self._in_flight_tokens = num_tokens + + def post_forward(self) -> None: + """ + Update the state of the sequence after a forward pass. + """ + self._seen_tokens += self._in_flight_tokens + self._in_flight_tokens = 0 + + def extend_kv_cache(self, new_ids: Union[List[torch.IntTensor], torch.IntTensor]) -> None: + """ + Extend the KV-cache for the sequence. + + Args: + new_ids (Union[List[torch.IntTensor], torch.IntTensor]): For each allocation group, the IDs + to add to the KV-cache. If there is only one allocation group, a single tensor can be + provided. Otherwise, a list of tensors should be provided. The tensors do not need + to have the same shape. + """ + if isinstance(new_ids, torch.Tensor): + new_ids = [new_ids] + + if len(new_ids) != self._num_allocation_groups: + raise ValueError(f"Only {len(new_ids)} allocation groups provided, expected {self._num_allocation_groups}") + + for group_id, new_group_ids in enumerate(new_ids): + new_blocks = new_group_ids.numel() + + if new_blocks == 0: + # If we have multiple groups, it's possible to have an empty group. + continue + + shadow_alloc_group = self._kv_cache_ids_shadow[group_id] + alloc_group = self._kv_cache_ids[group_id] + cur_blocks = self._blocks_per_allocation_group[group_id] + + shadow_alloc_group[cur_blocks:cur_blocks + new_blocks].copy_(new_group_ids) + alloc_group[cur_blocks:cur_blocks + new_blocks].copy_(shadow_alloc_group[cur_blocks:cur_blocks + + new_blocks], + non_blocking=True) + + self._blocks_per_allocation_group[group_id] += new_blocks + + def free_kv_cache(self, free_ids: Union[List[torch.IntTensor], torch.IntTensor]) -> None: + """ + Free blocks from the KV-cache for the sequence. + + Args: + free_ids (Union[List[torch.IntTensor], torch.IntTensor]): The ids of blocks to free + from the KV-cache. If there is only one allocation group, a single tensor can be + provided. Otherwise, a list of tensors should be provided. The tensors do not need + to have the same shape. + """ + raise NotImplementedError("Partial KV-cache freeing is not yet supported.") diff --git a/deepspeed/inference/v2/scheduling_utils.py b/deepspeed/inference/v2/scheduling_utils.py new file mode 100644 index 000000000000..6d3818d46675 --- /dev/null +++ b/deepspeed/inference/v2/scheduling_utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum + + +class SchedulingResult(Enum): + + Success = 0 + """ + The proposed batch is valid and can be scheduled. + """ + + EngineSequenceLimitExceeded = 1 + """ + The proposed batch would would overflow the number of concurrent sequences the engine may support. + """ + + BatchSequenceLimitExceeded = 2 + """ + The proposed batch contains more sequences than the engine was configured + to support in a single forwardp + """ + + BatchTokenLimitExceeded = 3 + """ + The proposed batch contains more tokens than the engine was configured + to support in a single forward. + """ + + KVCacheLimitExceeded = 4 + """ + The proposed batch would require more KV cache to be allocated than the engine + currently has available. + """ + + SequenceTokenLimitExceeded = 5 + """ + The proposed batch contains a sequence that is longer than the engine/model can support. + """ + + +class SchedulingError(RuntimeError): + + result: SchedulingResult + """ + The failed result of the scheduling check. Guaranteed to not be SchedulingResult.Success. + """ + + def __init__(self, result: SchedulingResult) -> None: + self.result = result + super().__init__(f"Batch scheduling failed with result {result}") diff --git a/docs/index.md b/docs/index.md index fa99aa9ccaf3..60bcf19b84da 100755 --- a/docs/index.md +++ b/docs/index.md @@ -7,6 +7,7 @@ title: "Latest News" --- DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). +* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) * [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)] * [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)] * [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) diff --git a/op_builder/builder.py b/op_builder/builder.py index 62683774f20a..cf2d3311667e 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -70,7 +70,7 @@ def get_default_compute_capabilities(): "10.2", ], 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], - 12: ["12.0", "12.1"], + 12: ["12.0", "12.1", "12.2", "12.3"], } @@ -102,6 +102,7 @@ def assert_no_cuda_mismatch(name=""): class OpBuilder(ABC): _rocm_version = None _is_rocm_pytorch = None + _loaded_ops = {} def __init__(self, name): self.name = name @@ -433,6 +434,9 @@ def builder(self): extra_link_args=self.strip_empty_entries(self.extra_ldflags())) def load(self, verbose=True): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + from deepspeed.git_version_info import installed_ops, torch_info if installed_ops.get(self.name, False): # Ensure the op we're about to load was compiled with the same @@ -441,7 +445,9 @@ def load(self, verbose=True): if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder): self.validate_torch_op_version(torch_info) - return importlib.import_module(self.absolute_name()) + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module else: return self.jit_load(verbose) @@ -456,11 +462,12 @@ def jit_load(self, verbose=True): raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): - try: - assert_no_cuda_mismatch(self.name) - self.build_for_cpu = False - except BaseException: - self.build_for_cpu = True + #TODO(jeff): need to come back and fix cpu-only builds, this came in on #3085 but is hiding real user env issues (eg. torch cuda != sys cuda) + #try: + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + #except BaseException: + # self.build_for_cpu = True self.jit_mode = True from torch.utils.cpp_extension import load @@ -485,6 +492,8 @@ def jit_load(self, verbose=True): if not self.build_for_cpu and self.enable_bf16: cxx_args.append("-DBF16_AVAILABLE") nvcc_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") if self.is_rocm_pytorch(): cxx_args.append("-D__HIP_PLATFORM_AMD__=1") @@ -505,6 +514,8 @@ def jit_load(self, verbose=True): if torch_arch_list: os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + __class__._loaded_ops[self.name] = op_module + return op_module diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py new file mode 100755 index 000000000000..b8ea54cd0b3f --- /dev/null +++ b/op_builder/inference_core_ops.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class InferenceCoreBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_INFERENCE_CORE_OPS" + NAME = "inference_core_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.kernels{self.NAME}' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 6: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/kernels/core_ops/core_ops.cpp", + "inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp", + "inference/v2/kernels/core_ops/bias_activations/bias_activation.cu", + "inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cpp", + "inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm.cu", + "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cpp", + "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm.cu", + "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp", + "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cu", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + return [] + + def include_paths(self): + sources = [ + 'inference/v2/kernels/core_ops/bias_activations', + 'inference/v2/kernels/core_ops/blas_kernels', + 'inference/v2/kernels/core_ops/cuda_layer_norm', + 'inference/v2/kernels/core_ops/cuda_rms_norm', + 'inference/v2/kernels/core_ops/gated_activations', + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + + sources.append('csrc/includes') + return sources diff --git a/op_builder/inference_cutlass_builder.py b/op_builder/inference_cutlass_builder.py new file mode 100644 index 000000000000..73400b2ce10e --- /dev/null +++ b/op_builder/inference_cutlass_builder.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class InferenceCutlassBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_CUTLASS_OPS" + NAME = "cutlass_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.kernels.cutlass_ops.{self.NAME}' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 8: + # Only support Ampere and newer + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/kernels/cutlass_ops/cutlass_ops.cpp", + "inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu", + "inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + import dskernels + lib_path = dskernels.library_path() + prefix = self.get_prefix() + lib_path = os.path.join(prefix, lib_path) + lib_path = self.deepspeed_src_path(lib_path) + + args = [f'-L{lib_path}', '-ldeepspeedft'] + if self.jit_load: + args.append(f'-Wl,-rpath,{lib_path}') + return args + + def include_paths(self): + sources = [ + 'inference/v2/kernels/cutlass_ops/mixed_gemm', + 'inference/v2/kernels/cutlass_ops/moe_gemm', + 'inference/v2/kernels/cutlass_ops/shared_resources/', + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + sources.append('csrc/includes') + return sources diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py new file mode 100644 index 000000000000..505aaa9708cf --- /dev/null +++ b/op_builder/ragged_ops.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class RaggedOpsBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_RAGGED_DEVICE_OPS" + NAME = "ragged_device_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.kernels.ragged_ops.{self.NAME}' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 8: + # Blocked flash has a dependency on Ampere + newer + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/kernels/ragged_ops/ragged_ops.cpp", + "inference/v2/kernels/ragged_ops/atom_builder/atom_builder.cpp", + "inference/v2/kernels/ragged_ops/blocked_flash/blocked_flash.cpp", + "inference/v2/kernels/ragged_ops/embed/embed.cpp", + "inference/v2/kernels/ragged_ops/embed/embed.cu", + "inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp", + "inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu", + "inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cpp", + "inference/v2/kernels/ragged_ops/logits_gather/logits_gather.cu", + "inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp", + "inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu", + "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp", + "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu", + "inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp", + "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp", + "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + import dskernels + lib_path = dskernels.library_path() + + prefix = self.get_prefix() + lib_path = os.path.join(prefix, lib_path) + lib_path = self.deepspeed_src_path(lib_path) + + args = [f'-L{lib_path}', '-lblockedflash'] + if self.jit_load: + args.append(f'-Wl,-rpath,{lib_path}') + return args + + def include_paths(self): + sources = [ + 'inference/v2/kernels/ragged_ops', + 'inference/v2/kernels/ragged_ops/atom_builder', + 'inference/v2/kernels/ragged_ops/blocked_flash', + 'inference/v2/kernels/ragged_ops/embed', + 'inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary', + 'inference/v2/kernels/ragged_ops/logits_gather', + 'inference/v2/kernels/ragged_ops/moe_gather', + 'inference/v2/kernels/ragged_ops/moe_scatter', + 'inference/v2/kernels/ragged_ops/ragged_helpers', + 'inference/v2/kernels/ragged_ops/top_1_gating', + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + + sources.append('csrc/includes') + + return sources diff --git a/op_builder/ragged_utils.py b/op_builder/ragged_utils.py new file mode 100755 index 000000000000..c1d38eef1981 --- /dev/null +++ b/op_builder/ragged_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os + +from .builder import CUDAOpBuilder, installed_cuda_version + + +class RaggedUtilsBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_RAGGED_OPS" + NAME = "ragged_ops" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.inference.v2.{self.NAME}' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability < 6: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 6: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def get_prefix(self): + ds_path = self.deepspeed_src_path("deepspeed") + return "deepspeed" if os.path.isdir(ds_path) else ".." + + def sources(self): + sources = [ + "inference/v2/ragged/csrc/fast_host_buffer.cu", + "inference/v2/ragged/csrc/ragged_ops.cpp", + ] + + prefix = self.get_prefix() + sources = [os.path.join(prefix, src) for src in sources] + return sources + + def extra_ldflags(self): + return [] + + def include_paths(self): + include_dir = "inference/v2/ragged/includes" + prefix = self.get_prefix() + include_dir = os.path.join(prefix, include_dir) + + return ['csrc/includes', include_dir] diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 5d0e963779a3..105dd094f995 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,6 +1,7 @@ accelerate clang-format==16.0.2 coverage +deepspeed-kernels docutils<0.18 future importlib-metadata>=4 diff --git a/scripts/check-torchcuda.py b/scripts/check-torchcuda.py index 48581628930b..0723c9888369 100755 --- a/scripts/check-torchcuda.py +++ b/scripts/check-torchcuda.py @@ -19,6 +19,8 @@ def err(s: str) -> None: print(s, file=sys.stderr) +print(*sys.argv[1:]) + # There are many ways we could search for the string "torch.cuda", but `git # grep --no-index` is nice because # - it's very fast (as compared to iterating over the file in Python) diff --git a/tests/pytest.ini b/tests/pytest.ini index 20d6ca0624f2..cc6b6564daa8 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,9 +1,11 @@ [pytest] -addopts = -m "not sequential and not nightly and not inference and not seq_inference and not inference_ops" +addopts = -m "not sequential and not nightly and not inference and not seq_inference and not inference_ops and not inference_v2 and not inference_v2_ops" markers = sequential:Tests that need to be run sequentially inference:Inference model tests inference_ops:Individual inference operator tests + inference_v2: Inference tests for the v2 stack + inference_v2_ops: Op tests for the v2 stack seq_inference:Inference model tests to run sequentially nightly:Tests that should be run nightly world_size:Change world size of individual tests in a class diff --git a/tests/unit/inference/__init__.py b/tests/unit/inference/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/inference_test_utils.py b/tests/unit/inference/inference_test_utils.py new file mode 100644 index 000000000000..d63c51267e51 --- /dev/null +++ b/tests/unit/inference/inference_test_utils.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import torch +from deepspeed.accelerator import get_accelerator + +TOLERANCES = None + + +def get_tolerances(): + global TOLERANCES + if TOLERANCES is None: + TOLERANCES = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)} + if get_accelerator().is_bf16_supported(): + # Note: BF16 tolerance is higher than FP16 because of the lower precision (7 (+1) bits vs + # 10 (+1) bits) + TOLERANCES[torch.bfloat16] = (4.8e-1, 3.2e-2) + return TOLERANCES + + +DTYPES = None + + +def get_dtypes(include_float=True): + global DTYPES + if DTYPES is None: + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass + return DTYPES + + +def allclose(x, y, tolerances: Tuple[int, int] = None): + assert x.dtype == y.dtype + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances + return torch.allclose(x, y, rtol=rtol, atol=atol) diff --git a/tests/unit/inference/kernels/__init__.py b/tests/unit/inference/kernels/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/kernels/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/kernels/core_ops/__init__.py b/tests/unit/inference/kernels/core_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/kernels/core_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/kernels/core_ops/test_bias_activation.py b/tests/unit/inference/kernels/core_ops/test_bias_activation.py new file mode 100644 index 000000000000..2c6134991597 --- /dev/null +++ b/tests/unit/inference/kernels/core_ops/test_bias_activation.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.kernels.core_ops import CUDABiasActivation +from ...inference_test_utils import get_dtypes, allclose + + +def reference_bias_act_implementation(input: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + bias_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + dtype = input.dtype + input_f = input.to(torch.float32) + if bias is not None: + bias_f = bias.to(torch.float32) + output_f = input_f + bias_f + else: + output_f = input_f + output_f = bias_func_map[act_type](output_f) + + return output_f.to(dtype) + + +def _bias_activation_test_helper(tokens: int, + channels: int, + act_fn: ActivationType, + dtype: DtypeEnum, + use_bias: bool = True) -> None: + """ + Fully parameterized testing entry point. + """ + # Input vals + input_tensor = torch.randn((tokens, channels), dtype=dtype.value, device=get_accelerator().current_device_name()) + if use_bias: + bias = torch.randn((channels), dtype=dtype.value, device=get_accelerator().current_device_name()) + else: + bias = None + + # Reference output + ref_output = reference_bias_act_implementation(input_tensor, bias, act_fn) + + bias_act = CUDABiasActivation(channels, dtype, act_fn) + + # New output + ds_tensor = input_tensor.clone() + bias_act(ds_tensor, bias) + + # Check + assert allclose(ds_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_token_channels_permutations(tokens: int, channels: int, dtype: torch.dtype) -> None: + """ + Validate bias activation kernel with different token and channel permutations when using the RELU + activation function. + """ + act_fn = ActivationType.RELU + dtype = DtypeEnum(dtype) + _bias_activation_test_helper(tokens, channels, act_fn, dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", + [ActivationType.RELU, ActivationType.GELU, ActivationType.SILU, ActivationType.IDENTITY]) +def test_act_fns(act_fn: ActivationType) -> None: + """ + Validate bias activation kernel with different activation functions. + """ + tokens = 223 + channels = 4096 + dtype = DtypeEnum.fp16 + _bias_activation_test_helper(tokens, channels, act_fn, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_bias() -> None: + """ + Validate bias activation kernel with no bias. + """ + tokens = 223 + channels = 4096 + dtype = DtypeEnum.fp16 + act_fn = ActivationType.IDENTITY + _bias_activation_test_helper(tokens, channels, act_fn, dtype, use_bias=False) diff --git a/tests/unit/inference/kernels/core_ops/test_blas_linear.py b/tests/unit/inference/kernels/core_ops/test_blas_linear.py new file mode 100644 index 000000000000..0f9f99b4f879 --- /dev/null +++ b/tests/unit/inference/kernels/core_ops/test_blas_linear.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import BlasLibLinear +from ...inference_test_utils import allclose + +# Note: only testing with FP16 and BF16 because we use TF32 on Ampere and we don't have a good +# set of tolerances. Since this is just on top of BLAS though, the test is more about +# making sure the stride/contiguity is correct and that's data type agnostic. + + +def reference_implementation(hidden_states, weights): + return hidden_states @ weights.t() + + +problem_shapes = [ + (1, 1, 1024, 1024), + (1, 1024, 1024, 1024), + (2, 1024, 1024, 1024), + (1, 128, 768, 3072), + (1, 128, 3072, 768), + (1, 1024, 8192, 8192), + (1, 733, 8192, 32768), + (1, 13, 32768, 8192), +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("problem_shape", problem_shapes) +def test_blas_linear(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]): + batch, seq_len, in_features, out_features = problem_shape + hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype, + device=get_accelerator().current_device()) * 0.1 + weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01 + ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device()) + + ds_kernel = BlasLibLinear(fp_dtype) + + ds_output = ds_kernel(ds_output, hidden_states, weights) + ref_output = reference_implementation(hidden_states, weights) + + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("fp_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("problem_shape", problem_shapes) +def test_blas_linear_t(fp_dtype: torch.dtype, problem_shape: Tuple[int, int, int, int]): + batch, seq_len, in_features, out_features = problem_shape + hidden_states = torch.randn(batch, seq_len, in_features, dtype=fp_dtype, + device=get_accelerator().current_device()) * 0.1 + weights = torch.randn(out_features, in_features, dtype=fp_dtype, device=get_accelerator().current_device()) * 0.01 + ds_output = torch.empty(batch, seq_len, out_features, dtype=fp_dtype, device=get_accelerator().current_device()) + + ds_kernel = BlasLibLinear(fp_dtype) + + # Transpose the weights then revert to the format we expect. + weights = weights.t().contiguous() + weights = weights.t() + ds_output = ds_kernel(ds_output, hidden_states, weights) + + ref_output = reference_implementation(hidden_states, weights) + + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/kernels/core_ops/test_gated_activation.py b/tests/unit/inference/kernels/core_ops/test_gated_activation.py new file mode 100644 index 000000000000..ebfca4801eea --- /dev/null +++ b/tests/unit/inference/kernels/core_ops/test_gated_activation.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAGatedActivation +from deepspeed.inference.v2.inference_utils import ActivationType +from ...inference_test_utils import get_dtypes, allclose + + +def reference_geglu_implementation(input: torch.Tensor, + bias: Optional[torch.Tensor] = None, + act_fn: Optional[ActivationType] = ActivationType.GEGLU) -> torch.Tensor: + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + dtype = input.dtype + input = input.to(torch.float32) + + if bias is not None: + bias = bias.to(torch.float32) + input = input + bias + + act_act = input[..., ::2] + act_linear = input[..., 1::2] + + act_act = act_func_map[act_fn](act_act) + + return (act_act * act_linear).to(dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("shape", [(1372, 16384), (2, 743, 22016)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_dtypes(shape: Iterable[int], dtype: torch.dtype) -> None: + input_tensor = torch.randn(shape, dtype=dtype, device=get_accelerator().current_device_name()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU) + + # Build kernel + geglu = CUDAGatedActivation(input_tensor.size(-1), input_tensor.dtype, ActivationType.GEGLU) + + # New output + output_shape = list(input_tensor.shape) + output_shape[-1] //= 2 + output_tensor = torch.empty(output_shape, dtype=input_tensor.dtype, device=get_accelerator().current_device_name()) + geglu(output_tensor, input_tensor) + + # Check + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]) +def test_act_fn(act_fn: ActivationType) -> None: + input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, act_fn=act_fn) + + cuda_act = CUDAGatedActivation(4096, torch.float16, act_fn) + + # New output + output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device()) + cuda_act(output_tensor, input_tensor) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_act_with_bias(): + input_tensor = torch.randn(832, 4096, dtype=torch.float16, device=get_accelerator().current_device()) + bias = torch.randn(4096, dtype=torch.float16, device=get_accelerator().current_device()) + + # Reference output + ref_output = reference_geglu_implementation(input_tensor, bias=bias, act_fn=ActivationType.GEGLU) + + cuda_act = CUDAGatedActivation(4096, torch.float16, ActivationType.GEGLU) + + # New output + output_tensor = torch.empty(832, 2048, dtype=torch.float16, device=get_accelerator().current_device()) + + cuda_act(output_tensor, input_tensor, bias) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_max_channels(): + input_tensor = torch.randn(832, 48152, dtype=torch.float16, device=get_accelerator().current_device()) + + ref_output = reference_geglu_implementation(input_tensor, act_fn=ActivationType.GEGLU) + + cuda_act = CUDAGatedActivation(48152, torch.float16, ActivationType.GEGLU) + + output_tensor = torch.empty(832, 24076, dtype=torch.float16, device=get_accelerator().current_device()) + cuda_act(output_tensor, input_tensor) + + assert allclose(output_tensor, ref_output) + + +@pytest.mark.inference_v2_ops +def test_bad_dtype() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(128, torch.int8, ActivationType.GEGLU) + + +@pytest.mark.inference_v2_ops +def test_bad_act_fn() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(128, torch.float16, ActivationType.RELU) + + +@pytest.mark.inference_v2_ops +def test_bad_alignment() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(127, torch.float16, ActivationType.GEGLU) + + +@pytest.mark.inference_v2_ops +def test_too_many_channels() -> None: + with pytest.raises(ValueError): + CUDAGatedActivation(49160, torch.float16, ActivationType.GEGLU) diff --git a/tests/unit/inference/kernels/core_ops/test_post_ln.py b/tests/unit/inference/kernels/core_ops/test_post_ln.py new file mode 100644 index 000000000000..8b54e5651acb --- /dev/null +++ b/tests/unit/inference/kernels/core_ops/test_post_ln.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAFPPostLN +from ...inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_post_ln(tokens: int, channels: int, dtype: torch.dtype) -> None: + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_kernel = CUDAFPPostLN(hidden_states.size(-1), residual.dtype) + ds_output = torch.empty_like(residual) + post_ln_kernel(ds_output, residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/kernels/core_ops/test_pre_ln.py b/tests/unit/inference/kernels/core_ops/test_pre_ln.py new file mode 100644 index 000000000000..e5ac3ae1428f --- /dev/null +++ b/tests/unit/inference/kernels/core_ops/test_pre_ln.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.core_ops import CUDAFPPreLN +from ...inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + residual_out = residual_f + hidden_states_f + hidden_out = torch.nn.functional.layer_norm(residual_out, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon) + return residual_out.to(hidden_states.dtype), hidden_out.to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_pre_ln(tokens: int, channels: int, dtype: torch.dtype) -> None: + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output_res, ref_output_hid = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + pre_ln_kernel = CUDAFPPreLN(hidden_states.size(-1), residual.dtype) + ds_output_res = torch.empty_like(residual) + ds_output_hid = torch.empty_like(hidden_states) + pre_ln_kernel(ds_output_res, ds_output_hid, residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output_res, ref_output_res) + assert allclose(ds_output_hid, ref_output_hid) diff --git a/tests/unit/inference/kernels/core_ops/test_rms_norm.py b/tests/unit/inference/kernels/core_ops/test_rms_norm.py new file mode 100644 index 000000000000..d2893a2115b7 --- /dev/null +++ b/tests/unit/inference/kernels/core_ops/test_rms_norm.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.core_ops import CUDARMSNorm, CUDARMSPreNorm +from ...inference_test_utils import get_dtypes, allclose + + +def reference_rms_norm(vals: torch.Tensor, gamma: torch.Tensor, epsilon: float = 1e-5) -> torch.Tensor: + variance = vals.to(torch.float32).pow(2).mean(-1, keepdim=True) + vals = vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + vals = vals.to(gamma.dtype) + + return gamma * vals + + +def reference_rms_pre_norm(vals: torch.Tensor, + residual: torch.Tensor, + gamma: torch.Tensor, + epsilon: float = 1e-5) -> torch.Tensor: + residual = residual + vals + return residual, reference_rms_norm(residual, gamma, epsilon) + + +def _rms_norm_testing_helper(rows: int, channels: int, do_residual: bool, dtype: DtypeEnum) -> None: + device = get_accelerator().current_device_name() + t_dtype = dtype.value + + vals = torch.randn((rows, channels), dtype=t_dtype, device=device) + gamma = torch.randn((channels), dtype=t_dtype, device=device) + epsilon = 1e-5 + + if do_residual: + residual_in = torch.randn((rows, channels), dtype=t_dtype, device=device) + ds_residual = residual_in.clone() + + ref_residual, ref_output = reference_rms_pre_norm(vals, residual_in, gamma, epsilon) + + kernel = CUDARMSPreNorm(channels, t_dtype, epsilon=epsilon) + ds_out = torch.empty_like(ds_residual) + + kernel(ds_residual, ds_out, residual_in, vals, gamma) + + assert allclose(ds_out, ref_output) + assert allclose(ds_residual, ref_residual) + else: + + ref_output = reference_rms_norm(vals, gamma, epsilon) + + kernel = CUDARMSNorm(channels, t_dtype, epsilon=epsilon) + ds_out = torch.empty_like(vals) + + kernel(ds_out, vals, gamma) + + assert allclose(ds_out, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes()) +@pytest.mark.parametrize("do_residual", [True, False]) +def test_rms_dtypes(dtype: DtypeEnum, do_residual: bool) -> None: + _rms_norm_testing_helper(883, 1024, do_residual, DtypeEnum(dtype)) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("rows, cols", [(1, 4096), (37, 2048), (112, 14432), (1024, 6144)]) +@pytest.mark.parametrize("do_residual", [True, False]) +def test_rms_shapes(rows: int, cols: int, do_residual: bool) -> None: + _rms_norm_testing_helper(rows, cols, do_residual, DtypeEnum.fp16) diff --git a/tests/unit/inference/kernels/cutlass_ops/__init__.py b/tests/unit/inference/kernels/cutlass_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/kernels/cutlass_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/kernels/cutlass_ops/test_moe_gemm.py b/tests/unit/inference/kernels/cutlass_ops/test_moe_gemm.py new file mode 100644 index 000000000000..4dd6d286fe00 --- /dev/null +++ b/tests/unit/inference/kernels/cutlass_ops/test_moe_gemm.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.kernels.cutlass_ops import MoEGEMM +from ...inference_test_utils import allclose + +SINGLE_EXPERT_CASES = [(13, 2048, 2048), (256, 1024, 4096), (278, 5120, 2048), (893, 5120, 2560)] + +PYTORCH_ACT_FN_MAP = { + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.RELU: torch.nn.functional.relu +} + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, in_neurons, out_neurons", SINGLE_EXPERT_CASES) +def test_single_expert(n_tokens: int, in_neurons: int, out_neurons: int) -> None: + """ + Validate that the GEMM kernel produces identical results for a single GEMM instance. + """ + device = get_accelerator().current_device() + + activations = torch.rand((n_tokens, in_neurons), device=device, dtype=torch.float16) - 0.5 + weights = torch.rand((1, in_neurons, out_neurons), device=device, dtype=torch.float16) - 0.5 + biases = torch.randn((1, out_neurons), device=device, dtype=torch.float16) + + weights_ref = weights.reshape(in_neurons, out_neurons) + biases_ref = biases.reshape(out_neurons) + ref_output = torch.matmul(activations, weights_ref) + biases_ref + + moe_gemm = MoEGEMM(DtypeEnum.fp16, ActivationType.IDENTITY) + output = torch.empty((n_tokens, out_neurons), device=device, dtype=torch.float16) + cumsum_rows = torch.tensor([n_tokens], dtype=torch.int64, device=device) + + moe_gemm(output, activations, weights, cumsum_rows, biases) + assert allclose(output, ref_output, tolerances=(1e-2, 1e-2)) + get_accelerator().synchronize() + + +def moe_test_helper(in_neurons: int, out_neurons: int, n_experts: int, max_tokens_per_expert: int, + act_fn: ActivationType, dtype: DtypeEnum) -> None: + """ + Helper function for validating the GEMM kernel for a single expert. + """ + device = get_accelerator().current_device() + + expert_allocations = torch.randint(0, max_tokens_per_expert, (n_experts, ), device=device, dtype=torch.int32) + cumsum_rows = expert_allocations.cumsum(dim=0) + print(cumsum_rows.dtype) + + activations = torch.rand((cumsum_rows[-1], in_neurons), device=device, dtype=dtype.value) - 0.5 + weights = torch.rand((n_experts, in_neurons, out_neurons), device=device, dtype=dtype.value) - 0.5 + biases = torch.randn((n_experts, out_neurons), device=device, dtype=dtype.value) + + out_ref = torch.empty((cumsum_rows[-1], out_neurons), device=device, dtype=dtype.value) + + for expert_idx in range(n_experts): + start = cumsum_rows[expert_idx - 1] if expert_idx > 0 else 0 + end = cumsum_rows[expert_idx] + activations_slice = activations[start:end] + weights_slice = weights[expert_idx] + biases_slice = biases[expert_idx] + out_ref[start:end] = torch.matmul(activations_slice, weights_slice) + biases_slice + + if act_fn != ActivationType.IDENTITY: + act_fn_fn = PYTORCH_ACT_FN_MAP[act_fn] + out_ref = act_fn_fn(out_ref) + + moe_gemm = MoEGEMM(DtypeEnum.fp16, act_fn) + output = torch.empty((cumsum_rows[-1], out_neurons), device=device, dtype=dtype.value) + + moe_gemm(output, activations, weights, cumsum_rows, biases) + + if dtype == DtypeEnum.bf16: + assert allclose(output, out_ref, tolerances=(1e-1, 1e-1)) + else: + assert allclose(output, out_ref, tolerances=(1e-2, 1e-2)) + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("max_tokens_per_expert", [1, 4, 16, 64, 128]) +def test_multi_expert(max_tokens_per_expert: int) -> None: + """ + Validate for multi-expert GEMM instances that the output is identical to the reference. + """ + moe_test_helper(5120, 2048, 64, max_tokens_per_expert, ActivationType.IDENTITY, DtypeEnum.fp16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", [ActivationType.GELU, ActivationType.SILU, ActivationType.RELU]) +def test_act_fns(act_fn: ActivationType) -> None: + """ + Validate activation function behavior. + """ + moe_test_helper(5120, 2048, 64, 32, act_fn, DtypeEnum.fp16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", [DtypeEnum.fp16, DtypeEnum.bf16]) +def test_dtypes(dtype: DtypeEnum) -> None: + """ + Validate data type behavior. + """ + moe_test_helper(5120, 2048, 64, 32, ActivationType.IDENTITY, dtype) diff --git a/tests/unit/inference/kernels/ragged_ops/__init__.py b/tests/unit/inference/kernels/ragged_ops/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/kernels/ragged_ops/ragged_testing_utils.py b/tests/unit/inference/kernels/ragged_ops/ragged_testing_utils.py new file mode 100644 index 000000000000..445c6c38b87f --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/ragged_testing_utils.py @@ -0,0 +1,300 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +from typing import List, Optional, Tuple + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.ragged import ( + AllocationMode, + DSSequenceDescriptor, + DSStateManager, + DSStateManagerConfig, + KVCacheConfig, + MemoryConfig, + PlaceholderSequenceDescriptor, + RaggedBatchWrapper, +) +from ...inference_test_utils import allclose + + +def build_simple_batch(seq_lens: List[int], + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False) -> RaggedBatchWrapper: + """ + Construct a simple batch with the given sequence lengths. This method should not + be used for for testing scenarios that require information about KV or sequence + history. + """ + total_tokens = max(sum(seq_lens), 1024) + n_seqs = max(len(seq_lens), 128) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens) + batch = RaggedBatchWrapper(config) + + batch.clear() + + for seq_len in seq_lens: + seq_desc = PlaceholderSequenceDescriptor() + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + + batch.finalize(padding=padding) + + return batch + + +def build_complex_batch(seq_params: List[Tuple[int, int, int]], + kv_block_size: int, + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False) -> Tuple[RaggedBatchWrapper, int]: + """ + Construct a fully paramtrized batch with the given sequence lengths. This method + can be used to construct more realistic inputs for testing scenarios that will interact + with all the members of the RaggedBatchWrapper. + """ + seq_lens = [seq_param[0] for seq_param in seq_params] + total_tokens = max(sum(seq_lens), 1024) + n_seqs = max(len(seq_lens), 128) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens) + batch = RaggedBatchWrapper(config) + + batch.clear() + + total_kv_blocks = 0 + + for seq_len, n_seen_tokens, kv_ptr in seq_params: + n_kv_blocks = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size + seq_desc = PlaceholderSequenceDescriptor(seen_tokens=n_seen_tokens, + cur_allocated_blocks=n_kv_blocks, + kv_blocks_ptr=kv_ptr) + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + total_kv_blocks += n_kv_blocks + + batch.finalize(padding=padding) + + return batch, total_kv_blocks + + +def build_batch_and_manager( + seq_params: List[Tuple[int, int]], + head_size: int, + n_heads_kv: int, + kv_block_size: int, + vocab_range: Optional[int] = 100, + padding: Optional[bool] = False, + kv_fill: Optional[List[torch.Tensor]] = None +) -> Tuple[RaggedBatchWrapper, DSStateManager, List[DSSequenceDescriptor]]: + """ + Will construct and populate a batch and KVCache with the given sequence parameters. + + Arguments: + seq_params (List[Tuple[int, int]]): A list of tuples containing the sequence length and + the number of tokens that have already been seen for that sequence. + head_size (int): The size of each attention head. + n_heads_kv (int): The number of attention heads for the KV-cache. + kv_block_size (int): The size of each block in the KV-cache. + vocab_range (Optional[int]): The range of the vocabulary. Defaults to 100. + padding (Optional[bool]): Whether to pad the batch. Defaults to False. + kv_fill (Optional[List[torch.Tensor]]): A list of tensors to use to populate the KV-cache. + If this is not provided, the KV-cache will be treated as empty and the contents should + not be relied upon. NOTE(cmikeh2): This functionality relies on the functionality + of LinearBlockedKVCopy. If tests relying on this feature are failing, make sure that + LinearBlockedKVCopy is working correctly. + """ + seq_lens = [seq_param[0] for seq_param in seq_params] + fill_lens = [seq_param[1] for seq_param in seq_params] + max_created_batch_len = max(sum(seq_lens), sum(fill_lens)) + total_tokens = max(max_created_batch_len, 1024) + n_seqs = max(len(seq_lens), 128) + + req_kv_blocks = [None] * n_seqs + total_kv_blocks = 0 + for i, (seq_len, n_seen_tokens) in enumerate(seq_params): + req_kv_blocks[i] = (seq_len + n_seen_tokens + kv_block_size - 1) // kv_block_size + total_kv_blocks += req_kv_blocks[i] + + kv_config = KVCacheConfig(block_size=kv_block_size, + num_allocation_groups=1, + cache_shape=(1, n_heads_kv, head_size)) + memory_config = MemoryConfig(mode=AllocationMode.ALLOCATE, size=total_kv_blocks) + + config = DSStateManagerConfig(max_tracked_sequences=n_seqs, + max_ragged_sequence_count=n_seqs, + max_ragged_batch_size=total_tokens, + memory_config=memory_config) + + batch = RaggedBatchWrapper(config) + state_manager = DSStateManager(config, kv_config) + + # At the beginning of operation, the design of the allocator is such that it will return + # linear blocks of memory. The following will "warm up" the allocator so that we can be + # more certain that code is not dependent on this behavior. + all_allocs = [] + for _ in range(20): + decision = random.randint(0, 1) + + if decision == 0: + blocks_to_allocate = random.randint(0, total_kv_blocks) + if blocks_to_allocate <= state_manager.free_blocks and blocks_to_allocate > 0: + all_allocs.append(state_manager.allocate_blocks(blocks_to_allocate)) + else: + if len(all_allocs) > 0: + idx = random.randint(0, len(all_allocs) - 1) + state_manager._kv_cache.free(all_allocs[idx]) + + del all_allocs[idx] + + for alloc in all_allocs: + state_manager._kv_cache.free(alloc) + + assert state_manager.free_blocks == total_kv_blocks + + batch.clear() + seq_descs = [] + + if kv_fill is None or sum(fill_lens) == 0: + for i, (seq_len, n_seen_tokens) in enumerate(seq_params): + # Create empty descriptor + seq_desc = state_manager.get_or_create_sequence(i) + + # Update `seen_tokens` in the descriptor + seq_desc.pre_forward(n_seen_tokens) + seq_desc.post_forward() + + # Ensure there's enough KV-cache for the sequence + kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i]) + print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}") + seq_desc.extend_kv_cache(kv_block_ids) + + # Insert sequence into batch + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + seq_desc.pre_forward(seq_len) + seq_descs.append(seq_desc) + else: + qkv = torch.empty((total_tokens, (n_heads_kv * 3) * head_size), + dtype=torch.float16, + device=get_accelerator().current_device()) + fills_as_tensor = torch.tensor(fill_lens, dtype=torch.int32) + fill_cumsum = torch.cat((torch.tensor([0], dtype=torch.int32), torch.cumsum(fills_as_tensor, dim=0))) + + for i, (_, n_seen_tokens) in enumerate(seq_params): + # Create empty descriptor + seq_desc = state_manager.get_or_create_sequence(i) + + # Update `seen_tokens` in the descriptor + if n_seen_tokens > 0: + dummy_fill_toks = torch.randint(0, vocab_range, (n_seen_tokens, )) + batch.insert_sequence(seq_desc, dummy_fill_toks) + seq_desc.pre_forward(n_seen_tokens) + + # Ensure there's enough KV-cache for the sequence + kv_block_ids = state_manager.allocate_blocks(req_kv_blocks[i]) + print(f"Allocated {req_kv_blocks[i]} blocks for sequence {i}: {kv_block_ids}") + seq_desc.extend_kv_cache(kv_block_ids) + seq_descs.append(seq_desc) + + if n_seen_tokens == 0: + continue + + assert kv_fill[i].shape[0] == n_seen_tokens + assert kv_fill[i].shape[1] == n_heads_kv * head_size * 2 + + local_q = torch.randn((n_seen_tokens, n_heads_kv * head_size), dtype=torch.float16, device=qkv.device) + local_qkv = torch.cat((local_q, kv_fill[i]), dim=1) + qkv[fill_cumsum[i]:fill_cumsum[i + 1]] = local_qkv + + batch.finalize(padding=padding) + + from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy + kv_copy = LinearBlockedKVCopy(head_size, n_heads_kv, n_heads_kv, torch.float16) + kv_cache = state_manager.get_cache(0) + kv_copy(kv_cache, qkv, batch) + + for seq_desc in seq_descs: + if seq_desc.in_flight_tokens > 0: + seq_desc.post_forward() + + batch.clear() + + for i, (seq_len, _) in enumerate(seq_params): + seq_desc = state_manager.get_or_create_sequence(i) + tokens = torch.randint(0, vocab_range, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + seq_desc.pre_forward(seq_len) + + # We will skip KV cache allocation here because we did a lump allocation above + # for both the fill and the sequence itself. + + batch.finalize(padding=padding) + + return batch, state_manager, seq_descs + + +def validate_kv_cache(kv_cache: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_descs: List[DSSequenceDescriptor], + batch: RaggedBatchWrapper, + exact: bool = True) -> None: + """ + Given a QKV tensor and a KV cache, validate that the cache contains the correct values. + """ + block_size = kv_cache.shape[1] + n_kv_heads = kv_cache.shape[3] + head_size = kv_cache.shape[4] + + inflight_descs = batch.inflight_seq_descriptors(on_device=False)[:batch.current_sequences] + + if inflight_descs.shape[0] != len(seq_descs): + raise ValueError("The number of sequence descriptors does not match the number of sequences in the batch.") + + for seq_desc, inflight_seq in zip(seq_descs, inflight_descs): + start_idx = inflight_seq[0] + assigned_kv_blocks = seq_desc.kv_cache_ids(on_device=False) + + real_k_values = k[start_idx:start_idx + seq_desc.in_flight_tokens] + real_v_values = v[start_idx:start_idx + seq_desc.in_flight_tokens] + + start_block_idx = seq_desc.seen_tokens // block_size + local_start_idx = 0 + cur_start_idx = seq_desc.seen_tokens + + for block_idx in range(start_block_idx, seq_desc.cur_allocated_blocks): + block = kv_cache[assigned_kv_blocks[0, block_idx].item()] + block_start_idx = cur_start_idx % block_size + n_tokens_to_check = min(block_size - block_start_idx, seq_desc.in_flight_tokens - local_start_idx) + block_end_idx = block_start_idx + n_tokens_to_check + + if exact: + assert torch.equal( + block[block_start_idx:block_end_idx, 0, :, :], + real_k_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + assert torch.equal( + block[block_start_idx:block_end_idx, 1, :, :], + real_v_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + else: + assert allclose( + block[block_start_idx:block_end_idx, 0, :, :], + real_k_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + assert allclose( + block[block_start_idx:block_end_idx, 1, :, :], + real_v_values[local_start_idx:local_start_idx + n_tokens_to_check].reshape( + n_tokens_to_check, n_kv_heads, head_size)) + + local_start_idx += n_tokens_to_check + cur_start_idx += n_tokens_to_check diff --git a/tests/unit/inference/kernels/ragged_ops/test_atom_builder.py b/tests/unit/inference/kernels/ragged_ops/test_atom_builder.py new file mode 100644 index 000000000000..a33c938a0608 --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_atom_builder.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.kernels.ragged_ops import AtomBuilder +from .ragged_testing_utils import build_complex_batch + +Q_BLOCK_SIZE = 128 +KV_BLOCK_SIZE = 128 + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_params', [(1, 0, 0), (1, 228, 0), (383, 0, 0), (1, 494, 0)]) +def test_single_sequence(seq_params) -> None: + seq_len, n_seen_tokens, _ = seq_params + + batch, _ = build_complex_batch([seq_params], kv_block_size=KV_BLOCK_SIZE, padding=False) + atom_builder = AtomBuilder() + + atoms = torch.empty((8, 8), dtype=torch.int32, device=torch.device("cpu")) + atoms, n_atoms = atom_builder(atoms, batch, Q_BLOCK_SIZE, KV_BLOCK_SIZE) + + calc_n_atoms = (seq_len + 127) // 128 + + assert n_atoms == calc_n_atoms + + for i, atom in enumerate(atoms[:n_atoms]): + # Since the ptr was 0, first 2 elements should be 0 + assert atom[0] == 0 + assert atom[1] == 0 + + # Since we have a single sequence, the q_start_idx should always be + # whichever atom we're on multiplied by the block size + assert atom[2] == i * Q_BLOCK_SIZE + assert atom[3] == min(Q_BLOCK_SIZE, seq_len - i * Q_BLOCK_SIZE) + total_toks = i * Q_BLOCK_SIZE + min(Q_BLOCK_SIZE, seq_len - i * Q_BLOCK_SIZE) + + assert atom[4] == (total_toks + n_seen_tokens + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE + assert atom[5] == (total_toks + n_seen_tokens) + + assert atom[6] == n_seen_tokens + i * Q_BLOCK_SIZE diff --git a/tests/unit/inference/kernels/ragged_ops/test_blocked_flash.py b/tests/unit/inference/kernels/ragged_ops/test_blocked_flash.py new file mode 100644 index 000000000000..a16a7775e964 --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_blocked_flash.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import itertools + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import ( + AtomBuilder, + BlockedFlashAttn, + get_q_block_size, + get_kv_block_size, + LinearBlockedKVCopy, +) +from deepspeed.inference.v2.ragged import split_kv +from deepspeed.ops.op_builder import RaggedUtilsBuilder + +from .ragged_testing_utils import build_batch_and_manager +from ...inference_test_utils import allclose + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + validate_accuracy = True +except ImportError: + validate_accuracy = False +""" +NOTE(cmikeh2): These tests depend on atom construction and KV-cache copying to behave correctly. +If one or the other of those is not working, then these tests will fail. Before debugging here, +make sure that the atom construction and KV-cache copying tests are passing. +""" + + +def _blocked_flash_testing_helper(head_size: int, n_heads_q: int, n_heads_kv: int, + seq_params: List[Tuple[int, int]]) -> None: + """ + Helper function for testing blocked flash attention. Used to enable parametrize to only set up + a subset of parameters before being passed to the unified test function. + """ + q_block_size = get_q_block_size(head_size) + kv_block_size = get_kv_block_size(head_size) + + kvs = [] + for _, history_len in seq_params: + if history_len > 0: + kvs.append( + torch.randn((history_len, 2 * n_heads_kv * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16)) + else: + kvs.append(None) + + batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs) + + atom_builder = AtomBuilder() + kv_copy = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, DtypeEnum.fp16) + atom_flash = BlockedFlashAttn(head_size, DtypeEnum.fp16) + + total_atoms = sum((seq[0] + q_block_size - 1) // q_block_size for seq in seq_params) + atoms = torch.empty((total_atoms, 8), dtype=torch.int32, device=get_accelerator().current_device()) + alloc_func = RaggedUtilsBuilder().load().allocate_fast_host_buffer + atoms_host = alloc_func(atoms) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16) + + atoms_host, n_atoms = atom_builder(atoms_host, batch, q_block_size, kv_block_size) + atoms.copy_(atoms_host[:n_atoms]) + + kv_cache = state_manager.get_cache(0) + kv_copy(kv_cache, qkv, batch) + + out = torch.empty((batch.current_tokens, head_size * n_heads_q), + device=get_accelerator().current_device(), + dtype=torch.float16) + k_cache, v_cache = split_kv(kv_cache) + q = qkv[:, :head_size * n_heads_q] + + atom_flash(out, q, k_cache, v_cache, atoms, 1.0) + + if validate_accuracy: + cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + + inflight_kv = qkv[:, head_size * n_heads_q:] + full_kvs = [] + for i, kv in enumerate(kvs): + if kv is not None: + full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0)) + else: + full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]) + run_kvs = torch.cat(full_kvs, dim=0) + k = run_kvs[:, :head_size * n_heads_kv] + v = run_kvs[:, head_size * n_heads_kv:] + + q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size)) + k_ref = k.reshape((k.shape[0], n_heads_kv, head_size)) + v_ref = v.reshape((v.shape[0], n_heads_kv, head_size)) + + max_seqlen_q = max([seq[0] for seq in seq_params]) + max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params]) + + ref_o = flash_attn_varlen_func(q_ref, + k_ref, + v_ref, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + softmax_scale=1.0, + causal=True) + + ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q) + + assert allclose(out, ref_o) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens", [2, 33, 65, 128, 256, 2037]) +def test_single_prompt(n_tokens: int) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(n_tokens, 0)] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("prompt_lengths", [(128, 128), (192, 38), (514, 713), (83, 312, 610)]) +def test_multiple_prompts(prompt_lengths: Tuple[int, int]) -> None: + """ + Test multiple prompts in a single batch. + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(prompt_lengths[i], 0) for i in range(len(prompt_lengths))] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_params", [(1, 34), (43, 40), (1, 144), (64, 128), (332, 628)]) +def test_continuation(seq_params: Tuple[int, int]) -> None: + """ + Test continued generation/prompt processing. + """ + head_size = 64 + n_heads_q = 32 + n_heads_kv = 32 + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, [seq_params]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [64, 128]) +def test_head_size(head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_config", [(32, 8), (64, 16), (40, 8)]) +def test_gqa(head_config: Tuple[int, int]) -> None: + head_size = 128 + n_heads_q = head_config[0] + n_heads_kv = head_config[1] + + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +def test_fully_composed() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) diff --git a/tests/unit/inference/kernels/ragged_ops/test_blocked_kv_copy.py b/tests/unit/inference/kernels/ragged_ops/test_blocked_kv_copy.py new file mode 100644 index 000000000000..90fe26eb4490 --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_blocked_kv_copy.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import LinearBlockedKVCopy +from .ragged_testing_utils import build_batch_and_manager, validate_kv_cache + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 8), (63, 1)]) +def test_single_sequence_single_block(n_tokens: int, history_size: int): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) + + +@pytest.mark.inference_v2_ops +def test_multi_sequence() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + copy_impl = LinearBlockedKVCopy(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + k = qkv[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv[:, head_size * (n_heads_q + n_heads_kv):] + + validate_kv_cache(kv_cache, k, v, seq_descs, batch) diff --git a/tests/unit/inference/kernels/ragged_ops/test_blocked_rotary_emb.py b/tests/unit/inference/kernels/ragged_ops/test_blocked_rotary_emb.py new file mode 100644 index 000000000000..35b92ef86305 --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_blocked_rotary_emb.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import BlockedRotaryEmbeddings, BlockedTrainedRotaryEmbeddings +from deepspeed.inference.v2.ragged import RaggedBatchWrapper, DSSequenceDescriptor +from .ragged_testing_utils import build_batch_and_manager, validate_kv_cache +from ...inference_test_utils import allclose +""" +NOTE(cmikeh2): It is very possible to see unit test failures (even on FP16) depending on when +certain values are casted up to or down from float32. If we are seeing accuracy issues, we should +make sure we are aligning on the training implementation's cast pattern here, given these tolerances +tend to be sufficient elsewhere. +""" + + +def rotary_pos_embs(q: torch.Tensor, k: torch.Tensor, seq_descs: List[DSSequenceDescriptor], batch: RaggedBatchWrapper, + head_size: int): + + def make_cos_sin_emb(seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: + t = torch.arange(seq_len, dtype=torch.float32, device=get_accelerator().current_device()) + inv_freq = (1.0 / (10000.0**(torch.arange( + 0, head_size, 2, dtype=torch.float32, device=get_accelerator().current_device()) / head_size))).half() + + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + return torch.cos(emb)[:, None, :], torch.sin(emb)[:, None, :], inv_freq + + def rotate_half(x: torch.Tensor) -> torch.Tensor: + return torch.cat((-x[..., x.shape[-1] // 2:], x[..., :x.shape[-1] // 2]), dim=-1) + + cos, sin, freqs = make_cos_sin_emb(1024) + + q_out = torch.empty_like(q) + k_out = torch.empty_like(k) + n_heads_q = q.shape[1] // head_size + n_heads_kv = k.shape[1] // head_size + + inflight_descs = batch.inflight_seq_descriptors(on_device=False)[:batch.current_sequences] + + if inflight_descs.shape[0] != len(seq_descs): + raise ValueError("The number of sequence descriptors does not match the number of sequences in the batch.") + + for seq_desc, inflight_seq in zip(seq_descs, inflight_descs): + start_idx = inflight_seq[0] + n_tokens = seq_desc.in_flight_tokens + + q_src = q[start_idx:start_idx + n_tokens].reshape(n_tokens, n_heads_q, head_size).float() + k_src = k[start_idx:start_idx + n_tokens].reshape(n_tokens, n_heads_kv, head_size).float() + freq_start_offset = seq_desc.seen_tokens + + cos_chunk = cos[range(freq_start_offset, freq_start_offset + n_tokens)] + sin_chunk = sin[range(freq_start_offset, freq_start_offset + n_tokens)] + + q_emb = q_src * cos_chunk + rotate_half(q_src) * sin_chunk + k_emb = k_src * cos_chunk + rotate_half(k_src) * sin_chunk + + q_out[start_idx:start_idx + n_tokens] = q_emb.reshape(n_tokens, n_heads_q * head_size).to(q_out.dtype) + k_out[start_idx:start_idx + n_tokens] = k_emb.reshape(n_tokens, n_heads_kv * head_size).to(k_out.dtype) + + return q_out, k_out, freqs + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(1, 0), (17, 0), (33, 15), (1, 63)]) +@pytest.mark.parametrize("trained_emb", [False, True]) +def test_single_sequence_single_block(n_tokens: int, history_size: int, trained_emb: bool): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, history_size", [(128, 0), (177, 0), (169, 8), (117, 88)]) +@pytest.mark.parametrize("trained_emb", [False, True]) +def test_single_sequence_multiple_blocks(n_tokens: int, history_size: int, trained_emb: bool): + """ + Validate that the copy works correctly + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch, state_manager, seq_descs = build_batch_and_manager([(n_tokens, history_size)], head_size, n_heads_kv, + kv_block_size) + + assert batch.current_sequences == 1 + assert batch.current_tokens == n_tokens + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("trained_emb", [False, True]) +def test_multi_sequences(trained_emb: bool) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + kv_block_size = 64 + device = get_accelerator().current_device() + + batch_config = [ + (128, 0), + (177, 0), + (169, 8), + (117, 88), + (1, 293), + (1, 733), + (1, 33), + ] + + batch, state_manager, seq_descs = build_batch_and_manager(batch_config, head_size, n_heads_kv, kv_block_size) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=device, + dtype=torch.float16) + qkv_ref = qkv.clone() + + q = qkv_ref[:, :head_size * n_heads_q] + k = qkv_ref[:, head_size * n_heads_q:head_size * (n_heads_q + n_heads_kv)] + v = qkv_ref[:, head_size * (n_heads_q + n_heads_kv):] + + q_ref, k, freqs = rotary_pos_embs(q, k, seq_descs, batch, head_size) + freqs = freqs.half() + + kv_cache = state_manager.get_cache(0) + + if trained_emb: + copy_impl = BlockedTrainedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch, freqs) + else: + copy_impl = BlockedRotaryEmbeddings(head_size, n_heads_q, n_heads_kv, torch.float16) + copy_impl(kv_cache, qkv, batch) + + assert allclose(qkv[:, :head_size * n_heads_q], q_ref) + validate_kv_cache(kv_cache, k, v, seq_descs, batch, exact=False) diff --git a/tests/unit/inference/kernels/ragged_ops/test_logits_gather.py b/tests/unit/inference/kernels/ragged_ops/test_logits_gather.py new file mode 100644 index 000000000000..0208b733ab5b --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_logits_gather.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import RaggedLogitsGather +from ...inference_test_utils import allclose, get_dtypes +from .ragged_testing_utils import build_simple_batch + + +def baseline_implementation(hidden_states: torch.Tensor, seq_lens: List[int]) -> torch.Tensor: + output = torch.empty((len(seq_lens), hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + + offset = 0 + for i, seq_len in enumerate(seq_lens): + output[i] = hidden_states[offset + seq_len - 1] + offset += seq_len + + return output + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('dtype', get_dtypes()) +def test_supported_dtypes(dtype: torch.dtype) -> None: + """ + Validate support on nominally supported data types. + """ + model_dim = 4096 + + batch = build_simple_batch([256], padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, [256]) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_lens', [[128, 64, 192, 32], [57, 112, 63, 89, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], + [63, 27, 74, 83, 32, 17, 1, 1, 1, 1, 1]]) +def test_multiple_sequences(seq_lens: List[int]) -> None: + """ + Validate support on more multi-sequence inputs. + """ + model_dim = 4096 + dtype = torch.float16 + + batch = build_simple_batch(seq_lens, padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, seq_lens) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("model_dim", [1024, 6144, 6784]) +def test_problem_size_permutations(model_dim: int) -> None: + """ + Validate for different embedding sizes. + """ + dtype = torch.float16 + seq_lens = [128, 64, 192, 32] + + batch = build_simple_batch(seq_lens, padding=False) + hidden_states = torch.randn((batch.current_tokens, model_dim), + dtype=dtype, + device=get_accelerator().current_device()) + + reference_result = baseline_implementation(hidden_states, seq_lens) + + kernel = RaggedLogitsGather(model_dim, dtype) + output = torch.empty_like(reference_result) + kernel(output, hidden_states, batch) + + assert allclose(output, reference_result) diff --git a/tests/unit/inference/kernels/ragged_ops/test_moe_gather.py b/tests/unit/inference/kernels/ragged_ops/test_moe_gather.py new file mode 100644 index 000000000000..5fa375b49c19 --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_moe_gather.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import ( + MoEGather, + MoEScatter, + RaggedTop1Gating, +) +from .ragged_testing_utils import build_simple_batch +""" +For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` and +``MoEScatter`` to produce correct inputs. If either of these kernels is broken +these tests will fail, so double check the unit test results there before +debugging here. +""" + + +def build_inputs(n_tokens, n_experts, do_padding): + + assert n_tokens <= 2048, "This test will break if n_tokens > 2048" + + # Sequence composition shouldn't matter here + batch = build_simple_batch([n_tokens], padding=do_padding) + + logits = torch.randn((batch.tensor_toks, n_experts), + dtype=torch.float16, + device=get_accelerator().current_device()) + + # This will make each token's value equal to its index. NOTE: This will break for + # tokens with index > 2048. + hidden_states = torch.arange(batch.tensor_toks, dtype=torch.float16, + device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( + batch.tensor_toks, 4096).contiguous() + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + # Gating outputs + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + # Scatter outputs + moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + scatter = MoEScatter(DtypeEnum.fp16, 4096) + scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + + return batch, moe_input, scores, mapped_slots, expert_counts + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) +@pytest.mark.parametrize("do_padding", [True, False]) +def test_moe_gather(n_tokens, n_experts, do_padding): + + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, do_padding) + + output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + + gather = MoEGather(DtypeEnum.fp16, 4096) + gather(output, moe_input, scores, mapped_slots, expert_counts) + + for token_idx in range(n_tokens): + assert torch.equal( + output[token_idx], + torch.full((4096, ), + token_idx * scores[token_idx], + dtype=torch.float16, + device=get_accelerator().current_device())) diff --git a/tests/unit/inference/kernels/ragged_ops/test_moe_scatter.py b/tests/unit/inference/kernels/ragged_ops/test_moe_scatter.py new file mode 100644 index 000000000000..4ca051410c1c --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_moe_scatter.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTop1Gating +from .ragged_testing_utils import build_simple_batch +""" +For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` to produce correct +inputs. If ``RaggedTop1Gating`` is broken, these tests will fail, so double check +the unit test results there before debugging here. +""" + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) +@pytest.mark.parametrize("do_padding", [True, False]) +def test_moe_scatter(n_tokens, n_experts, do_padding): + + # Sequence composition shouldn't matter here + batch = build_simple_batch([n_tokens], padding=do_padding) + + logits = torch.randn((batch.tensor_toks, n_experts), + dtype=torch.float16, + device=get_accelerator().current_device()) + + # This will make each token's value equal to its index. NOTE: This will break for + # tokens with index > 2048. + hidden_states = torch.arange(batch.tensor_toks, dtype=torch.float16, + device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( + batch.tensor_toks, 4096).contiguous() + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + # Gating outputs + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, ), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + # Scatter outputs + moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + + scatter = MoEScatter(DtypeEnum.fp16, 4096) + scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + assert torch.equal(expert_cumsum, torch.cumsum(expert_counts, dim=0).to(torch.int64)) + + for token_idx in range(batch.tensor_toks): + if token_idx < n_tokens: + expert_idx = expert_assignment[token_idx].item() + if expert_idx == 0: + expert_cumsum_val = 0 + else: + expert_cumsum_val = expert_cumsum[expert_idx - 1] + offset = expert_offset[token_idx] + total_offset = offset + expert_cumsum_val + + assert total_offset == mapped_slots[token_idx].item() + assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) + else: + assert mapped_slots[token_idx].item() == -1 + + assert expert_cumsum[-1] == n_tokens diff --git a/tests/unit/inference/kernels/ragged_ops/test_ragged_embed.py b/tests/unit/inference/kernels/ragged_ops/test_ragged_embed.py new file mode 100644 index 000000000000..94f3f143274e --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_ragged_embed.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.kernels.ragged_ops import RaggedEmbeddingKernel +from ...inference_test_utils import allclose, get_dtypes +from .ragged_testing_utils import build_batch_and_manager + + +def baseline_implementation(token_ids: torch.Tensor, + embedding_table: torch.Tensor, + unpadded_size: int, + positional_embedding_table: Optional[torch.Tensor] = None, + positional_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Baseline implementation for our ragged embedding kernel. + """ + if unpadded_size == token_ids.shape[0]: + token_embed = torch.nn.functional.embedding(token_ids, embedding_table) + + if positional_embedding_table is not None: + pos_embed = torch.nn.functional.embedding(positional_ids, positional_embedding_table) + token_embed += pos_embed + return token_embed + else: + real_token_ids = token_ids[:unpadded_size] + output = torch.empty((token_ids.shape[0], embedding_table.shape[1]), + dtype=embedding_table.dtype, + device=get_accelerator().current_device()) + unpadded_output = torch.nn.functional.embedding(real_token_ids, embedding_table) + + # Positional embeddings aren't padded because it's simulated + if positional_embedding_table is not None: + pos_embed = torch.nn.functional.embedding(positional_ids, positional_embedding_table) + unpadded_output += pos_embed + + output[:unpadded_size] = unpadded_output + return output + + +def _ragged_embed_test_helper(sequence_config: List[Tuple[int, int]], + embed_dtype: torch.dtype, + token_dtype: torch.dtype, + embed_dim: int, + vocab_size: int, + do_padding: bool = False, + pos_embed_size: int = -1, + pos_embed_offset: int = 0) -> None: + """ + Helper for embedding test to limit the number of tests to run. + + Params: + embed_dim (int): Model dimension + vocab_size (int): Leading dimension on embedding weight + pos_embed_size (int): Size of positional embedding. If negative, no positional embedding + is used. + pos_embed_offset (int): Offset for positional embedding. Effectively, the raw offsets + of a token into a sequence are offset by this amount into the embedding matrix. ( + i.e. the shape of the positional embeddings is (pos_embed_size + pos_embed_offset + embed_dim) + """ + device = get_accelerator().current_device() + + # Heads/Block size are irrelevant here but need something. + batch, _, _, = build_batch_and_manager(sequence_config, 64, 16, 64, vocab_range=vocab_size, padding=do_padding) + + embedding_table = torch.randn((vocab_size, embed_dim), dtype=embed_dtype, device=device) + + if pos_embed_size > 0: + pos_embedding_table = torch.randn((pos_embed_size + pos_embed_offset, embed_dim), + dtype=embed_dtype, + device=device) + positional_ids = torch.cat([ + torch.arange(start_idx, start_idx + seq_len, dtype=token_dtype, device=device) + for seq_len, start_idx in sequence_config + ]) + pos_embed_offset + else: + pos_embedding_table = None + positional_ids = None + + baseline_output = baseline_implementation(batch.input_ids().to(token_dtype), embedding_table, batch.current_tokens, + pos_embedding_table, positional_ids) + + kernel = RaggedEmbeddingKernel(embed_dtype, token_dtype, embed_dim) + output = torch.empty_like(baseline_output) + + kernel(output, + batch, + embedding_table, + position_embed_weight=pos_embedding_table, + position_embed_offset=pos_embed_offset) + + if do_padding: + assert output.shape[0] != batch.current_tokens + + assert allclose(output[:batch.current_tokens], baseline_output[:batch.current_tokens]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('token_dtype', [torch.int32, torch.int64]) +@pytest.mark.parametrize('embed_dtype', get_dtypes()) +def test_dtype_permutations(token_dtype: torch.dtype, embed_dtype: torch.dtype) -> None: + """ + Validate (on a single problem size) that the kernel support for different data types + is correct. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper([(256, 0)], embed_dtype, token_dtype, embed_dim, vocab_size) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('vocab_size, embed_dim', [(1024, 1024), (32000, 5120), (50304, 6144)]) +def test_problem_size_permutations(vocab_size: int, embed_dim: int) -> None: + """ + Validate on wider range of problem sizes. + """ + + _ragged_embed_test_helper([(256, 0)], torch.float16, torch.int32, embed_dim, vocab_size) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('seq_lens', [[128, 64, 192, 32], [57, 112, 63, 89, 1, 1, 1, 1]]) +@pytest.mark.parametrize('do_padding', [True, False]) +def test_complex_sequences(seq_lens: List[int], do_padding: bool) -> None: + """ + Validate on different ragged batch construction scenarios. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper([(seq_len, 0) for seq_len in seq_lens], + torch.float16, + torch.int32, + embed_dim, + vocab_size, + do_padding=do_padding) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_lens", [[(256, 0)], [(256, 0), + (128, 0)], [(256, 0), (128, 0), + (64, 0)], [(1, 877), (619, 0), (213, 372), (1, 45)]]) +def test_positional_embedding(seq_lens: List[Tuple[int, int]]) -> None: + """ + Validate that positional embedding works correctly. + """ + embed_dim = 4096 + vocab_size = 50304 + + _ragged_embed_test_helper(seq_lens, torch.float16, torch.int32, embed_dim, vocab_size, pos_embed_size=2048) + + +@pytest.mark.inference_v2_ops +def test_positional_embedding_offset() -> None: + """ + Validate that positional embedding works correctly with an offset. + """ + embed_dim = 4096 + vocab_size = 50304 + seq_config = [(1, 877), (619, 0), (213, 372), (1, 45)] + + _ragged_embed_test_helper(seq_config, + torch.float16, + torch.int32, + embed_dim, + vocab_size, + pos_embed_size=2048, + pos_embed_offset=2) diff --git a/tests/unit/inference/kernels/ragged_ops/test_top_1_gating.py b/tests/unit/inference/kernels/ragged_ops/test_top_1_gating.py new file mode 100644 index 000000000000..96bf28eea7ad --- /dev/null +++ b/tests/unit/inference/kernels/ragged_ops/test_top_1_gating.py @@ -0,0 +1,120 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import torch.nn.functional as F + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import DtypeEnum +from deepspeed.inference.v2.kernels.ragged_ops import RaggedTop1Gating +from .ragged_testing_utils import build_simple_batch +from ...inference_test_utils import allclose + + +def _test_single_mapping_helper(n_tokens: int, + n_experts: int, + assigned_expert: int, + logit_fill: float = 0.0, + match_fill: float = 1.0) -> None: + logits = torch.full((n_tokens, n_experts), + logit_fill, + dtype=torch.float16, + device=get_accelerator().current_device()) + + logits[:, assigned_expert] = match_fill + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + assert expert_counts[assigned_expert] == n_tokens + assert torch.all(expert_assignment == assigned_expert) + assert torch.unique(expert_offset).shape[0] == n_tokens + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens, n_experts', [(1, 16), (17, 16), (32, 128), (89, 128), (433, 128)]) +def test_single_mapping_gating(n_tokens: int, n_experts: int) -> None: + """ + Evaluate our expert stacking behavior in complete isolation. This ensures all tokens + mapped to the same expert are getting unique offsets and identical scores. + """ + assigned_expert = 13 + _test_single_mapping_helper(n_tokens, n_experts, assigned_expert) + + +@pytest.mark.inference_v2_ops +def test_negative_logits(): + """ + Ensure that scores/values are propagated correctly when all the logits are negative. An + earlier implementation of the scoring would return NaN for this case. + """ + _test_single_mapping_helper(128, 32, 13, logit_fill=-2.0, match_fill=-1.0) + + +@pytest.mark.inference_v2_ops +def test_determinism(): + """ + Ensure that ties between two logits are broken deterministically. This is essential when + the gating is distributed across multiple devices that need to map the same token to + the same expert. + """ + + n_tokens = 512 + n_experts = 64 + + logits = torch.zeros((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + logits[:, 19] = 1.0 + logits[:, 26] = 1.0 + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + for _ in range(1024): + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + assert expert_counts[19] == n_tokens + assert expert_counts[26] == 0 + assert torch.all(expert_assignment == 19) + assert torch.unique(expert_offset).shape[0] == n_tokens + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens, n_experts', [(1, 16), (17, 16), (32, 128), (89, 128), (433, 2)]) +def test_score_accuracy(n_tokens: int, n_experts: int) -> None: + """ + Validate expert scores are correct. + """ + logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + + gate = RaggedTop1Gating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + + ref_scores = F.softmax(logits.float(), dim=1).max(dim=1).values + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + assert allclose(scores, ref_scores) + assert expert_counts.sum() == n_tokens diff --git a/tests/unit/inference/model_implementations/__init__.py b/tests/unit/inference/model_implementations/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/model_implementations/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/model_implementations/parameters/__init__.py b/tests/unit/inference/model_implementations/parameters/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/model_implementations/parameters/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/model_implementations/parameters/test_layer_inheritance.py b/tests/unit/inference/model_implementations/parameters/test_layer_inheritance.py new file mode 100644 index 000000000000..20803e53a320 --- /dev/null +++ b/tests/unit/inference/model_implementations/parameters/test_layer_inheritance.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + +from .utils import validate_device, SimpleParam, DummyInferenceModel + + +class ParentLayer(LayerContainer): + """ + A layer that has a dependency on a simple parameter. + """ + + param_1: SimpleParam + + +class ChildLayer(ParentLayer): + """ + A layer that inherits from another layer. + """ + + param_2: SimpleParam + + +@pytest.mark.inference_v2 +def test_layer_inheritance(): + inference_model = DummyInferenceModel() + + multi_param_layer = ChildLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + assert multi_param_layer.is_initialized is True + assert isinstance(multi_param_layer.param_1, torch.Tensor) + assert isinstance(multi_param_layer.param_2, torch.Tensor) + + validate_device(multi_param_layer.param_1) + validate_device(multi_param_layer.param_2) diff --git a/tests/unit/inference/model_implementations/parameters/test_mapping.py b/tests/unit/inference/model_implementations/parameters/test_mapping.py new file mode 100644 index 000000000000..3c74d7a0479a --- /dev/null +++ b/tests/unit/inference/model_implementations/parameters/test_mapping.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.allocator import on_device +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParamList +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class MultiDependencyContainer(ParameterBase): + + dependency_1: torch.Tensor + + dependency_2: torch.Tensor + + @on_device + def finalize(self) -> torch.Tensor: + return torch.cat([self.dependency_1, self.dependency_2]) + + +class ListDependencyContainer(ParameterBase): + + dependencies: ParamList("list_items") # noqa: F821 + + @on_device + def finalize(self) -> torch.Tensor: + return torch.cat(tuple(self.dependencies)) + + +class MappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend.dependency_1", + "model.val.item.d_2": "multi_depend.dependency_2", + "model.list_vals.*.d": "list_depend.dependencies" + } + + multi_depend: MultiDependencyContainer + + list_depend: ListDependencyContainer + + +class SubMappingLayer(MappingLayer): + PARAM_MAPPING = { + "model.val.item2.d_1": "multi_depend2.dependency_1", + "model.val.item2.d_2": "multi_depend2.dependency_2", + } + + multi_depend2: MultiDependencyContainer + + +class DoubleMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": ["multi_depend.dependency_1", "multi_depend.dependency_2"], + } + + multi_depend: MultiDependencyContainer + + +class InferenceModel: + + @property + def list_items(self) -> int: + return 16 + + +@pytest.mark.inference_v2 +def test_mapping_syntax(): + model = InferenceModel() + + mapping_layer = MappingLayer(model) + + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend, torch.Tensor) + + for i in range(16): + mapping_layer.set_dependency(f"model.list_vals.{i}.d", torch.ones(1) * i) + if i != 16 - 1: + assert mapping_layer.is_initialized == False + + assert isinstance(mapping_layer.list_depend, torch.Tensor) + assert mapping_layer.is_initialized == True + + +@pytest.mark.inference_v2 +def test_sub_mapping_syntax(): + model = InferenceModel() + + mapping_layer = SubMappingLayer(model) + + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend, torch.Tensor) + + mapping_layer.set_dependency("model.val.item2.d_1", torch.ones(1)) + mapping_layer.set_dependency("model.val.item2.d_2", torch.ones(1) * 2) + + assert isinstance(mapping_layer.multi_depend2, torch.Tensor) + + # We want to check into double digits to make sure that this isn't specific + # to single difit indexing. + for i in range(16): + mapping_layer.set_dependency(f"model.list_vals.{i}.d", torch.ones(1) * i) + if i != 16 - 1: + assert mapping_layer.is_initialized == False + + assert isinstance(mapping_layer.list_depend, torch.Tensor) + assert mapping_layer.is_initialized == True + + +@pytest.mark.inference_v2 +def test_double_mapping_syntax(): + model = InferenceModel() + + mapping_layer = DoubleMappingLayer(model) + mapping_layer.set_dependency("model.val.item.d_1", torch.ones(1)) + + # The single parameter setting should immediately make the parameter finalized + # and the whole layer initialized. + assert isinstance(mapping_layer.multi_depend, torch.Tensor) + assert mapping_layer.is_initialized == True + + +@pytest.mark.inference_v2 +def test_insufficient_mapping_syntax(): + """ + In the above example, we don't have a mapping for `multi_depend2.dependency_2`. + """ + + with pytest.raises(ValueError): + + class InsuffienctMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend1.dependency_1", + "model.val.item.d_2": "multi_depend1.dependency_2", + "model.val.item2.d_1": "multi_depend2.dependency_1", + } + + multi_depend1: MultiDependencyContainer + + multi_depend2: MultiDependencyContainer + + +@pytest.mark.inference_v2 +def test_unknown_target_mapping_syntax(): + """ + In the above example, `multi_depend_unknown` does not exist + """ + + with pytest.raises(ValueError): + + class UnknownTargetMappingLayer(LayerContainer): + PARAM_MAPPING = { + "model.val.item.d_1": "multi_depend1.dependency_1", + "model.val.item.d_2": "multi_depend1.dependency_2", + "model.val.item2.d_1": "multi_depend_unknown.dependency_1", + } + + multi_depend: MultiDependencyContainer diff --git a/tests/unit/inference/model_implementations/parameters/test_multi_parameter_layer.py b/tests/unit/inference/model_implementations/parameters/test_multi_parameter_layer.py new file mode 100644 index 000000000000..6bfc04e97c30 --- /dev/null +++ b/tests/unit/inference/model_implementations/parameters/test_multi_parameter_layer.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + +from .utils import validate_device, SimpleParam, ListParam, DummyInferenceModel + + +class MultiParameterLayer(LayerContainer): + """ + Two dependencies, both of which are simple parameters. + """ + + param_1: SimpleParam + + param_2: SimpleParam + + +class MixedMultiParameterLayer(LayerContainer): + """ + Two dependencies, one of which is a simple parameter, the other is a list parameter. + """ + + param_1: SimpleParam + + param_2: ListParam + + +@pytest.mark.inference_v2 +def test_multi_parameter_layer(): + inference_model = DummyInferenceModel() + + multi_param_layer = MultiParameterLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + assert multi_param_layer.is_initialized is True + assert isinstance(multi_param_layer.param_1, torch.Tensor) + assert isinstance(multi_param_layer.param_2, torch.Tensor) + + validate_device(multi_param_layer.param_1) + validate_device(multi_param_layer.param_2) + + +@pytest.mark.inference_v2 +def test_mixed_multi_parameter_layer(): + inference_model = DummyInferenceModel() + + mixed_multi_param_layer = MixedMultiParameterLayer(inference_model) + + assert mixed_multi_param_layer.n_params == 2 + assert mixed_multi_param_layer.is_initialized is False + + mixed_multi_param_layer.param_2.params[1] = torch.full((16, 16), 2.0) + assert mixed_multi_param_layer.is_initialized is False + assert not isinstance(mixed_multi_param_layer.param_2, torch.Tensor) + + mixed_multi_param_layer.param_1.param = torch.ones(16, 16) + assert mixed_multi_param_layer.is_initialized is False + assert isinstance(mixed_multi_param_layer.param_1, torch.Tensor) + + validate_device(mixed_multi_param_layer.param_1) + + mixed_multi_param_layer.param_2.params[0] = torch.full((16, 16), 2.0) + + assert mixed_multi_param_layer.is_initialized is True + assert isinstance(mixed_multi_param_layer.param_2, torch.Tensor) + + validate_device(mixed_multi_param_layer.param_2) + + +class NoCopyInferenceModel: + + @property + def num_dependencies(self) -> int: + return 2 + + def transform(self, param: torch.Tensor) -> torch.Tensor: + return param + + +@pytest.mark.inference_v2 +def test_device_validation(): + inference_model = NoCopyInferenceModel() + + multi_param_layer = MultiParameterLayer(inference_model) + + assert multi_param_layer.n_params == 2 + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_1.param = torch.ones(16, 16) + + assert multi_param_layer.is_initialized is False + + multi_param_layer.param_2.param = torch.full((16, 16), 2.0) + + with pytest.raises(RuntimeError): + # NoCopyInference model did not copy the parameters, so the device validation should fail. + assert multi_param_layer.is_initialized is True diff --git a/tests/unit/inference/model_implementations/parameters/test_parameter_list.py b/tests/unit/inference/model_implementations/parameters/test_parameter_list.py new file mode 100644 index 000000000000..42edd90595fa --- /dev/null +++ b/tests/unit/inference/model_implementations/parameters/test_parameter_list.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParamList +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.allocator import on_device + +from .utils import validate_device + + +class SimpleMoELayer(LayerContainer): + + moe_mlp_1: UnfusedMoEMLP1Parameter + + +class DummyInferenceModel: + + def __init__(self, experts_per_rank: int) -> None: + self._num_experts = experts_per_rank + + @property + def num_experts(self) -> int: + return self._num_experts + + @on_device + def transform_moe_mlp_1_param(self, param: torch.Tensor) -> torch.Tensor: + return param + + +@pytest.mark.inference_v2 +def test_simple_moe_layer(): + + inference_model = DummyInferenceModel(experts_per_rank=2) + + simple_moe_layer = SimpleMoELayer(inference_model) + + assert simple_moe_layer.moe_mlp_1.experts[0] is None + assert simple_moe_layer.moe_mlp_1.experts[1] is None + + # Set the first expert + simple_moe_layer.moe_mlp_1.experts[0] = torch.zeros(16, 16) + + assert simple_moe_layer.moe_mlp_1.experts[0] is not None + assert simple_moe_layer.moe_mlp_1.experts[1] is None + + assert not simple_moe_layer.is_initialized + + # Set the second expert + simple_moe_layer.moe_mlp_1.experts[1] = torch.ones(16, 16) + + # We have all the experts, so the layer should be initialized + assert simple_moe_layer.is_initialized + assert isinstance(simple_moe_layer.moe_mlp_1, torch.Tensor) + + validate_device(simple_moe_layer.moe_mlp_1) + + +""" +Check that we can mix the number of elements in lists in the same context and have that +be tracked correctly. +""" + + +class CustomListParam1(ParameterBase): + + deps: ParamList("attr_1") + + +class CustomListParam2(ParameterBase): + + deps: ParamList("attr_2") + + +class MixedLayer(LayerContainer): + + list_1: CustomListParam1 + list_2: CustomListParam2 + + +class MixedInferenceModel: + + @property + def attr_1(self) -> int: + return 1 + + @property + def attr_2(self) -> int: + return 2 + + +@pytest.mark.inference_v2 +def test_mixed_param_lists(): + model = MixedInferenceModel() + + layer = MixedLayer(model) + + assert layer.list_1.deps.n_params == 1 + assert layer.list_2.deps.n_params == 2 diff --git a/tests/unit/inference/model_implementations/parameters/utils.py b/tests/unit/inference/model_implementations/parameters/utils.py new file mode 100644 index 000000000000..0d2cbb27d40e --- /dev/null +++ b/tests/unit/inference/model_implementations/parameters/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.allocator import on_device +from deepspeed.inference.v2.model_implementations.parameter_base import ParameterBase, ParametrizedList + + +class SimpleParam(ParameterBase): + """ + Parameter with single dependency. + """ + + param: torch.Tensor + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform(self.param) + + +class SimpleParametrizedList(ParametrizedList): + """ + Parameter list based on `num_dependencies` attribute. + """ + + count_attr: str = "num_dependencies" + + +class ListParam(ParameterBase): + """ + Parameter with list dependency. + + NOTE: This uses the tuple workaround for the `ParametrizedList` class + as described in the docstring of `ParametrizedList`. + """ + + params: SimpleParametrizedList + + def finalize(self) -> torch.Tensor: + return self.inference_model.transform(torch.cat(tuple(self.params))) + + +class DummyInferenceModel: + + @property + def num_dependencies(self) -> int: + return 2 + + @on_device + def transform(self, param: torch.Tensor) -> torch.Tensor: + return param + + +def validate_device(tensor: torch.Tensor): + assert tensor.device == torch.device(get_accelerator().current_device()) diff --git a/tests/unit/inference/model_implementations/sharding/__init__.py b/tests/unit/inference/model_implementations/sharding/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/model_implementations/sharding/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/model_implementations/sharding/test_attn_out_sharding.py b/tests/unit/inference/model_implementations/sharding/test_attn_out_sharding.py new file mode 100644 index 000000000000..850c4c24fde6 --- /dev/null +++ b/tests/unit/inference/model_implementations/sharding/test_attn_out_sharding.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + +# None of the logic should be dependent on head size. +HEAD_SIZE = 64 + + +def fill_with_head_ids(head_size: int, n_heads: int) -> torch.Tensor: + """ + Fills a tensor with the associated head ids. All columns should have the same value. + """ + head_ids = torch.arange(n_heads, dtype=torch.half, device=get_accelerator().current_device()) + + head_ids = head_ids.repeat_interleave(head_size).repeat(head_size * n_heads).reshape(n_heads * head_size, -1) + return head_ids + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads, n_shards", [(1, 1), (8, 4), (32, 8)]) +def test_mha_even_sharding(n_heads: int, n_shards: int): + """ + Even head sharding for MHA. + + Args: + n_heads (int): The number QKV heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads) + + n_local_heads = n_heads // n_shards + sharded_shape = (HEAD_SIZE * n_heads, HEAD_SIZE * n_local_heads) + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + assert sharded_param.shape == sharded_shape + + heads = torch.chunk(sharded_param, n_local_heads, dim=1) + + for i, head in enumerate(heads): + assert torch.all(head == i + shard_rank * n_local_heads) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads, n_shards", [(3, 2), (20, 8)]) +def test_mha_unbalanced_sharding(n_heads: int, n_shards: int): + """ + Unbalanced head sharding for MHA. + + Args: + n_heads (int): The number QKV heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads) + + max_heads = 0 + min_heads = n_heads + seen_heads = set() + total_heads = 0 + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + + n_local_heads = sharded_param.shape[1] // HEAD_SIZE + total_heads += n_local_heads + max_heads = max(max_heads, n_local_heads) + min_heads = min(min_heads, n_local_heads) + + for i in range(n_local_heads): + head_ids = torch.unique_consecutive(sharded_param[:, i * HEAD_SIZE:(i + 1) * HEAD_SIZE]) + assert len(head_ids) == 1 + seen_heads.add(head_ids.item()) + + assert max_heads == min_heads + 1 + assert total_heads == n_heads + assert len(seen_heads) == n_heads + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(20, 4, 8)]) +def test_gqa_uneven_sharding(n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + We only test the uneven GQA test case because even GQA shards the attention output + in the exact same manner as MHA. + + Args: + n_heads_q (int): The number of query heads. + n_heads_kv (int): The number of key/value heads. + n_shards (int): The number of shards to test for. + """ + param = fill_with_head_ids(HEAD_SIZE, n_heads_q) + + min_heads = n_heads_q + max_heads = 0 + seen_heads = set() + total_heads = 0 + + for shard_rank in range(n_shards): + sharded_param = shard_attn_out_param(param, shard_rank, n_shards, HEAD_SIZE, n_heads_q, n_heads_kv) + n_heads_local_q, _ = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + + assert sharded_param.shape[-1] == HEAD_SIZE * n_heads_local_q + + n_local_heads = sharded_param.shape[1] // HEAD_SIZE + total_heads += n_local_heads + max_heads = max(max_heads, n_local_heads) + min_heads = min(min_heads, n_local_heads) + + for i in range(n_local_heads): + head_id = torch.unique_consecutive(sharded_param[:, i * HEAD_SIZE:(i + 1) * HEAD_SIZE]) + assert len(head_id) == 1 + seen_heads.add(head_id.item()) + + assert max_heads == min_heads + 1 + assert total_heads == n_heads_q + assert len(seen_heads) == n_heads_q diff --git a/tests/unit/inference/model_implementations/sharding/test_mlp_sharding.py b/tests/unit/inference/model_implementations/sharding/test_mlp_sharding.py new file mode 100644 index 000000000000..aac7e5391d8f --- /dev/null +++ b/tests/unit/inference/model_implementations/sharding/test_mlp_sharding.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + + +def round_up_to_256(x: int) -> int: + """ + Round up to the nearest multiple of 256. + """ + return x + (256 - x % 256) + + +def make_params(model_dim: int, ffn_multiplier: int, n_experts: int, gated: bool = False) -> torch.Tensor: + """ + + """ + if gated: + mlp_1_intermediate = round_up_to_256(int(model_dim * ffn_multiplier * 4 / 3)) + mlp_2_intermediate = mlp_1_intermediate // 2 + else: + mlp_1_intermediate = ffn_multiplier * model_dim + mlp_2_intermediate = ffn_multiplier * model_dim + + mlp_1_shared_dim = torch.arange(mlp_1_intermediate, dtype=torch.float32, device=get_accelerator().current_device()) + + mlp_1_w = mlp_1_shared_dim.repeat_interleave(model_dim).reshape(mlp_1_intermediate, model_dim) + mlp_1_b = mlp_1_shared_dim + + mlp_2_shared_dim = torch.arange(mlp_2_intermediate, dtype=torch.float32, device=get_accelerator().current_device()) + mlp_2_w = mlp_2_shared_dim.repeat(model_dim).reshape(model_dim, mlp_2_intermediate) + mlp_2_b = torch.ones(model_dim, dtype=torch.float32, device=get_accelerator().current_device()) + + if n_experts > 1: + mlp_1_w = mlp_1_w.expand(n_experts, -1, -1) + mlp_1_b = mlp_1_b.expand(n_experts, -1) + mlp_2_w = mlp_2_w.expand(n_experts, -1, -1) + mlp_2_b = mlp_2_b.expand(n_experts, -1) + + return (mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("model_dim, ffn_multiplier, n_shards", [(1024, 4, 1), (1024, 4, 8), (1024, 4, 6)]) +@pytest.mark.parametrize("n_experts", [1, 16]) +def test_even_ffn_sharding(model_dim: int, ffn_multiplier: int, n_shards: int, n_experts: int): + """ + FFN sharding tends to be much simpler than attention sharding since it works on larger granularities. + While the test case of (1024, 4, 6) is not a use case we're likely to see, this does ensure that + the sharding logic will round correctly for the alignments we care about. + """ + mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b = make_params(model_dim, ffn_multiplier, n_experts) + + total_ffn_dim = model_dim * ffn_multiplier + mapped_neurons = 0 + + is_moe = n_experts > 1 + + for shard_rank in range(n_shards): + shard_1_w = shard_mlp_1_param(mlp_1_w, shard_rank, n_shards, is_moe=is_moe) + shard_1_b = shard_mlp_1_param(mlp_1_b, shard_rank, n_shards, is_moe=is_moe) + shard_2_w = shard_mlp_2_param(mlp_2_w, shard_rank, n_shards, is_moe=is_moe) + shard_2_b = shard_mlp_2_param(mlp_2_b, shard_rank, n_shards, is_moe=is_moe) + + assert shard_1_w.shape[-2] == shard_2_w.shape[-1] + assert shard_1_w.shape[-2] % DEFAULT_SHARD_GRANULARITY == 0 + assert shard_1_w.shape[-2] == shard_1_b.shape[-1] + + mapped_neurons += shard_1_w.shape[-2] + + if shard_rank != 0: + assert shard_2_b is None + else: + assert shard_2_b.shape[-1] == model_dim + + assert mapped_neurons == total_ffn_dim + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("model_dim, ffn_multiplier, n_shards", [(1024, 4, 1), (1024, 4, 8), (1024, 4, 6)]) +@pytest.mark.parametrize("n_experts", [1, 16]) +def test_gated_ffn_sharding(model_dim: int, ffn_multiplier: int, n_shards: int, n_experts: int): + """ + Test the same cases assuming a gated regime. + """ + mlp_1_w, mlp_1_b, mlp_2_w, mlp_2_b = make_params(model_dim, ffn_multiplier, n_experts, gated=True) + + total_ffn_dim = round_up_to_256(int(model_dim * ffn_multiplier * 4 / 3)) + mapped_neurons = 0 + + is_moe = n_experts > 1 + + for shard_rank in range(n_shards): + shard_1_w = shard_mlp_1_param(mlp_1_w, shard_rank, n_shards, gated=True, is_moe=is_moe) + shard_1_b = shard_mlp_1_param(mlp_1_b, shard_rank, n_shards, gated=True, is_moe=is_moe) + shard_2_w = shard_mlp_2_param(mlp_2_w, shard_rank, n_shards, is_moe=is_moe) + shard_2_b = shard_mlp_2_param(mlp_2_b, shard_rank, n_shards, is_moe=is_moe) + + assert shard_1_w.shape[-2] == shard_2_w.shape[-1] * 2 + assert shard_1_w.shape[-2] % DEFAULT_SHARD_GRANULARITY == 0 + assert shard_1_w.shape[-2] == shard_1_b.shape[-1] + + mapped_neurons += shard_1_w.shape[-2] + + if shard_rank != 0: + assert shard_2_b is None + else: + assert shard_2_b.shape[-1] == model_dim + + assert mapped_neurons == total_ffn_dim diff --git a/tests/unit/inference/model_implementations/sharding/test_qkv_sharding.py b/tests/unit/inference/model_implementations/sharding/test_qkv_sharding.py new file mode 100644 index 000000000000..9a1cb9c09c64 --- /dev/null +++ b/tests/unit/inference/model_implementations/sharding/test_qkv_sharding.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.model_implementations.sharding import * + + +def fill_with_head_ids(head_size: int, n_heads_q: int, n_heads_kv: Optional[int] = None) -> torch.Tensor: + """ + + """ + head_ids_q = torch.arange(n_heads_q, dtype=torch.half, device=get_accelerator().current_device()) + head_vals_q = head_ids_q.repeat_interleave(head_size * head_size * n_heads_q).reshape(n_heads_q * head_size, -1) + + if n_heads_kv is None: + return torch.cat([head_vals_q, head_vals_q, head_vals_q], dim=0) + + head_ids_k = torch.arange(n_heads_kv, dtype=torch.half, device=get_accelerator().current_device()) + head_vals_k = head_ids_k.repeat_interleave(head_size * head_size * n_heads_q).reshape(n_heads_kv * head_size, -1) + + return torch.cat([head_vals_q, head_vals_k, head_vals_k], dim=0) + + +def validate_inferred_shape(shard: torch.Tensor, head_size: int, n_local_q_heads: int, n_local_kv_heads: int): + """ + Validate that the leading dim of the shard is of the expected size and aligns with the sharding + logic for the attention computation itself. + """ + inferred_leading_dim = head_size * (n_local_q_heads + 2 * n_local_kv_heads) + assert shard.shape[0] == inferred_leading_dim + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads,n_shards", [(1, 1), (32, 1), (32, 8)]) +def test_even_mha_sharding(head_size: int, n_heads: int, n_shards: int): + """ + Test for MHA sharding. In these scenarios, we expect that each of the shards + should be the same size. + """ + param = fill_with_head_ids(head_size, n_heads) + + heads_per_shard = n_heads // n_shards + + for shard_rank in range(n_shards): + + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads, n_heads) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + assert shard.shape == (3 * head_size * heads_per_shard, head_size * n_heads) + + heads = shard.chunk(heads_per_shard * 3, dim=0) + for i in range(heads_per_shard): + assert torch.all(heads[i] == i + shard_rank * heads_per_shard) + assert torch.all(heads[i + heads_per_shard] == i + shard_rank * heads_per_shard) + assert torch.all(heads[i + heads_per_shard * 2] == i + shard_rank * heads_per_shard) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads, n_shards", [(3, 2), (20, 8)]) +def test_unbalanced_mha_sharding(head_size: int, n_heads: int, n_shards: int): + """ + Test MHA sharding when the distribution of heads will not be equal across all ranks. + """ + param = fill_with_head_ids(head_size, n_heads) + + max_heads = 0 + min_heads = n_heads + total_heads = 0 + seen_heads = set() + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads, n_heads) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + n_heads_in_shard = shard.shape[0] // head_size // 3 + + max_heads = max(max_heads, n_heads_in_shard) + min_heads = min(min_heads, n_heads_in_shard) + total_heads += n_heads_in_shard + + heads = shard.chunk(n_heads_in_shard * 3, dim=0) + + for local_head_id in range(n_heads_in_shard): + head_qkv = torch.cat([ + heads[local_head_id], heads[local_head_id + n_heads_in_shard], + heads[local_head_id + 2 * n_heads_in_shard] + ], + dim=0) + assert head_qkv.shape == (3 * head_size, head_size * n_heads) + + global_head_id = torch.unique_consecutive(head_qkv) + assert len(global_head_id) == 1 + + seen_heads.add(global_head_id.item()) + + assert max_heads - min_heads <= 1 + assert total_heads == n_heads + assert len(seen_heads) == n_heads + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(4, 2, 1), (8, 2, 1), (64, 16, 8)]) +def test_gqa_even_sharding(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + Test GQA sharding when the KV heads are evenly divisible by the number of shards. + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + n_kv_heads_in_shard = n_heads_kv // n_shards + n_q_heads_in_shard = n_heads_q // n_shards + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + assert shard.shape[0] == (n_q_heads_in_shard + n_kv_heads_in_shard * 2) * head_size + + q = shard[:n_q_heads_in_shard * head_size] + k = shard[n_q_heads_in_shard * head_size:(n_q_heads_in_shard + n_kv_heads_in_shard) * head_size] + v = shard[(n_q_heads_in_shard + n_kv_heads_in_shard) * head_size:] + + for local_head_id in range(n_q_heads_in_shard): + assert torch.all(q[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_q_heads_in_shard) + + for local_head_id in range(n_kv_heads_in_shard): + assert torch.all(k[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_kv_heads_in_shard) + assert torch.all(v[local_head_id * head_size:(local_head_id + 1) * head_size] == local_head_id + + shard_rank * n_kv_heads_in_shard) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(4, 2, 4), (20, 4, 8)]) +def test_gqa_uneven_sharding(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + Test GQA sharding when there are more shards than KV heads. + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + n_kv_heads_in_shard = 1 + n_shards_per_kv_head = n_shards // n_heads_kv + + max_heads = 0 + min_heads = n_heads_q + total_heads = 0 + seen_heads = set() + + for shard_rank in range(n_shards): + shard = shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + n_local_q_heads, n_local_kv_heads = get_local_heads(shard_rank, n_shards, n_heads_q, n_heads_kv) + validate_inferred_shape(shard, head_size, n_local_q_heads, n_local_kv_heads) + + local_n_heads_q = (shard.shape[0] - 2 * n_kv_heads_in_shard * head_size) // head_size + + max_heads = max(max_heads, local_n_heads_q) + min_heads = min(min_heads, local_n_heads_q) + total_heads += local_n_heads_q + + q = shard[:local_n_heads_q * head_size] + kv = shard[local_n_heads_q * head_size:] + + for local_head_id in range(local_n_heads_q): + q_head_id = torch.unique_consecutive(q[local_head_id * head_size:(local_head_id + 1) * head_size]) + assert len(q_head_id) == 1 + + seen_heads.add(q_head_id.item()) + + kv_id_calc = shard_rank // n_shards_per_kv_head + kv_id = torch.unique_consecutive(kv) + assert len(kv_id) == 1 + assert kv_id.item() == kv_id_calc + + assert max_heads - min_heads <= 1 + assert total_heads == n_heads_q + assert len(seen_heads) == n_heads_q + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads, n_shards", [(6, 8)]) +def test_unsupported_mha_configs(head_size: int, n_heads: int, n_shards: int): + """ + Sharding should fail if there are fewer heads than shards. + + TODO(cmikeh2): Look to support this configuration. + """ + param = fill_with_head_ids(head_size, n_heads) + + for shard_rank in range(n_shards): + with pytest.raises(ValueError): + shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads, n_heads) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("n_heads_q, n_heads_kv, n_shards", [(5, 2, 1), (40, 10, 8), (30, 5, 8)]) +def test_unsupported_gqa_configs(head_size: int, n_heads_q: int, n_heads_kv: int, n_shards: int): + """ + GQA has stricter requirements. We must be able to evenly shard or distribute the KV heads. + + Test cases are to test the following preconditions specifically: + 1. n_heads_q % n_heads_kv == 0 + 2. We must be able to evenly distribute KV heads + 3. We must be able to evely split KV heads + """ + param = fill_with_head_ids(head_size, n_heads_q, n_heads_kv) + + for shard_rank in range(n_shards): + with pytest.raises(ValueError): + shard_qkv_param(param, shard_rank, n_shards, head_size, n_heads_q, n_heads_kv) + + +@pytest.mark.inference_v2 +def test_mha_input_shape_error(): + + param = torch.empty(256, 128) + + n_heads = 2 + head_size = 64 + + with pytest.raises(ValueError): + shard_qkv_param(param, 0, 1, 64) + + +@pytest.mark.inference_v2 +def test_gqa_input_shape_error(): + + head_size = 64 + n_heads_q = 16 + n_heads_kv = 4 + + # Correct shape is 1536 (=16 * 64 + 2 * 4 * 64), 1024 + param = torch.empty(2048, 1024) + + with pytest.raises(ValueError): + shard_qkv_param(param, 0, 1, head_size, n_heads_q, n_heads_kv) diff --git a/tests/unit/inference/modules/__init__.py b/tests/unit/inference/modules/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/modules/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/modules/test_blas_linear_module.py b/tests/unit/inference/modules/test_blas_linear_module.py new file mode 100644 index 000000000000..18b546bab6bd --- /dev/null +++ b/tests/unit/inference/modules/test_blas_linear_module.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSLinearConfig +from deepspeed.inference.v2.modules.interfaces import DSLinearRegistry +from ..inference_test_utils import allclose + + +def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + out_states = torch.nn.functional.linear(hidden_states, weight, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _blas_linear_helper(tokens: int, + in_channels: int, + out_channels: int, + dtype: DtypeEnum, + act_fn: ActivationType, + use_bias: bool = True) -> None: + linear_config = DSLinearConfig(max_tokens=2048, + in_channels=in_channels, + out_channels=out_channels, + activation=act_fn, + input_dtype=dtype, + output_dtype=dtype) + + bundle = ConfigBundle(name='blas_fp_linear', config=linear_config) + + module = DSLinearRegistry.instantiate_config(bundle) + + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + + weight_out_channels = 2 * out_channels if is_gated(act_fn) else out_channels + weight = torch.randn( + (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + if use_bias: + bias = torch.randn( + (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + else: + bias = None + + # Reference output + ref_output = reference_implementation(hidden_states, weight, bias, act_fn) + + # New output + ds_output = module(hidden_states, weight, bias) + + # Check + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, in_channels, out_channels", [(1, 4608, 1728), (37, 8192, 4096), (1280, 3072, 6144)]) +def test_blas_linear_shapes(tokens: int, in_channels: int, out_channels: int) -> None: + + _blas_linear_helper(tokens, in_channels, out_channels, DtypeEnum.fp16, ActivationType.IDENTITY) + + +all_acts = [ + ActivationType.RELU, + ActivationType.GELU, + ActivationType.SILU, + ActivationType.GEGLU, + ActivationType.ReGLU, + ActivationType.SiGLU, +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_blas_linear_act_fn(act_fn: ActivationType, use_bias: bool) -> None: + + _blas_linear_helper(283, 512, 4096, DtypeEnum.fp16, act_fn, use_bias=use_bias) diff --git a/tests/unit/inference/modules/test_blocked_attn.py b/tests/unit/inference/modules/test_blocked_attn.py new file mode 100644 index 000000000000..1f03b46bd002 --- /dev/null +++ b/tests/unit/inference/modules/test_blocked_attn.py @@ -0,0 +1,210 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import itertools + +from typing import List, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType +from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase + +from ..kernels.ragged_ops.ragged_testing_utils import build_batch_and_manager +from ..inference_test_utils import allclose + +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + validate_accuracy = True +except ImportError: + validate_accuracy = False + + +def _blocked_flash_testing_helper(head_size: int, + n_heads_q: int, + n_heads_kv: int, + seq_params: List[Tuple[int, int]], + trained_freqs: bool = None) -> None: + """ + Helper function for testing blocked flash attention. This implementation is based on + the implemnentation in ``unit.inference.kernels.ragged_ops.test_blocked_flash`` but + integrates functionality to validate the composability. + """ + if trained_freqs is None: + embed_type = PositionalEmbeddingType.none + embed_args = {} + else: + embed_type = PositionalEmbeddingType.rotate_half + if trained_freqs: + embed_args = {'trained_freqs': True} + else: + embed_args = {'trained_freqs': False} + + attn_config = DSSelfAttentionConfig(max_tokens=2048, + n_heads_q=n_heads_q, + n_heads_kv=n_heads_kv, + head_size=head_size, + max_sequences=32, + positional_embedding_type=embed_type, + positional_embedding_args=embed_args) + + config = ConfigBundle(name='dense_blocked_attention', config=attn_config) + attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config) + + kv_block_size = attn_module.kv_block_size + + kvs = [] + for _, history_len in seq_params: + if history_len > 0: + kvs.append( + torch.randn((history_len, 2 * n_heads_kv * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16)) + else: + kvs.append(None) + + batch, state_manager, _ = build_batch_and_manager(seq_params, head_size, n_heads_kv, kv_block_size, kv_fill=kvs) + + qkv = torch.randn((batch.current_tokens, (n_heads_q + 2 * n_heads_kv) * head_size), + device=get_accelerator().current_device(), + dtype=torch.float16) + + kv_cache = state_manager.get_cache(0) + + attn_module.build_atoms(batch) + if not trained_freqs: + out = attn_module(qkv, kv_cache, batch) + else: + inv_freqs = torch.randn((head_size // 2, ), device=get_accelerator().current_device(), dtype=torch.float16) + out = attn_module(qkv, kv_cache, batch, inv_freqs) + + if validate_accuracy and trained_freqs is None: + cu_seqlens_q = torch.tensor([0] + list(itertools.accumulate([seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + cu_seqlens_kv = torch.tensor([0] + list(itertools.accumulate([seq[1] + seq[0] for seq in seq_params])), + dtype=torch.int32, + device=get_accelerator().current_device()) + + inflight_kv = qkv[:, head_size * n_heads_q:] + full_kvs = [] + for i, kv in enumerate(kvs): + if kv is not None: + full_kvs.append(torch.cat([kv, inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]], dim=0)) + else: + full_kvs.append(inflight_kv[cu_seqlens_q[i]:cu_seqlens_q[i + 1]]) + run_kvs = torch.cat(full_kvs, dim=0) + k = run_kvs[:, :head_size * n_heads_kv] + v = run_kvs[:, head_size * n_heads_kv:] + + q = qkv[:, :head_size * n_heads_q] + q_ref = q.reshape((batch.current_tokens, n_heads_q, head_size)) + k_ref = k.reshape((k.shape[0], n_heads_kv, head_size)) + v_ref = v.reshape((v.shape[0], n_heads_kv, head_size)) + + max_seqlen_q = max([seq[0] for seq in seq_params]) + max_seqlen_kv = max([seq[1] + seq[0] for seq in seq_params]) + + ref_o = flash_attn_varlen_func(q_ref, + k_ref, + v_ref, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + softmax_scale=1.0, + causal=True) + + ref_o = ref_o.reshape(batch.current_tokens, head_size * n_heads_q) + + assert allclose(out, ref_o) + + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("n_tokens", [2, 33, 65, 128, 256, 2037]) +def test_single_prompt(n_tokens: int) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(n_tokens, 0)] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("prompt_lengths", [(128, 128), (192, 38), (514, 713), (83, 312, 610)]) +def test_multiple_prompts(prompt_lengths: Tuple[int, int]) -> None: + """ + Test multiple prompts in a single batch. + """ + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(prompt_lengths[i], 0) for i in range(len(prompt_lengths))] + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("seq_params", [(1, 34), (43, 40), (1, 144), (64, 128), (332, 628)]) +def test_continuation(seq_params: Tuple[int, int]) -> None: + """ + Test continued generation/prompt processing. + """ + head_size = 64 + n_heads_q = 32 + n_heads_kv = 32 + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, [seq_params]) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_size", [64, 128]) +def test_head_size(head_size: int) -> None: + n_heads_q = 16 + n_heads_kv = 16 + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("head_config", [(32, 8), (64, 16), (40, 8)]) +def test_gqa(head_config: Tuple[int, int]) -> None: + head_size = 128 + n_heads_q = head_config[0] + n_heads_kv = head_config[1] + + seq_params = [(128, 128), (192, 38), (1, 814)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +def test_fully_composed() -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("trained_freqs", [True, False]) +def test_rotary_emb(trained_freqs: bool) -> None: + head_size = 64 + n_heads_q = 16 + n_heads_kv = 16 + + seq_params = [(332, 628), (1, 718), (1, 323), (180, 5), (224, 0)] + + _blocked_flash_testing_helper(head_size, n_heads_q, n_heads_kv, seq_params, trained_freqs=trained_freqs) diff --git a/tests/unit/inference/modules/test_cuda_pre_ln_module.py b/tests/unit/inference/modules/test_cuda_pre_ln_module.py new file mode 100644 index 000000000000..d6c42a3e1336 --- /dev/null +++ b/tests/unit/inference/modules/test_cuda_pre_ln_module.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPreNormRegistry +from ..inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: Optional[torch.Tensor], gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = residual.dtype + + residual = residual.to(torch.float32) + gamma = gamma.to(torch.float32) + beta = beta.to(torch.float32) + + if hidden_states is not None: + hidden_states = hidden_states.to(torch.float32) + residual = residual + hidden_states + hidden_states = torch.nn.functional.layer_norm(residual, (residual.size(-1), ), + weight=gamma, + bias=beta, + eps=epsilon) + return residual.to(dtype), hidden_states.to(dtype) + + +def _pre_ln_test_helper(n_tokens: int, n_channels: int, dtype: torch.dtype, res_add: bool = False): + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=n_channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_pre_ln', config=config) + + # Input vals + if res_add: + hidden_states = torch.randn((n_tokens, n_channels), + dtype=dtype, + device=get_accelerator().current_device_name()) + else: + hidden_states = None + + residual = torch.randn((n_tokens, n_channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_residual, ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + pre_ln_module = DSPreNormRegistry.instantiate_config(bundle) + gamma = pre_ln_module.transform_param(gamma) + beta = pre_ln_module.transform_param(beta) + + ds_residual, ds_output = pre_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_residual, ref_residual) + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +def test_token_channels(tokens: int, channels: int) -> None: + _pre_ln_test_helper(tokens, channels, torch.float16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtype(dtype: torch.dtype) -> None: + _pre_ln_test_helper(733, 2560, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_res_add(): + _pre_ln_test_helper(733, 2560, torch.float16, res_add=False) diff --git a/tests/unit/inference/modules/test_custom_module.py b/tests/unit/inference/modules/test_custom_module.py new file mode 100644 index 000000000000..e14ccd3f2244 --- /dev/null +++ b/tests/unit/inference/modules/test_custom_module.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.implementations import cuda_post_ln +from ..inference_test_utils import allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@DSPostNormRegistry.register_module +class CustomPostLNModule(cuda_post_ln.DSPostLNCUDAModule): + + @staticmethod + def name(): + return 'custom_post_ln' + + +""" +Here, we explicitly register an LN implementation outside the core deepspeed repo. This should +validate that the registry is working as expected and we can implement modules outside the core +repo. +""" + + +@pytest.mark.inference_v2_ops +def test_custom_registration(): + channels = 4096 + dtype = torch.float16 + tokens = 1024 + + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='custom_post_ln', config=config) + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_module = DSPostNormRegistry.instantiate_config(bundle) + gamma = post_ln_module.transform_param(gamma) + beta = post_ln_module.transform_param(beta) + ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/modules/test_cutlass_moe.py b/tests/unit/inference/modules/test_cutlass_moe.py new file mode 100644 index 000000000000..98a48b5b149d --- /dev/null +++ b/tests/unit/inference/modules/test_cutlass_moe.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSMoEConfig +from deepspeed.inference.v2.modules.interfaces import DSMoERegistry + +from ..kernels.ragged_ops.ragged_testing_utils import build_simple_batch +from ..inference_test_utils import allclose, get_dtypes + + +def _gating_reference(logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Reference gating code. + """ + logits = logits.float() + probs = torch.nn.functional.softmax(logits, dim=1) + + indices1_s = torch.argmax(probs, dim=-1) + mask1 = torch.nn.functional.one_hot(indices1_s, num_classes=logits.shape[-1]) + indices_mask = mask1.sum(dim=1) * logits.shape[-1] - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + gates1_s = (probs * mask1).sum(dim=1) + + sorted_indices = indices1_s.sort()[1] + original_indices = sorted_indices.sort()[1] + + exp_count = torch.bincount(indices1_s, minlength=logits.shape[-1]).long() + exp_count_cumsum = exp_count.cumsum(dim=0) + + return sorted_indices, original_indices, exp_count_cumsum, gates1_s + + +def _reference_impl(hidden_states: torch.Tensor, gate_weight: torch.Tensor, mlp_1_w: torch.Tensor, + mlp_2_w: torch.Tensor, mlp_1_b: torch.Tensor, mlp_2_b: torch.Tensor, + act_fn: ActivationType) -> torch.Tensor: + """ + Reference implementation of the MoE module. + """ + + act_fn_dict = { + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + logits = torch.matmul(hidden_states, gate_weight.t()) + sorted_indices, original_indices, exp_count_cumsum, gate_scales = _gating_reference(logits) + + moe_input = hidden_states[sorted_indices] + + output_unordered = torch.empty_like(hidden_states) + + for expert_idx in range(mlp_1_w.shape[0]): + min_bound = 0 if expert_idx == 0 else exp_count_cumsum[expert_idx - 1] + max_bound = exp_count_cumsum[expert_idx] + + input_slice = moe_input[min_bound:max_bound] + intermediate = torch.nn.functional.linear(input_slice, mlp_1_w[expert_idx], mlp_1_b[expert_idx]) + + intermediate = act_fn_dict[act_fn](intermediate) + output_slice = torch.nn.functional.linear(intermediate, mlp_2_w[expert_idx], mlp_2_b[expert_idx]) + + output_unordered[min_bound:max_bound] = output_slice + + output = output_unordered[original_indices] + + output.mul_(gate_scales.unsqueeze(-1)).reshape(hidden_states.shape) + return output + + +def _cutlass_moe_testing_helper(tokens: int, + in_channels: int, + intermediate_dim: int, + experts: int, + dtype: int, + activation_type: ActivationType = ActivationType.GELU, + use_bias: bool = True, + iters: int = 1) -> None: + + config = DSMoEConfig(max_tokens=4096, + model_dim=in_channels, + intermediate_features=intermediate_dim, + n_experts=experts, + activation=activation_type, + input_dtype=dtype, + output_dtype=dtype) + + implementation_config = {"weight_dtype": DtypeEnum(dtype)} + + bundle = ConfigBundle(name='cutlass_multi_gemm_moe', config=config, implementation_config=implementation_config) + moe_module = DSMoERegistry.instantiate_config(bundle) + + batch = build_simple_batch([tokens]) + + # Parameters + gate_weight = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + mlp_1_w = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_2_w = torch.randn( + (experts, in_channels, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + if use_bias: + mlp_1_b = torch.randn( + (experts, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_2_b = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + else: + mlp_1_b = None + mlp_2_b = None + + gate_ds = moe_module.transform_gate_param(gate_weight) + mlp_1_w_ds = moe_module.transform_moe_mlp_1_param(mlp_1_w) + mlp_1_b_ds = moe_module.transform_moe_mlp_1_param(mlp_1_b) + mlp_2_w_ds = moe_module.transform_moe_mlp_2_param(mlp_2_w) + mlp_2_b_ds = moe_module.transform_moe_mlp_2_param(mlp_2_b) + + for _ in range(iters): + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + # Reference implementation + ref_output = _reference_impl(hidden_states, gate_weight, mlp_1_w, mlp_2_w, mlp_1_b, mlp_2_b, activation_type) + + output = moe_module(hidden_states, + batch, + gate_ds, + mlp_1_w_ds, + mlp_2_w_ds, + mlp_1_b=mlp_1_b_ds, + mlp_2_b=mlp_2_b_ds) + + # Increase the tolerance for larger meta ops since the error is additive + assert allclose(output, ref_output, tolerances=(1e-2, 1e-2)) + + get_accelerator().synchronize() + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("experts", [2, 32, 64]) +def test_expert_variance(experts: int) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=experts, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True) + + +@pytest.mark.inference_v2_ops +def test_successive_inputs(): + """ + The CUTLASS MoE uses persistent state (expert counts) that is assumed to be cleared + on each forward pass. This ensures that the module is clearing that metadata. + """ + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True, + iters=10) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtypes(dtype: torch.dtype) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum(dtype), + activation_type=ActivationType.IDENTITY, + use_bias=True) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("activation_type", [ActivationType.GELU, ActivationType.RELU, ActivationType.SILU]) +def test_activation_types(activation_type: ActivationType) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=4096, + intermediate_dim=2048, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=activation_type, + use_bias=True) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("in_channels, out_channels", [(4096, 2048), (2048, 8192), (6144, 3072)]) +def test_in_out_channels(in_channels: int, out_channels: int) -> None: + _cutlass_moe_testing_helper(tokens=876, + in_channels=in_channels, + intermediate_dim=out_channels, + experts=64, + dtype=DtypeEnum.fp16, + activation_type=ActivationType.IDENTITY, + use_bias=True) diff --git a/tests/unit/inference/modules/test_post_ln_module.py b/tests/unit/inference/modules/test_post_ln_module.py new file mode 100644 index 000000000000..238d8fa4d1b1 --- /dev/null +++ b/tests/unit/inference/modules/test_post_ln_module.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPostNormRegistry +from ..inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: torch.Tensor, gamma: torch.Tensor, + beta: torch.Tensor, epsilon: float) -> torch.Tensor: + residual_f = residual.to(torch.float32) + hidden_states_f = hidden_states.to(torch.float32) + gamma_f = gamma.to(torch.float32) + beta_f = beta.to(torch.float32) + return torch.nn.functional.layer_norm(residual_f + hidden_states_f, (hidden_states_f.size(-1), ), + weight=gamma_f, + bias=beta_f, + eps=epsilon).to(hidden_states.dtype) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +@pytest.mark.parametrize("dtype", get_dtypes()) +def test_cuda_post_ln_module(tokens: int, channels: int, dtype: torch.dtype) -> None: + config = DSNormConfig(max_tokens=2048, + type="layer_norm", + channels=channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_post_ln', config=config) + + # Input vals + hidden_states = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + residual = torch.randn((tokens, channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + beta = torch.rand((channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_output = reference_implementation(residual, hidden_states, gamma, beta, epsilon) + + # New output + post_ln_module = DSPostNormRegistry.instantiate_config(bundle) + gamma = post_ln_module.transform_param(gamma) + beta = post_ln_module.transform_param(beta) + ds_output, _ = post_ln_module(residual, hidden_states, gamma, beta) + + # Check + assert allclose(ds_output, ref_output) diff --git a/tests/unit/inference/modules/test_pre_rms_module.py b/tests/unit/inference/modules/test_pre_rms_module.py new file mode 100644 index 000000000000..bbbec2d15709 --- /dev/null +++ b/tests/unit/inference/modules/test_pre_rms_module.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional, Tuple + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSNormConfig +from deepspeed.inference.v2.modules.interfaces import DSPreNormRegistry +from ..inference_test_utils import get_dtypes, allclose + + +def reference_implementation(residual: torch.Tensor, hidden_states: Optional[torch.Tensor], gamma: torch.Tensor, + epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = residual.dtype + + if hidden_states is not None: + hidden_states = hidden_states + residual = residual + hidden_states + + rms_vals = residual.to(torch.float32) + variance = rms_vals.pow(2).mean(-1, keepdim=True) + rms_vals = rms_vals * torch.rsqrt(variance + epsilon) + + if gamma.dtype in [torch.float16, torch.bfloat16]: + rms_vals = rms_vals.to(gamma.dtype) + + hidden_states = gamma * rms_vals + + return residual.to(dtype), hidden_states.to(dtype) + + +def _pre_rms_test_helper(n_tokens: int, n_channels: int, dtype: torch.dtype, res_add: bool = False): + config = DSNormConfig(max_tokens=2048, + type="rms_norm", + channels=n_channels, + residual_dtype=dtype, + input_dtype=dtype, + output_dtype=dtype, + eps=1e-5) + bundle = ConfigBundle(name='cuda_pre_rms', config=config) + + # Input vals + if res_add: + hidden_states = torch.randn((n_tokens, n_channels), + dtype=dtype, + device=get_accelerator().current_device_name()) + else: + hidden_states = None + + residual = torch.randn((n_tokens, n_channels), dtype=dtype, device=get_accelerator().current_device_name()) + gamma = torch.randn((n_channels), dtype=torch.float32, device=get_accelerator().current_device_name()) + epsilon = 1e-5 + + # Reference output + ref_residual, ref_output = reference_implementation(residual, hidden_states, gamma, epsilon) + + # New output + pre_ln_module = DSPreNormRegistry.instantiate_config(bundle) + gamma = pre_ln_module.transform_param(gamma) + + ds_residual, ds_output = pre_ln_module(residual, hidden_states, gamma) + + # Check + assert allclose(ds_residual, ref_residual) + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, channels", [(1, 2048), (37, 8192), (1280, 768), (2048, 5120)]) +def test_token_channels(tokens: int, channels: int) -> None: + _pre_rms_test_helper(tokens, channels, torch.float16) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("dtype", get_dtypes(include_float=False)) +def test_dtype(dtype: torch.dtype) -> None: + _pre_rms_test_helper(733, 2560, dtype) + + +@pytest.mark.inference_v2_ops +def test_no_res_add(): + _pre_rms_test_helper(733, 2560, torch.float16, res_add=False) diff --git a/tests/unit/inference/ragged/__init__.py b/tests/unit/inference/ragged/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/tests/unit/inference/ragged/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/tests/unit/inference/ragged/test_blocked_allocator.py b/tests/unit/inference/ragged/test_blocked_allocator.py new file mode 100644 index 000000000000..4596e81c5652 --- /dev/null +++ b/tests/unit/inference/ragged/test_blocked_allocator.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +from typing import List + +import pytest +import torch + +from deepspeed.inference.v2.ragged.blocked_allocator import BlockedAllocator + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('bad_size', [0, -1]) +def test_bad_initialization(bad_size: int) -> None: + with pytest.raises(ValueError): + BlockedAllocator(bad_size) + + +@pytest.mark.inference_v2 +def test_allocation() -> None: + + allocator = BlockedAllocator(16) + + a1 = allocator.allocate(4) + assert a1.numel() == 4 + assert allocator.free_blocks == 12 + + a2_allocs = [] + for i in range(3): + a2_allocs.append(allocator.allocate(2)) + assert allocator.free_blocks == 12 - (i + 1) * 2 + + a3 = allocator.allocate(6) + assert a3.numel() == 6 + + assert allocator.free_blocks == 0 + + # Test that we can't allocate more blocks than we have. + with pytest.raises(ValueError): + allocator.allocate(1) + + all_vals = torch.cat([a1, *a2_allocs, a3], dim=0) + unique_vals = torch.unique(all_vals, sorted=False) + assert unique_vals.numel() == all_vals.numel() + + +@pytest.mark.inference_v2 +def test_too_large_allocation(): + allocator = BlockedAllocator(16) + + with pytest.raises(ValueError): + allocator.allocate(17) + + +@pytest.mark.inference_v2 +def test_deallocation() -> None: + allocator = BlockedAllocator(16) + + # Allocate + all_blocks = allocator.allocate(16) + assert allocator.free_blocks == 0 + + # Deallocate all blocks + allocator.free(all_blocks) + assert allocator.free_blocks == 16 + + # Get all the blocks again + all_blocks = allocator.allocate(16) + + # Deallocate in chunks + c1 = all_blocks[:4] + c2 = all_blocks[4:8] + + allocator.free(c1) + assert allocator.free_blocks == 4 + + allocator.free(c2) + assert allocator.free_blocks == 8 + + with pytest.raises(ValueError): + allocator.free(c1) + + with pytest.raises(ValueError): + allocator.free(c2) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('index', [-1, 2]) +def test_invalid_dealloc_indices(index: int): + allocator = BlockedAllocator(1) + + with pytest.raises(ValueError): + allocator.free(torch.tensor([index])) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('index', [-1, 2]) +def test_invalid_alloc_indices(index: int): + allocator = BlockedAllocator(1) + allocator.allocate(1) + + to_free = [0, index] + + with pytest.raises(ValueError): + allocator.free(torch.tensor(to_free)) + + # Block 0 should not be freed if passed with an invalid index. + assert allocator.free_blocks == 0 + + allocator.free(torch.tensor([0])) + assert allocator.free_blocks == 1 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('test_iters', [8192]) +def test_long_running_allocation(test_iters: int) -> None: + """ + Evaluate the stability of the allocator over a longer sequence of allocations/deallocations. + """ + TOTAL_BLOCKS = 128 + + allocator = BlockedAllocator(TOTAL_BLOCKS) + + def validate_uniqueness(all_blocks: List[torch.Tensor]) -> None: + all_vals = torch.cat(all_blocks, dim=0) + assert all_vals.numel() <= TOTAL_BLOCKS + + unique_vals = torch.unique(all_vals, sorted=False) + assert unique_vals.numel() == all_vals.numel() + + all_allocs: List[torch.Tensor] = [] + num_allocs = 0 + num_frees = 0 + num_blocks_allocated = 0 + num_blocks_freed = 0 + + for _ in range(test_iters): + decision = random.randint(0, 1) + + if decision == 0: + blocks_to_allocate = random.randint(1, 24) + if blocks_to_allocate > allocator.free_blocks: + with pytest.raises(ValueError): + allocator.allocate(blocks_to_allocate) + else: + all_allocs.append(allocator.allocate(blocks_to_allocate)) + num_allocs += 1 + num_blocks_allocated += blocks_to_allocate + else: + if len(all_allocs) > 0: + idx = random.randint(0, len(all_allocs) - 1) + allocator.free(all_allocs[idx]) + + num_frees += 1 + num_blocks_freed += all_allocs[idx].numel() + + del all_allocs[idx] + + if len(all_allocs) > 0: + validate_uniqueness(all_allocs) + + assert num_allocs == num_frees + len(all_allocs) + assert num_blocks_allocated == num_blocks_freed + (TOTAL_BLOCKS - allocator.free_blocks) diff --git a/tests/unit/inference/ragged/test_manager_configs.py b/tests/unit/inference/ragged/test_manager_configs.py new file mode 100644 index 000000000000..bdd513445ddb --- /dev/null +++ b/tests/unit/inference/ragged/test_manager_configs.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +from pydantic import ValidationError + +from deepspeed.inference.v2.ragged import DSStateManagerConfig + + +@pytest.mark.inference_v2 +def test_negative_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=0) + + +@pytest.mark.inference_v2 +def test_negative_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=0) + + +@pytest.mark.inference_v2 +def test_negative_max_ragged_sequence_count() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_sequence_count=-1) + + +@pytest.mark.inference_v2 +def test_zero_max_ragged_sequence_count() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_sequence_count=0) + + +@pytest.mark.inference_v2 +def test_too_small_max_ragged_batch_size() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_ragged_batch_size=512, max_ragged_sequence_count=1024) + + +@pytest.mark.inference_v2 +def test_too_small_max_tracked_sequences() -> None: + with pytest.raises(ValidationError): + DSStateManagerConfig(max_tracked_sequences=512, max_ragged_sequence_count=1024) diff --git a/tests/unit/inference/ragged/test_ragged_wrapper.py b/tests/unit/inference/ragged/test_ragged_wrapper.py new file mode 100644 index 000000000000..3cb74f4c49d2 --- /dev/null +++ b/tests/unit/inference/ragged/test_ragged_wrapper.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.ragged import ( + PlaceholderSequenceDescriptor, + RaggedBatchWrapper, + DSStateManagerConfig, +) + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('max_ragged_sequence_count, max_ragged_batch_size', [(128, 512), (128, 1024)]) +def test_wrapper_initialization(max_ragged_sequence_count: int, max_ragged_batch_size: int) -> None: + config = DSStateManagerConfig(max_tracked_sequences=max_ragged_sequence_count, + max_ragged_batch_size=max_ragged_batch_size, + max_ragged_sequence_count=max_ragged_sequence_count) + + batch = RaggedBatchWrapper(config) + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('seq_len', [1, 37, 128, 512]) +def test_single_sequence_batch(seq_len: int) -> None: + """ + Test we successfully construct single sequence batches and the on device metadata is accurate. + """ + + config = DSStateManagerConfig() + batch = RaggedBatchWrapper(config) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + seq_desc = PlaceholderSequenceDescriptor() + tokens = torch.randint(0, 100, (seq_len, )) + batch.insert_sequence(seq_desc, tokens) + + batch.finalize() + + assert batch.current_tokens == seq_len + assert batch.current_sequences == 1 + assert torch.equal(batch.input_ids(), tokens.to(get_accelerator().current_device())) + assert torch.equal(batch.tokens_to_seq(), torch.zeros_like(tokens, device=get_accelerator().current_device())) + assert torch.equal(batch.batch_metadata_buffer(), + torch.tensor([seq_len, 1], device=get_accelerator().current_device())) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + +@pytest.mark.inference_v2 +@pytest.mark.parametrize('seq_lens', [[128, 128], [1, 32, 243], [64, 1, 1, 1, 1, 393, 27, 2]]) +def test_multi_sequence_batch(seq_lens: List[int]) -> None: + """ + Test sequentially adding new tokens to a batch and validate device data structures hold + the appropriate data. + """ + config = DSStateManagerConfig() + batch = RaggedBatchWrapper(config) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 + + all_toks = [torch.randint(0, 100, (seq_len, )) for seq_len in seq_lens] + + for i, toks in enumerate(all_toks): + seq_desc = PlaceholderSequenceDescriptor() + batch.insert_sequence(seq_desc, toks) + + assert batch.current_tokens == sum(seq_lens[:i + 1]) + assert batch.current_sequences == i + 1 + + batch.finalize() + + assert batch.current_tokens == sum(seq_lens) + assert batch.current_sequences == len(seq_lens) + + assert torch.equal(batch.input_ids(), torch.cat(all_toks, dim=0).to(get_accelerator().current_device())) + assert torch.equal( + batch.tokens_to_seq(), + torch.cat([torch.full((seq_len, ), i, dtype=torch.int32) for i, seq_len in enumerate(seq_lens)], + dim=0).to(get_accelerator().current_device())) + + for i, seq_len in enumerate(seq_lens): + assert batch.inflight_seq_descriptors()[i][0] == sum(seq_lens[:i]) + assert batch.inflight_seq_descriptors()[i][1] == seq_len + assert batch.inflight_seq_descriptors()[i][2] == 0 + + assert torch.equal(batch.batch_metadata_buffer(), + torch.tensor([sum(seq_lens), len(seq_lens)], device=get_accelerator().current_device())) + + batch.clear() + + assert batch.current_tokens == 0 + assert batch.current_sequences == 0 diff --git a/version.txt b/version.txt index bc859cbd6d99..aa22d3ce39b2 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.11.2 +0.12.3