Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jul 29, 2024
2 parents 2d766d9 + afe1b9e commit 0be147f
Show file tree
Hide file tree
Showing 12 changed files with 306 additions and 21 deletions.
22 changes: 15 additions & 7 deletions csrc/fp_quantizer/fp_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,20 @@
stochastic_rounding); \
}

at::Tensor quantize(torch::Tensor& val,
at::Tensor quantize(torch::Tensor& out,
torch::Tensor& val,
int group_size,
int stochastic_rounding,
int q_bits,
int q_mantisa_bits)
{
int total_elems = at::numel(val);
auto options = at::TensorOptions()
.dtype(torch::kInt8)
.layout(val.layout())
.device(val.device())
.requires_grad(false);
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
(q_bits == 12 ? 510.0 : // fp12 range
(q_bits == 6 ? 28.0 : // fp6 range
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
// in case accuracy is not matching!
int num_groups = total_elems / group_size;
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);

DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
#ifdef BF16_AVAILABLE
Expand Down Expand Up @@ -108,9 +103,22 @@ void selective_dequantize(torch::Tensor& val,
#endif
}

at::Tensor get_scales(torch::Tensor& out, int num_groups)
{
auto options = at::TensorOptions()
.dtype(torch::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto scales =
torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options);
return scales;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
m.def("get_scales", &get_scales, "get scales function");
m.def("selective_dequantize", &selective_dequantize, "selective dequantize function");
}
6 changes: 3 additions & 3 deletions deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ def __new__(

def _ensure_quantized(self, tensor: torch.Tensor):
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8:
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8:
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.int8
assert tensor.dtype == torch.uint8

def dequantized(self) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8:
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8:
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
return self.quantizer.dequantize(self.data,
q_bits=self.quantization_config.q_bits,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import lion
from . import sparse_attention
from . import transformer

from . import fp_quantizer
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig

from ..git_version_info import compatible_ops as __compatible_ops__
6 changes: 6 additions & 0 deletions deepspeed/ops/fp_quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@
# DeepSpeed Team

from .quantize import FP_Quantize, Quantizer

try:
import triton
from .fp8_gemm import matmul_fp8
except ImportError:
pass
171 changes: 171 additions & 0 deletions deepspeed/ops/fp_quantizer/fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

######## Fused MoE kernel #########
# These kernels are implemented for
# fusing GeMM with dequantization of
# fp8 weight data when using bit-16
# activation.
###################################

import torch
import triton
import triton.language as tl


@triton.jit
def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> bf16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
w = (w + 0x3C00).to(tl.uint16)
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk
weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
weight = tl.load(weight_data, mask=weight_mask, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
mask=weight_mask,
other=0.0)

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


@triton.jit
def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> fp16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
w = (w + 0x2000).to(tl.uint16)
w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk

weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


def matmul_fp8(inp, weight, scale, quantization_group_size):

assert inp.shape[1] == weight.shape[0], \
f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"

M, K = inp.shape
K, N = weight.shape

out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)

# GEMM tuning parameters!
# TODO: Add a more configurable tuning for selecting the best GeMM
BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = max(64, quantization_group_size)
GROUP_SIZE_M = 8
num_stages = 4
num_warps = 4
if M >= 256:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = max(128, quantization_group_size)
num_stages = 3
num_warps = 8

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
kernel[grid](inp,
weight,
out,
scale,
M,
N,
K,
inp.stride(0),
inp.stride(1),
weight.stride(0),
weight.stride(1),
out.stride(0),
out.stride(1),
quantization_group_size=quantization_group_size,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=GROUP_SIZE_M,
num_stages=num_stages,
num_warps=num_warps)
return out
25 changes: 19 additions & 6 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import abc
from abc import ABC

import gc
from deepspeed.ops.op_builder import FPQuantizerBuilder
from deepspeed.accelerator import get_accelerator

fp_quant_module = None

Expand Down Expand Up @@ -71,15 +73,27 @@ def quantize(self,
else:
assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"

out = fp_quant_module.quantize(input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)

self.num_groups = input.numel() // self.group_size
self.input_q = torch.ones(self.num_groups,
int(self.group_size * q_bits) // 8 + 4,
dtype=torch.uint8,
device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)
if return_meta_tensor:
data, scale = out.split(self.group_size, dim=-1)
return data.contiguous().reshape(input.shape), scale.contiguous()
data, self.scale = out.split(self.group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
self.scale = self.scale.contiguous()
del self.input_q
del out
gc.collect()
get_accelerator().empty_cache()
return data, self.scale

return out

def get_scales(self):
return fp_quant_module.get_scales(self.scale, self.num_groups)

def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=None) -> torch.Tensor:
assert (self.orig_dtype is not None), \
"[De-quantization Error]: you need to call quantize before dequantizing!"
Expand All @@ -101,7 +115,6 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()

fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
return fp_out

Expand Down
8 changes: 7 additions & 1 deletion docs/_tutorials/onebit-adam.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ Alternatively, the standard mpirun launcher can also be used as follows:
mpirun -np [#processes] -ppn [#GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py]
```
#### 1.2.3 Compressed implementation
This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`.
This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`.
### 1.3 1-bit Algorithm
The detailed description of the 1-bit Algorithm can be seen from our [blog post](https://www.deepspeed.ai/2020/09/08/onebit-adam-blog-post.html) and our [paper](https://arxiv.org/abs/2102.02888).
Expand Down Expand Up @@ -106,7 +112,7 @@ Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_
`cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication.
(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" and "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`.
(New in v2) `comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`.
#### 1.4.1 (New in v2) Momentum masks for parameters with constant zero gradients
Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script.
Expand Down
6 changes: 5 additions & 1 deletion docs/_tutorials/onebit-lamb.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ Alternatively, the standard mpirun launcher can also be used as follows:
mpirun -np [num processes] -ppn [num GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py]
```

#### 1.2.3 Compressed implementation
This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend`, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype, which is utilized in one-bit algorithm. An example can be found in `Deepspeed/op_builder/xpu/packbits.py`.
This approach does not require NCCL or MPI based communication library. It will automatically use your default communication library selected by your accelerator in `deepspeed/comm`.

### 1.3 1-bit LAMB Algorithm

The detailed description of the 1-bit LAMB algorithm can be seen from our [paper](https://arxiv.org/abs/2104.06069).
Expand Down Expand Up @@ -101,7 +105,7 @@ Please note the new parameters `freeze_step`, `cuda_aware`, `comm_backend_name`,

`cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication.

`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" or "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`.
`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL, MPI-based and compressed implementations by setting `comm_backend_name` to "nccl", "mpi" or "compressed". When using NCCL-based implementation, there is no need to set `cuda_aware`.

`coeff_beta` is used when calculating a moving average of the LAMB scaling coefficient during the warmup stage. This moving average is then used as the frozen base scaling coefficient during the compression stage.

Expand Down
Loading

0 comments on commit 0be147f

Please sign in to comment.