diff --git a/.github/workflows/amd-mi200.yml b/.github/workflows/amd-mi200.yml index cd1cafe8e679..ea8d2f5f806f 100644 --- a/.github/workflows/amd-mi200.yml +++ b/.github/workflows/amd-mi200.yml @@ -32,7 +32,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm5.6 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/rocm6.0 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" diff --git a/.github/workflows/nv-mii.yml b/.github/workflows/nv-mii.yml index aab9d28b769c..d394b7e24bd6 100644 --- a/.github/workflows/nv-mii.yml +++ b/.github/workflows/nv-mii.yml @@ -37,7 +37,7 @@ jobs: - name: Install pytorch run: | - pip3 install -U --cache-dir $TORCH_CACHE torch --index-url https://download.pytorch.org/whl/cu118 + pip3 install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -46,7 +46,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - # git checkout bdf36dc + git checkout v4.42.4 git rev-parse --short HEAD pip install . diff --git a/blogs/deepspeed-gds/README.md b/blogs/deepspeed-gds/README.md new file mode 100644 index 000000000000..a00b5083c3ea --- /dev/null +++ b/blogs/deepspeed-gds/README.md @@ -0,0 +1,84 @@ +
+ +# DeepNVMe: Improving DL Applications through I/O Optimizations + +
+ +# Introduction + +Deep Learning (DL) continues to drive unprecedented advancements across important +Artificial Intelligence domains including language, speech, video, and multimodal applications. +A key factor to these advancements is dramatic scalability on multiple dimensions including model size, +sequence length, and hardware parallelism. From a system perspective, DL scalability puts significant +pressure on essential subsystems including computation, memory, communication, and storage. However, +existing DL optimization efforts have mostly neglected the storage subsystem, making I/O operations such +as data loading, model checkpointing, and offloading the main bottlenecks of large-scale DL. To address +this problem, DeepSpeed has created a suite of I/O optimizations collectively called DeepNVMe. + +DeepNVMe improves the performance and efficiency of I/O-bound DL applications by accelerating I/O operations +and reducing hardware requirements. It achieves this by leveraging storage innovations such as Non-Volatile +Memory Express (NVMe) Solid Storage Devices (SSDs) and Nvidia Magnum IO^TM GPUDirect® Storage (GDS). In this +blog we show the benefits of DeepNVMe using microbenchmarks and an inference application. In experiments +conducted on an Azure NC96ads\_A100\_v4 VM, we observed that DeepNVMe saturates available NVMe bandwidth for +data transfers with GPU or CPU memory, achieving up to 10GB/sec reads and 5 GB/secs writes. + +# Background +High-performance access to persistent storage is a common challenge in many computing domains, including DL. Thus, a significant number of hardware and software solutions have been proposed. DeepNVMe builds on three such solutions: (1) NVMe SSDs, (2) Nvidia GDS, and (3) Linux Asynchronous I/O (libaio). We will briefly describe each of these technologies. + +NVMe SSDs are Flash-based storage devices that are replacing much slower hard disk drives (HDD) as primary persistent storage in modern servers. For example, an Azure NC96ads\_A100\_v4 VM is equipped with four NVMe SSDs which are individually capable of 3.25 GB/sec reads and can be combined in a RAID-0 configuration for a theoretical aggregate read bandwidth of 13 GB/sec. Nvidia GDS enables direct transfers between NVMe and GPU memory thus avoiding the inefficiencies of the traditional approach of using intermediate CPU memory (bounce buffer). Nvidia GDS is generally available in CUDA versions 11.4 and above. Finally, libaio is an asynchronous I/O stack introduced in Linux to better extract raw performance of fast storage devices like NVMe SSDs compared to the traditional I/O stack. + +# DeepNVMe: an Optimization Module for DeepLearning I/O + +DeepNVMe is a Python module that we developed with two key design principles. First, it leverages the above discussed storage technologies to implement powerful optimizations such as non-blocking I/O operations, bulk submission of I/O operations, parallelization of an individual I/O operation, and a lightweight runtime. Second, it exposes these I/O optimizations through a simple POSIX-like interface to foster easy integration into DL applications while avoiding the complexities of the underlying technologies. + +# Evaluation + +Our experiments are conducted on an Azure NC96ads\_A100\_v4 VM with setup details summarized in Table 1. For multi-device experiments, the SSDs are combined in a RAID-0 configuration. + + + +
+Table 1: Experimental setup details +
+ +## Microbenchmark Performance + +We used three benchmarking tools for our evaluations. The first is fio, the popular I/O benchmarking tool written in C. The second is gdsio from Nvidia for benchmarking GDS performance. The third is ds\_io, a Python tool that we created for easy integration with DeepNVMe and to be more representative of DL applications which are commonly Python-based. + +## High-Performance I/O with CPU Buffers via NVMe Scaling + +Our first set of microbenchmark evaluations used fio and ds\_io to measure the performance of transferring 1GB data between NVMe and CPU memory. We configure fio to use the libaio backend for these experiments1. The results are summarized in Figure 1, from which we make two observations. First, DeepNVMe demonstrates high performance as it roughly matches fio, despite being more representative of DL applications. Second, DeepNVMe scales I/O performance almost linearly with available NVMe bandwidth, achieving rates of 10GB/sec reads and 5GB/sec writes. + + + +
+Figure 1: Using DeepNVMe to scale data transfers between NVMe and CPU buffer +
+ +## High-Performance I/O with GPU Buffers via NVMe Scaling + +Our second set of microbenchmark evaluations used gdsio and ds\_io to measure the performance of 1GB data transfer between NVMe and GPU memory. For this experiment, we configure ds\_io to use both the traditional bounce buffer approach and the more efficient GDS approach. The results are summarized in Figure 2, from which we make three observations. First, we see that GDS improves performance in DeepNVMe compared to the traditional bounce buffer approach, with up to 37% speedup. Second, DeepNVMe demonstrates high performance by matching (and sometimes surpassing) gdsio despite being more representative of DL applications. Third, we see that DeepNVMe, with and without GDS, scales I/O performance with available NVMe bandwidth. With GDS, DeepNVMe achieves a maximum of 9.6GB/sec reads and 5GB/sec writes, and without GDS achieves 7GB/sec reads and 4GB/sec writes. + + + +
+Figure 2: Using DeepNVMe to scale data transfers between NVMe and GPU memory +
+ +## ZeRO-Inference: Generative AI Performance + +ZeRO-Inference is an AI democratization technology that reduces the hardware cost of inferencing massive models by using DeepNVMe to offload model weights to CPU or NVMe memory. ZeRO-Inference is well suited for throughput-oriented applications, such as offline inferencing, and for scenarios with limited hardware budget. We use token generation workload to evaluate DeepNVMe performance for NVMe offloading. + +## High-Performance Offloading via NVMe Scaling + +We measure the generation throughput of inferencing a LLAMA3-70B model on a single NVIDIA A100-80GB with a prompt length of 512, generation length of 32, and batch size of 96. We scale the number of NVMe SSDs from 1 to 4 and present the results for ZeRO-Inference with and without GDS in Figure 3. We make two observations from these results. First, GDS consistently provides better performance compared to the bounce buffer approach, achieving 10-18% faster token generation. Second, DeepNVMe, with and without GDS, scales generation performance with available NVMe bandwidth. With four NVMe SSDs, DeepNVMe achieves generation throughput rates of 7 tokens per second with GDS and 6 tokens per second without GDS. Our profiling results suggest that DeepNVMe will continue to scale with more NVMe bandwidth, making it an economic option for boosting generative application performance. + + + +
+Figure 3: Using DeepNVMe to scale LLAMA3-70B token generation performance with NVMe offloading. +
+ +# Conclusion + +In this blog post, we introduced DeepNVMe, an I/O optimization technology created to tackle the emergence of I/O operations as key bottlenecks of Deep Learning scalability. DeepNVMe enables fast and efficient data transfers between persistent storage and DL application memory through optimizations built on popular storage technologies such as NVMe SSDs and NVIDIA GDS. We showed benefits of using DeepNVMe for LLAMA3-70B token generation on single A100-80GB GPU with NVMe offloading, for which it achieves up to 7 tokens per second in generation throughput on an Azure NC96ads\_A100\_v4 VM. DeepNVMe will be generally available in DeepSpeed versions >= [0.15.0](https://github.com/microsoft/DeepSpeed/releases/tag/v0.15.0). In future blogs, we will report DeepNVMe improvements for other I/O bound DL applications such as model checkpointing and data loading. diff --git a/blogs/deepspeed-gds/media/figure1.png b/blogs/deepspeed-gds/media/figure1.png new file mode 100755 index 000000000000..08db7d2f8afa Binary files /dev/null and b/blogs/deepspeed-gds/media/figure1.png differ diff --git a/blogs/deepspeed-gds/media/figure2.png b/blogs/deepspeed-gds/media/figure2.png new file mode 100755 index 000000000000..35be5d4c4015 Binary files /dev/null and b/blogs/deepspeed-gds/media/figure2.png differ diff --git a/blogs/deepspeed-gds/media/figure3.png b/blogs/deepspeed-gds/media/figure3.png new file mode 100755 index 000000000000..7175236f886b Binary files /dev/null and b/blogs/deepspeed-gds/media/figure3.png differ diff --git a/blogs/deepspeed-gds/media/table1.png b/blogs/deepspeed-gds/media/table1.png new file mode 100755 index 000000000000..bba571369932 Binary files /dev/null and b/blogs/deepspeed-gds/media/table1.png differ diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index 6962b8050f51..1a887b50e1a3 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -22,25 +22,20 @@ stochastic_rounding); \ } -at::Tensor quantize(torch::Tensor& val, +at::Tensor quantize(torch::Tensor& out, + torch::Tensor& val, int group_size, int stochastic_rounding, int q_bits, int q_mantisa_bits) { int total_elems = at::numel(val); - auto options = at::TensorOptions() - .dtype(torch::kInt8) - .layout(val.layout()) - .device(val.device()) - .requires_grad(false); float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges (q_bits == 12 ? 510.0 : // fp12 range (q_bits == 6 ? 28.0 : // fp6 range 6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4 // in case accuracy is not matching! int num_groups = total_elems / group_size; - auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options); DISPATCH_QUANTIZE(kHalf, __half, 23, 8); #ifdef BF16_AVAILABLE @@ -108,9 +103,22 @@ void selective_dequantize(torch::Tensor& val, #endif } +at::Tensor get_scales(torch::Tensor& out, int num_groups) +{ + auto options = at::TensorOptions() + .dtype(torch::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto scales = + torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options); + return scales; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("quantize", &quantize, "quantize function"); m.def("dequantize", &dequantize, "dequantize function"); + m.def("get_scales", &get_scales, "get scales function"); m.def("selective_dequantize", &selective_dequantize, "selective dequantize function"); } diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py index f5343af45fb8..8e4f23dfba89 100644 --- a/deepspeed/linear/quantization.py +++ b/deepspeed/linear/quantization.py @@ -57,18 +57,18 @@ def __new__( def _ensure_quantized(self, tensor: torch.Tensor): # If the tensor is on the accelerator and is not quantized, then quantize it in-place. - if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8: + if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8: with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): tensor.data = self.quantizer.quantize(tensor.data, q_bits=self.quantization_config.q_bits, q_mantisa_bits=self.quantization_config.mantissa_bits) - assert tensor.dtype == torch.int8 + assert tensor.dtype == torch.uint8 def dequantized(self) -> torch.Tensor: """ Return a tensor containing the dequantized weights of this parameter. """ - if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8: + if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8: with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): return self.quantizer.dequantize(self.data, q_bits=self.quantization_config.q_bits, diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py index 7ea5ce5af19e..15179984173c 100755 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -9,7 +9,7 @@ from . import lion from . import sparse_attention from . import transformer - +from . import fp_quantizer from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from ..git_version_info import compatible_ops as __compatible_ops__ diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 995bbae4aeaf..51377bc6092c 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -4,3 +4,9 @@ # DeepSpeed Team from .quantize import FP_Quantize, Quantizer + +try: + import triton + from .fp8_gemm import matmul_fp8 +except ImportError: + pass diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py new file mode 100644 index 000000000000..55504e3af8c9 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> bf16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (w + 0x3C00).to(tl.uint16) + w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K + weight = tl.load(weight_data, mask=weight_mask, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), + mask=weight_mask, + other=0.0) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.bfloat16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +@triton.jit +def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> fp16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) + w = (w + 0x2000).to(tl.uint16) + w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +def matmul_fp8(inp, weight, scale, quantization_group_size): + + assert inp.shape[1] == weight.shape[0], \ + f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" + + M, K = inp.shape + K, N = weight.shape + + out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) + + # GEMM tuning parameters! + # TODO: Add a more configurable tuning for selecting the best GeMM + BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = max(64, quantization_group_size) + GROUP_SIZE_M = 8 + num_stages = 4 + num_warps = 4 + if M >= 256: + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = max(128, quantization_group_size) + num_stages = 3 + num_warps = 8 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 + kernel[grid](inp, + weight, + out, + scale, + M, + N, + K, + inp.stride(0), + inp.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + quantization_group_size=quantization_group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + num_stages=num_stages, + num_warps=num_warps) + return out diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index f8435bda16c1..170954e0cf71 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -7,7 +7,9 @@ import abc from abc import ABC +import gc from deepspeed.ops.op_builder import FPQuantizerBuilder +from deepspeed.accelerator import get_accelerator fp_quant_module = None @@ -71,15 +73,27 @@ def quantize(self, else: assert (0), \ f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" - - out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) - + self.num_groups = input.numel() // self.group_size + self.input_q = torch.ones(self.num_groups, + int(self.group_size * q_bits) // 8 + 4, + dtype=torch.uint8, + device=input.device) + out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) if return_meta_tensor: - data, scale = out.split(self.group_size, dim=-1) - return data.contiguous().reshape(input.shape), scale.contiguous() + data, self.scale = out.split(self.group_size, dim=-1) + data = data.contiguous().reshape(input.shape) + self.scale = self.scale.contiguous() + del self.input_q + del out + gc.collect() + get_accelerator().empty_cache() + return data, self.scale return out + def get_scales(self): + return fp_quant_module.get_scales(self.scale, self.num_groups) + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: assert (self.orig_dtype is not None), \ "[De-quantization Error]: you need to call quantize before dequantizing!" @@ -101,7 +115,6 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() - fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index a715843d8eff..f17cfa883cc6 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - import torch from typing import Any, Tuple @@ -10,9 +9,21 @@ from torch.nn import Module import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + + +def post_all2all(transpose, res_shape): + + def post_func(input): + if transpose: + input = input.transpose(0, 2).contiguous() + input = input.reshape(res_shape) + return input + return post_func -def single_all_to_all(input, scatter_idx, gather_idx, group): + +def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size @@ -29,32 +40,76 @@ def single_all_to_all(input, scatter_idx, gather_idx, group): ).transpose(0, 1).contiguous() output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) - - # if scattering the seq-dim, transpose the heads back to the original dimension - if scatter_idx < 2: - output = output.transpose(0, 2).contiguous() + work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - return output.reshape( - inp_shape[: gather_idx] + \ + res_shape=( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]).contiguous() + inp_shape[gather_idx + 1:]) + transpose = True if scatter_idx < 2 else False + post_all2all_fun = post_all2all(transpose, res_shape) + + if async_op: + if type in ('dq', 'dk'): + handle[type + '_work'] = work + handle[type + '_grad'] = output + handle[type + '_post_all2all_func'] = post_all2all_fun + return output.view(res_shape) + + res = post_all2all_fun(output) + return res class _SeqAllToAll(torch.autograd.Function): @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: - + def forward(ctx: Any, + group: dist.ProcessGroup, + input: Tensor, + scatter_idx: int, + gather_idx: int, + stream=None, + handle=None, + type=None, + is_fwd=True) -> Tensor: ctx.group = group ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx - - return single_all_to_all(input, scatter_idx, gather_idx, group) + ctx.stream = stream + ctx.handle = handle + ctx.type = type + if ctx.handle is None: + res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + + else: + # overlap communication path + if not is_fwd and type == 'o': + assert ctx.stream != None + res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + get_accelerator().current_stream().wait_stream(ctx.stream) + del ctx.stream.activation_buffer_list + # The computation of d o_weight can overlap with the communication of d o_input + + elif not is_fwd and type in ('q', 'k'): + # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv + type = 'd' + type + res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) + + elif is_fwd and type in ('q', 'k'): + # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v + type = 'fwd_' + type + res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type) + + else: + res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + + return res @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: - return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) + + return (None, + _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream, ctx.handle, + ctx.type, False), None, None, None, None, None, None) class DistributedAttention(torch.nn.Module): @@ -73,6 +128,7 @@ def __init__( sequence_process_group: dist.ProcessGroup, scatter_idx: int = 2, gather_idx: int = 0, + sp_stream=None, ) -> None: super(DistributedAttention, self).__init__() @@ -80,6 +136,17 @@ def __init__( self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx + self.sp_overlap_comm = False + self.overlap_handles = None + self.sp_stream = sp_stream + if sp_stream is not None: + self.overlap_handles = {} + self.sp_overlap_comm = True + self.dafult_stream = get_accelerator().default_stream() + + def layer_sync(self, layer): + if self.sp_overlap_comm and hasattr(layer, 'done_event'): + self.dafult_stream.wait_event(layer.done_event) def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: """ forward @@ -93,17 +160,53 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwarg Returns: * output (Tensor): context output """ + # TODO Merge three alltoall calls into one # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! #in shape : e.g., [s/p:h:] - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) - value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) + + def bwd_hook(layer_type): + + def pre_hook_fun(grad): + type = 'd' + layer_type + self.overlap_handles[type + '_work'].wait() + self.sp_stream.wait_stream(self.dafult_stream) + all2all_output = self.overlap_handles[type + '_grad'] + grad = list(grad) + grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output) + grad = tuple(grad) + + return pre_hook_fun + + self.layer_sync(query) + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None, + self.overlap_handles, 'q') + self.layer_sync(key) + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles, + 'k') + if self.sp_overlap_comm: + self.dafult_stream.wait_stream(self.sp_stream) + + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, + self.overlap_handles, 'v') + + if self.sp_overlap_comm: + # Register a hook to synchronize dq and dk after the all-to-all + # operation when the gradient data is used. + # Place this logic after the q, k, v all-to-all operation to + # improve interpreter speed to + # call and launch of the forward all-to-all communication. + grad_fn_q = query.grad_fn.next_functions[0][0] + grad_fn_q.register_prehook(bwd_hook(layer_type='q')) + grad_fn_k = key.grad_fn.next_functions[0][0] + grad_fn_k.register_prehook(bwd_hook(layer_type='k')) #out shape : e.g., [s:h/p:] + context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream, + self.overlap_handles, 'o') #out e.g., [s/p::h] return output diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md index 932bb355cf26..b1a8b5369761 100644 --- a/docs/_tutorials/onebit-adam.md +++ b/docs/_tutorials/onebit-adam.md @@ -75,6 +75,12 @@ Alternatively, the standard mpirun launcher can also be used as follows: mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] ``` +#### 1.2.3 Compressed implementation + +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. + +This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`. + ### 1.3 1-bit Algorithm The detailed description of the 1-bit Algorithm can be seen from our [blog post](https://www.deepspeed.ai/2020/09/08/onebit-adam-blog-post.html) and our [paper](https://arxiv.org/abs/2102.02888). @@ -106,7 +112,7 @@ Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_ `cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. -(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" and "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. +(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`. #### 1.4.1 (New in v2) Momentum masks for parameters with constant zero gradients Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. diff --git a/docs/_tutorials/onebit-lamb.md b/docs/_tutorials/onebit-lamb.md index 4873f1f35c17..b6c6ef075036 100644 --- a/docs/_tutorials/onebit-lamb.md +++ b/docs/_tutorials/onebit-lamb.md @@ -61,6 +61,10 @@ Alternatively, the standard mpirun launcher can also be used as follows: mpirun -np [num processes] -ppn [num GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] ``` +#### 1.2.3 Compressed implementation +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. +This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`. + ### 1.3 1-bit LAMB Algorithm The detailed description of the 1-bit LAMB algorithm can be seen from our [paper](https://arxiv.org/abs/2104.06069). @@ -101,7 +105,7 @@ Please note the new parameters `freeze_step`, `cuda_aware`, `comm_backend_name`, `cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. -`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" or "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. +`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`. `coeff_beta` is used when calculating a moving average of the LAMB scaling coefficient during the warmup stage. This moving average is then used as the frozen base scaling coefficient during the compression stage. diff --git a/docs/_tutorials/zero-one-adam.md b/docs/_tutorials/zero-one-adam.md index 2dd956e802fd..055c685faf89 100644 --- a/docs/_tutorials/zero-one-adam.md +++ b/docs/_tutorials/zero-one-adam.md @@ -62,6 +62,10 @@ Alternatively, the standard mpirun launcher can also be used as follows: mpirun -np [num processes] -ppn [num GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] ``` +#### 1.2.3 Compressed implementation +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. +This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`. + ### 1.3 0/1 Adam Algorithm The detailed description of the 0/1 Adam algorithm can be seen from our [paper](https://arxiv.org/abs/2202.06009). @@ -107,7 +111,7 @@ The learning rate policy is the default policy used in 0/1 Adam, and the value o `cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. -`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" or "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. +`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`. #### 1.4.1 Momentum masks for parameters with constant zero gradients Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 0/1 Adam we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. diff --git a/op_builder/builder.py b/op_builder/builder.py index 03611bf56284..8998fc0eddb8 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import re import sys import time import importlib @@ -215,19 +216,31 @@ def installed_rocm_version(): ROCM_MAJOR = '0' ROCM_MINOR = '0' + ROCM_VERSION_DEV_RAW = "" if OpBuilder.is_rocm_pytorch(): from torch.utils.cpp_extension import ROCM_HOME - rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev") + rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version") if rocm_ver_file.is_file(): with open(rocm_ver_file, 'r') as file: ROCM_VERSION_DEV_RAW = file.read() elif "rocm" in torch.__version__: ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1] + if ROCM_VERSION_DEV_RAW != "": + ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] + ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] else: + # Look in /usr/include/rocm-version.h + rocm_ver_file = Path("/usr/include/rocm_version.h") + if rocm_ver_file.is_file(): + with open(rocm_ver_file, 'r') as file: + for ln in file.readlines(): + if "#define ROCM_VERSION_MAJOR" in ln: + ROCM_MAJOR = re.findall(r'\S+', ln)[2] + elif "#define ROCM_VERSION_MINOR" in ln: + ROCM_MINOR = re.findall(r'\S+', ln)[2] + if ROCM_MAJOR == '0': assert False, "Could not detect ROCm version" - assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version" - ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] - ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] + OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) return OpBuilder._rocm_version @@ -235,7 +248,10 @@ def installed_rocm_version(): def get_rocm_gpu_arch(): if OpBuilder._rocm_gpu_arch: return OpBuilder._rocm_gpu_arch - rocm_gpu_arch_cmd = "/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*'" + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + rocm_gpu_arch_cmd = str(rocm_info) + " | grep -o -m 1 'gfx.*'" try: result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True) rocm_gpu_arch = result.decode('utf-8').strip() @@ -248,7 +264,12 @@ def get_rocm_gpu_arch(): def get_rocm_wavefront_size(): if OpBuilder._rocm_wavefront_size: return OpBuilder._rocm_wavefront_size - rocm_wavefront_size_cmd = "/opt/rocm/bin/rocminfo | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'" + + rocm_info = Path("/opt/rocm/bin/rocminfo") + if (not rocm_info.is_file()): + rocm_info = Path("rocminfo") + rocm_wavefront_size_cmd = str( + rocm_info) + " | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'" try: result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True) rocm_wavefront_size = result.decode('utf-8').strip() @@ -749,11 +770,18 @@ def nvcc_args(self): except ValueError: nvcc_threads = min(os.cpu_count(), 8) - cuda_major, _ = installed_cuda_version() + cuda_major, cuda_minor = installed_cuda_version() + if cuda_major > 10: + if cuda_major == 12 and cuda_minor >= 5: + std_lib = '-std=c++20' + else: + std_lib = '-std=c++17' + else: + std_lib = '-std=c++14' args += [ - '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', - '-std=c++17' if cuda_major > 10 else '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', f'--threads={nvcc_threads}' + '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', std_lib, + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + f'--threads={nvcc_threads}' ] if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1': args.append('--ptxas-options=-v') @@ -772,25 +800,32 @@ def libraries_args(self): class TorchCPUOpBuilder(CUDAOpBuilder): + def get_cuda_lib64_path(self): + import torch + if not self.is_rocm_pytorch(): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") + if not os.path.exists(CUDA_LIB64): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib") + else: + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + return CUDA_LIB64 + def extra_ldflags(self): if self.build_for_cpu: return ['-fopenmp'] if not self.is_rocm_pytorch(): - return ['-lcurand'] + ld_flags = ['-lcurand'] + if not self.build_for_cpu: + ld_flags.append(f'-L{self.get_cuda_lib64_path()}') + return ld_flags return [] def cxx_args(self): - import torch args = [] if not self.build_for_cpu: - if not self.is_rocm_pytorch(): - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") - if not os.path.exists(CUDA_LIB64): - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib") - else: - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + CUDA_LIB64 = self.get_cuda_lib64_path() args += super().cxx_args() args += [ diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index 9f07ec3d1275..c7d2e72b5408 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -3,6 +3,11 @@ # DeepSpeed Team +try: + from packaging import version as pkg_version +except ImportError: + pkg_version = None + from .builder import CUDAOpBuilder, installed_cuda_version @@ -36,6 +41,29 @@ def is_compatible(self, verbose=True): if torch_cuda_major < 11 or sys_cuda_major < 11: self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False + + try: + import triton + except ImportError: + self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels") + return False + + # triton 2.3.0 and 2.3.1 are okay and the only versions released in 2.3.x before 3.x was released + if pkg_version: + allowed = pkg_version.parse("2.3") + installed_triton = pkg_version.parse(triton.__version__) + triton_mismatch = installed_triton.major != allowed.major or installed_triton.minor != allowed.minor + else: + installed_triton = triton.__version__ + major, minor, _ = installed_triton.split(".") + triton_mismatch = major != "2" or minor != "3" + + if triton_mismatch: + self.warning( + f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels" + ) + return False + return super().is_compatible(verbose) and cuda_okay def filter_ccs(self, ccs): diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 76eb88eea560..eadf670d9328 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -94,7 +94,10 @@ def _hf_model_list() -> List[ModelInfo]: model_data = {"cache_time": 0, "model_list": []} if os.path.isfile(cache_file_path): with open(cache_file_path, 'rb') as f: - model_data = pickle.load(f) + try: + model_data = pickle.load(f) + except Exception as e: + print(f"Error loading cache file {cache_file_path}: {e}") current_time = time.time() diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py index 9479b3cba8a0..84a9f766ef74 100644 --- a/tests/unit/linear/test_quant_param.py +++ b/tests/unit/linear/test_quant_param.py @@ -42,7 +42,7 @@ def test_move_to_accelerator(self): assert qp.device == torch.device('cpu') qp = qp.to(get_accelerator().current_device_name()) assert qp.device == torch.device(device) - assert qp.dtype == torch.int8 + assert qp.dtype == torch.uint8 def test_hf_clone(self): device = get_accelerator().current_device_name() diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py new file mode 100644 index 000000000000..d66f7c8cb4cc --- /dev/null +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + +from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +@pytest.mark.parametrize("q_bits", [8], ids=[ + "qbits8", +]) +@pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048]) +def test_fp_quant(dtype, q_bits, M): + quantization_group_size = 128 + fpq = FP_Quantize(group_size=quantization_group_size) + + N = 8192 + H = 4096 + + x = torch.randn(M, H, dtype=dtype, device='cuda') + weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda') + + weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True) + scale = fpq.get_scales() + out = matmul_fp8( + x, + weight, + scale, + quantization_group_size, + ) + + out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale)) + + error = ((out - out_q).abs() / (out.abs() + 1e-5)).sum() / out.numel() + assert 0.004 > error, f"failed on batch-size {M} with error {error}"