From 4f9506729fda0e9b178761f58f5316018d77990b Mon Sep 17 00:00:00 2001 From: Reza Yazdani <152926435+sfc-gh-reyazda@users.noreply.github.com> Date: Mon, 29 Jul 2024 11:07:00 -0700 Subject: [PATCH 1/2] Add fp8-fused gemm kernel (#5764) This PR adds the new fused kernel for the Dense GeMM using fp8-quantized weight. --------- Co-authored-by: Jeff Rasley Co-authored-by: Jeff Rasley --- csrc/fp_quantizer/fp_quantize.cpp | 22 ++- deepspeed/linear/quantization.py | 6 +- deepspeed/ops/__init__.py | 2 +- deepspeed/ops/fp_quantizer/__init__.py | 6 + deepspeed/ops/fp_quantizer/fp8_gemm.py | 171 +++++++++++++++++++ deepspeed/ops/fp_quantizer/quantize.py | 25 ++- op_builder/fp_quantizer.py | 28 +++ tests/unit/linear/test_quant_param.py | 2 +- tests/unit/ops/fp_quantizer/test_fp8_gemm.py | 45 +++++ 9 files changed, 289 insertions(+), 18 deletions(-) create mode 100644 deepspeed/ops/fp_quantizer/fp8_gemm.py create mode 100644 tests/unit/ops/fp_quantizer/test_fp8_gemm.py diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index 6962b8050f51..1a887b50e1a3 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -22,25 +22,20 @@ stochastic_rounding); \ } -at::Tensor quantize(torch::Tensor& val, +at::Tensor quantize(torch::Tensor& out, + torch::Tensor& val, int group_size, int stochastic_rounding, int q_bits, int q_mantisa_bits) { int total_elems = at::numel(val); - auto options = at::TensorOptions() - .dtype(torch::kInt8) - .layout(val.layout()) - .device(val.device()) - .requires_grad(false); float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges (q_bits == 12 ? 510.0 : // fp12 range (q_bits == 6 ? 28.0 : // fp6 range 6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4 // in case accuracy is not matching! int num_groups = total_elems / group_size; - auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options); DISPATCH_QUANTIZE(kHalf, __half, 23, 8); #ifdef BF16_AVAILABLE @@ -108,9 +103,22 @@ void selective_dequantize(torch::Tensor& val, #endif } +at::Tensor get_scales(torch::Tensor& out, int num_groups) +{ + auto options = at::TensorOptions() + .dtype(torch::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto scales = + torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options); + return scales; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("quantize", &quantize, "quantize function"); m.def("dequantize", &dequantize, "dequantize function"); + m.def("get_scales", &get_scales, "get scales function"); m.def("selective_dequantize", &selective_dequantize, "selective dequantize function"); } diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py index f5343af45fb8..8e4f23dfba89 100644 --- a/deepspeed/linear/quantization.py +++ b/deepspeed/linear/quantization.py @@ -57,18 +57,18 @@ def __new__( def _ensure_quantized(self, tensor: torch.Tensor): # If the tensor is on the accelerator and is not quantized, then quantize it in-place. - if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8: + if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8: with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): tensor.data = self.quantizer.quantize(tensor.data, q_bits=self.quantization_config.q_bits, q_mantisa_bits=self.quantization_config.mantissa_bits) - assert tensor.dtype == torch.int8 + assert tensor.dtype == torch.uint8 def dequantized(self) -> torch.Tensor: """ Return a tensor containing the dequantized weights of this parameter. """ - if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8: + if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8: with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): return self.quantizer.dequantize(self.data, q_bits=self.quantization_config.q_bits, diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py index 7ea5ce5af19e..15179984173c 100755 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -9,7 +9,7 @@ from . import lion from . import sparse_attention from . import transformer - +from . import fp_quantizer from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from ..git_version_info import compatible_ops as __compatible_ops__ diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 995bbae4aeaf..51377bc6092c 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -4,3 +4,9 @@ # DeepSpeed Team from .quantize import FP_Quantize, Quantizer + +try: + import triton + from .fp8_gemm import matmul_fp8 +except ImportError: + pass diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py new file mode 100644 index 000000000000..55504e3af8c9 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> bf16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (w + 0x3C00).to(tl.uint16) + w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K + weight = tl.load(weight_data, mask=weight_mask, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), + mask=weight_mask, + other=0.0) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.bfloat16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +@triton.jit +def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> fp16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) + w = (w + 0x2000).to(tl.uint16) + w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +def matmul_fp8(inp, weight, scale, quantization_group_size): + + assert inp.shape[1] == weight.shape[0], \ + f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" + + M, K = inp.shape + K, N = weight.shape + + out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) + + # GEMM tuning parameters! + # TODO: Add a more configurable tuning for selecting the best GeMM + BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = max(64, quantization_group_size) + GROUP_SIZE_M = 8 + num_stages = 4 + num_warps = 4 + if M >= 256: + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = max(128, quantization_group_size) + num_stages = 3 + num_warps = 8 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 + kernel[grid](inp, + weight, + out, + scale, + M, + N, + K, + inp.stride(0), + inp.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + quantization_group_size=quantization_group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + num_stages=num_stages, + num_warps=num_warps) + return out diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index f8435bda16c1..170954e0cf71 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -7,7 +7,9 @@ import abc from abc import ABC +import gc from deepspeed.ops.op_builder import FPQuantizerBuilder +from deepspeed.accelerator import get_accelerator fp_quant_module = None @@ -71,15 +73,27 @@ def quantize(self, else: assert (0), \ f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" - - out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) - + self.num_groups = input.numel() // self.group_size + self.input_q = torch.ones(self.num_groups, + int(self.group_size * q_bits) // 8 + 4, + dtype=torch.uint8, + device=input.device) + out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) if return_meta_tensor: - data, scale = out.split(self.group_size, dim=-1) - return data.contiguous().reshape(input.shape), scale.contiguous() + data, self.scale = out.split(self.group_size, dim=-1) + data = data.contiguous().reshape(input.shape) + self.scale = self.scale.contiguous() + del self.input_q + del out + gc.collect() + get_accelerator().empty_cache() + return data, self.scale return out + def get_scales(self): + return fp_quant_module.get_scales(self.scale, self.num_groups) + def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor: assert (self.orig_dtype is not None), \ "[De-quantization Error]: you need to call quantize before dequantizing!" @@ -101,7 +115,6 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() - fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index 9f07ec3d1275..c7d2e72b5408 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -3,6 +3,11 @@ # DeepSpeed Team +try: + from packaging import version as pkg_version +except ImportError: + pkg_version = None + from .builder import CUDAOpBuilder, installed_cuda_version @@ -36,6 +41,29 @@ def is_compatible(self, verbose=True): if torch_cuda_major < 11 or sys_cuda_major < 11: self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False + + try: + import triton + except ImportError: + self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels") + return False + + # triton 2.3.0 and 2.3.1 are okay and the only versions released in 2.3.x before 3.x was released + if pkg_version: + allowed = pkg_version.parse("2.3") + installed_triton = pkg_version.parse(triton.__version__) + triton_mismatch = installed_triton.major != allowed.major or installed_triton.minor != allowed.minor + else: + installed_triton = triton.__version__ + major, minor, _ = installed_triton.split(".") + triton_mismatch = major != "2" or minor != "3" + + if triton_mismatch: + self.warning( + f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels" + ) + return False + return super().is_compatible(verbose) and cuda_okay def filter_ccs(self, ccs): diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py index 9479b3cba8a0..84a9f766ef74 100644 --- a/tests/unit/linear/test_quant_param.py +++ b/tests/unit/linear/test_quant_param.py @@ -42,7 +42,7 @@ def test_move_to_accelerator(self): assert qp.device == torch.device('cpu') qp = qp.to(get_accelerator().current_device_name()) assert qp.device == torch.device(device) - assert qp.dtype == torch.int8 + assert qp.dtype == torch.uint8 def test_hf_clone(self): device = get_accelerator().current_device_name() diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py new file mode 100644 index 000000000000..d66f7c8cb4cc --- /dev/null +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + +from deepspeed.ops.op_builder import FPQuantizerBuilder + +if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: + pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) + +from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 + + +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) +@pytest.mark.parametrize("q_bits", [8], ids=[ + "qbits8", +]) +@pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048]) +def test_fp_quant(dtype, q_bits, M): + quantization_group_size = 128 + fpq = FP_Quantize(group_size=quantization_group_size) + + N = 8192 + H = 4096 + + x = torch.randn(M, H, dtype=dtype, device='cuda') + weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda') + + weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True) + scale = fpq.get_scales() + out = matmul_fp8( + x, + weight, + scale, + quantization_group_size, + ) + + out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale)) + + error = ((out - out_q).abs() / (out.abs() + 1e-5)).sum() / out.numel() + assert 0.004 > error, f"failed on batch-size {M} with error {error}" From afe1b9ede1794b0bf221607be8051728cf4afa18 Mon Sep 17 00:00:00 2001 From: Liangliang Ma Date: Tue, 30 Jul 2024 02:38:03 +0800 Subject: [PATCH 2/2] Add doc of compressed backend in Onebit optimizers (#5782) This one is document supplement for https://github.com/microsoft/DeepSpeed/pull/5473. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- docs/_tutorials/onebit-adam.md | 8 +++++++- docs/_tutorials/onebit-lamb.md | 6 +++++- docs/_tutorials/zero-one-adam.md | 6 +++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md index 932bb355cf26..b1a8b5369761 100644 --- a/docs/_tutorials/onebit-adam.md +++ b/docs/_tutorials/onebit-adam.md @@ -75,6 +75,12 @@ Alternatively, the standard mpirun launcher can also be used as follows: mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] ``` +#### 1.2.3 Compressed implementation + +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. + +This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`. + ### 1.3 1-bit Algorithm The detailed description of the 1-bit Algorithm can be seen from our [blog post](https://www.deepspeed.ai/2020/09/08/onebit-adam-blog-post.html) and our [paper](https://arxiv.org/abs/2102.02888). @@ -106,7 +112,7 @@ Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_ `cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. -(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" and "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. +(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`. #### 1.4.1 (New in v2) Momentum masks for parameters with constant zero gradients Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. diff --git a/docs/_tutorials/onebit-lamb.md b/docs/_tutorials/onebit-lamb.md index 4873f1f35c17..b6c6ef075036 100644 --- a/docs/_tutorials/onebit-lamb.md +++ b/docs/_tutorials/onebit-lamb.md @@ -61,6 +61,10 @@ Alternatively, the standard mpirun launcher can also be used as follows: mpirun -np [num processes] -ppn [num GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] ``` +#### 1.2.3 Compressed implementation +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. +This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`. + ### 1.3 1-bit LAMB Algorithm The detailed description of the 1-bit LAMB algorithm can be seen from our [paper](https://arxiv.org/abs/2104.06069). @@ -101,7 +105,7 @@ Please note the new parameters `freeze_step`, `cuda_aware`, `comm_backend_name`, `cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. -`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" or "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. +`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`. `coeff_beta` is used when calculating a moving average of the LAMB scaling coefficient during the warmup stage. This moving average is then used as the frozen base scaling coefficient during the compression stage. diff --git a/docs/_tutorials/zero-one-adam.md b/docs/_tutorials/zero-one-adam.md index 2dd956e802fd..055c685faf89 100644 --- a/docs/_tutorials/zero-one-adam.md +++ b/docs/_tutorials/zero-one-adam.md @@ -62,6 +62,10 @@ Alternatively, the standard mpirun launcher can also be used as follows: mpirun -np [num processes] -ppn [num GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] ``` +#### 1.2.3 Compressed implementation +This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. +This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`. + ### 1.3 0/1 Adam Algorithm The detailed description of the 0/1 Adam algorithm can be seen from our [paper](https://arxiv.org/abs/2202.06009). @@ -107,7 +111,7 @@ The learning rate policy is the default policy used in 0/1 Adam, and the value o `cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. -`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" or "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. +`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`. #### 1.4.1 Momentum masks for parameters with constant zero gradients Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 0/1 Adam we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script.