Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimizedLinear updates #5791

Merged
merged 20 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions csrc/fp_quantizer/fp_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,20 @@
stochastic_rounding); \
}

at::Tensor quantize(torch::Tensor& val,
at::Tensor quantize(torch::Tensor& out,
torch::Tensor& val,
int group_size,
int stochastic_rounding,
int q_bits,
int q_mantisa_bits)
{
int total_elems = at::numel(val);
auto options = at::TensorOptions()
.dtype(torch::kInt8)
.layout(val.layout())
.device(val.device())
.requires_grad(false);
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
(q_bits == 12 ? 510.0 : // fp12 range
(q_bits == 6 ? 28.0 : // fp6 range
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
// in case accuracy is not matching!
int num_groups = total_elems / group_size;
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);

DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
#ifdef BF16_AVAILABLE
Expand Down Expand Up @@ -108,9 +103,22 @@ void selective_dequantize(torch::Tensor& val,
#endif
}

at::Tensor get_scales(torch::Tensor& out, int num_groups)
{
auto options = at::TensorOptions()
.dtype(torch::kFloat)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto scales =
torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options);
return scales;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
m.def("get_scales", &get_scales, "get scales function");
m.def("selective_dequantize", &selective_dequantize, "selective dequantize function");
}
2 changes: 2 additions & 0 deletions deepspeed/linear/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class LoRAConfig:
lora_r: int = 64
lora_alpha: float = 16.
base_weight_sharding: int = 1
offload: bool = False
offload_ratio: float = 0.0


@dataclass
Expand Down
57 changes: 38 additions & 19 deletions deepspeed/linear/optimized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,50 +85,69 @@ def __init__(self,
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
device = get_accelerator().current_device_name() if device is None else device
self.device = get_accelerator().current_device_name() if device is None else device
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"

self.zero_shards = self.lora_config.base_weight_sharding
self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards)
w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype))
if self.zero_shards > 1:
w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype), requires_grad=False)
else:
w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False)
torch.nn.init.xavier_uniform_(w)

if self.quantization_config is not None:
assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization"
self.base_weight = QuantizedParameter(w, quantization_config=quantization_config)
self.weight = QuantizedParameter(w, quantization_config=quantization_config)
else:
self.base_weight = w
self.weight = w

self.weight.requires_grad = False

self.base_weight.requires_grad = False
# Mark base weight to prevent broadcast and ensure proper offload behavior
self.weight.ds_optim_param = True

# Use RS lora for now.
self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r)
self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r
jeffra marked this conversation as resolved.
Show resolved Hide resolved
# Keeping lora weights in bf16 precision for ease of training.
self.lora_weight_1 = nn.Linear(self.input_dim,
self.lora_config.lora_r,
bias=self.bias,
device=device,
device=self.device,
dtype=dtype)
self.lora_weight_2 = nn.Linear(self.lora_config.lora_r,
self.output_dim,
bias=self.bias,
device=device,
device=self.device,
dtype=dtype)

# initialize "A" with kaiming uniform and "B" with zeros following this
# https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155
nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_weight_2.weight)
self.lora_weight_1.weight.requires_grad = True
self.lora_weight_2.weight.requires_grad = True

def full_weight(self):
# This assumes weights are evenly sharded across gpus. which might not be correct.
# in that case, we should flatten before all_gather.
local_weight = self.base_weight.dequantized() if isinstance(self.base_weight,
QuantizedParameter) else self.base_weight
tensor_list = [
torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype)
for _ in range(self.zero_shards)
]
dist.all_gather(tensor_list, local_weight)
weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1))
return weight
base_weight = self.weight
if getattr(base_weight, 'ds_offload', False):
# move to gpu so we can dequant and all-gather
assert base_weight.device == torch.device('cpu'), \
f"expected base weight on cpu but found {base_weight.device}"
base_weight.offload(revert=True)
local_weight = base_weight.dequantized() if isinstance(base_weight,
QuantizedParameter) else base_weight
base_weight.offload()
else:
local_weight = base_weight.dequantized() if isinstance(base_weight,
QuantizedParameter) else base_weight

tensor_out = torch.empty(self.output_dim * self.input_dim,
dtype=local_weight.dtype, device=local_weight.device)
dist.all_gather_into_tensor(tensor_out, local_weight)
return tensor_out.reshape(self.output_dim, self.input_dim)

def linear_without_F_linear(self, input, weight):
output = torch.mm(input.reshape(-1, input.shape[-1]), weight)
Expand All @@ -141,9 +160,9 @@ def forward(self, input_tensor):
with torch.no_grad():
base_weight = self.full_weight()
elif self.quantization_config:
base_weight = self.base_weight.dequantized()
base_weight = self.weight.dequantized()
else:
base_weight = self.base_weight
base_weight = self.weight

base_weight_output = F.linear(input_tensor, base_weight)
lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor))
Expand Down
18 changes: 14 additions & 4 deletions deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,31 @@ def __new__(

def _ensure_quantized(self, tensor: torch.Tensor):
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8:
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8:
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.int8
assert tensor.dtype == torch.uint8

def dequantized(self) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8:
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8:
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
return self.quantizer.dequantize(self.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
return self.data

def offload(self, revert=False):
if getattr(self, 'ds_offload', False):
if revert:
self.data = self.to(get_accelerator().current_device_name())
else:
self.data = self.to('cpu')

def __getstate__(self):
state = self.__dict__
state["data"] = self.data
Expand Down Expand Up @@ -104,14 +111,17 @@ def __copy__(self):
return new_instance

def cuda(self, device=None, non_blocking=False):
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
device = "cuda" if device is None else device
self.quantizer.to(device, non_blocking=non_blocking)
return self.to(device, non_blocking=non_blocking)

def to(self, *args, **kwargs):
"""
Move the parameter to the given device. Then, if the device is a cuda device,
quantize it.
"""
tensor = super().to(*args, **kwargs)
self.quantizer.to(*args, **kwargs)
self._ensure_quantized(tensor)
return tensor

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import lion
from . import sparse_attention
from . import transformer

from . import fp_quantizer
from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig

from ..git_version_info import compatible_ops as __compatible_ops__
1 change: 1 addition & 0 deletions deepspeed/ops/fp_quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
# DeepSpeed Team

from .quantize import FP_Quantize, Quantizer
from .fp8_gemm import matmul_fp8
171 changes: 171 additions & 0 deletions deepspeed/ops/fp_quantizer/fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

######## Fused MoE kernel #########
# These kernels are implemented for
# fusing GeMM with dequantization of
# fp8 weight data when using bit-16
# activation.
###################################

import torch
import triton
import triton.language as tl


@triton.jit
def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> bf16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
w = (w + 0x3C00).to(tl.uint16)
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk
weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
weight = tl.load(weight_data, mask=weight_mask, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
mask=weight_mask,
other=0.0)

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


@triton.jit
def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> fp16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
w = (w + 0x2000).to(tl.uint16)
w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk

weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


def matmul_fp8(inp, weight, scale, quantization_group_size):

assert inp.shape[1] == weight.shape[0], \
f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"

M, K = inp.shape
K, N = weight.shape

out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)

# GEMM tuning parameters!
# TODO: Add a more configurable tuning for selecting the best GeMM
BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = max(64, quantization_group_size)
GROUP_SIZE_M = 8
num_stages = 4
num_warps = 4
if M >= 256:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = max(128, quantization_group_size)
num_stages = 3
num_warps = 8

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
kernel[grid](inp,
weight,
out,
scale,
M,
N,
K,
inp.stride(0),
inp.stride(1),
weight.stride(0),
weight.stride(1),
out.stride(0),
out.stride(1),
quantization_group_size=quantization_group_size,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=GROUP_SIZE_M,
num_stages=num_stages,
num_warps=num_warps)
return out
Loading
Loading