Skip to content

Commit

Permalink
[gptq] add gptq tensor parallel (hpcaitech#4538)
Browse files Browse the repository at this point in the history
* add gptq tensor parallel

* add gptq tp

* delete print

* add test gptq check

* add test auto gptq check
  • Loading branch information
Xu-Kai committed Sep 19, 2023
1 parent 5bd381d commit 145ff94
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 16 deletions.
60 changes: 51 additions & 9 deletions colossalai/gptq/cai_gptq/cai_quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CaiQuantLinear(nn.Module):
"temp_dq": None,
}

def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
Expand All @@ -46,7 +46,14 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
if row_split:
self.register_buffer(
'g_idx',
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
dtype=torch.int32))
else:
self.register_buffer('g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))

if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
Expand All @@ -57,6 +64,9 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):

self.q4 = None
self.empty_tensor = torch.empty((1, 1), device="meta")
self.tp_size = tp_size
self.tp_rank = tp_rank
self.row_split = row_split

def pack(self, linear, scales, zeros, g_idx=None):

Expand Down Expand Up @@ -137,17 +147,30 @@ def pack(self, linear, scales, zeros, g_idx=None):
else:
self.g_idx = g_idx

def prepare_buffers(self):
assert self.qweight.device.type == "cuda"
device = self.qweight.device
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None

CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8)

if self.g_idx is not None:
CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures,
self.outfeatures)
CaiQuantLinear.max_input_len = 4096

def prepare_buffers(self):
assert self.qweight.device.type == "cuda"
device = self.qweight.device

# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros(
Expand All @@ -170,6 +193,21 @@ def prepare_buffers(self):
def init_q4(self):
assert self.qweight.device.type == "cuda"
self.q4_width = self.qweight.shape[1]
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None

if self.g_idx is not None:
g_idx = self.g_idx.to("cpu")
else:
Expand All @@ -192,16 +230,20 @@ def forward(self, x):
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
gptq_cuda.q4_matmul(x, self.q4, output)
if self.bias is not None:
if (self.bias is not None and not self.row_split) or self.tp_size == 1:
output.add_(self.bias)
else:
if (self.bias is not None and not self.row_split) or self.tp_size == 1:
bias = self.bias
else:
bias = None
output = self.gptq_linear(
x,
self.qweight,
self.scales,
self.qzeros,
g_idx=self.g_idx,
bias=self.bias,
bias=bias,
)
return output.view(outshape)

Expand Down
180 changes: 180 additions & 0 deletions colossalai/gptq/gptq_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import warnings

import torch
import torch.distributed as dist

HAS_AUTO_GPTQ = False
try:
import auto_gptq
HAS_AUTO_GPTQ = True
except ImportError:
warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
HAS_AUTO_GPTQ = False

from .cai_gptq import CaiQuantLinear
from .models import GPTQBloomConfig, GPTQLlamaConfig, reset_bloom_attention_params, reset_llama_attention_params

model_config_map = {
"llama": GPTQLlamaConfig,
"bloom": GPTQBloomConfig,
}
attention_proc_map = {
"llama": reset_llama_attention_params,
"bloom": reset_bloom_attention_params,
}
if HAS_AUTO_GPTQ:

def get_module_by_name_prefix(model, module_name: str):
for name, module in model.named_modules():
if name.startswith(module_name):
return module

def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):

qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
g_idx = gptq_linear.g_idx
if gptq_linear.bias is not None:
bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)

cai_split_out_features = cai_linear.outfeatures // split_num
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num

for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
zero_split_block] = qzeros[i][:,
tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) *
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]

cai_linear.g_idx.copy_(g_idx)

def split_row_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):

qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)

cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
idx_split_features = cai_linear.infeatures // split_num

for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
cai_split_in_features, :]
cai_linear.qzeros[i * zero_split_block:(i + 1) *
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.scales[i * zero_split_block:(i + 1) *
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.g_idx[i * idx_split_features:(i + 1) *
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
idx_split_features]
if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias)

def replace_autogptq_linear(model, tp_size=1, tp_rank=0, tp_group=None):

def all_reduce_hook(cai_linear, input, output):
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=tp_group)
if cai_linear.bias is not None:
output.add_(cai_linear.bias)

model_type_name = model.config.model_type

gptq_model_config = model_config_map[model_type_name]
layers = get_module_by_name_prefix(model.model, gptq_model_config.layer_blocks)

for layer in layers:

attention_proc_map[model_type_name](layer, tp_size=tp_size)
for linear_name in gptq_model_config.linear_names[0]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#column split copy
cai_linear = CaiQuantLinear(
gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features,
gptq_linear.out_features // tp_size,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
)
cai_linear.to(gptq_linear.qweight.device)
if len(gptq_model_config.linear_names[0]) == 1:
split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank, split_num=3)
else:
split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank, split_num=1)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)

for linear_name in gptq_model_config.linear_names[1]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#row split copy
cai_linear = CaiQuantLinear(gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features // tp_size,
gptq_linear.out_features,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True)
cai_linear.to(gptq_linear.qweight.device)
split_row_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank)

if tp_size > 1:
cai_linear.register_forward_hook(all_reduce_hook)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)

for linear_name in gptq_model_config.linear_names[2]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#column split copy
cai_linear = CaiQuantLinear(
gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features,
gptq_linear.out_features // tp_size,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
)
cai_linear.to(gptq_linear.qweight.device)
split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)

for linear_name in gptq_model_config.linear_names[3]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#row split copy
cai_linear = CaiQuantLinear(gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features // tp_size,
gptq_linear.out_features,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True)
cai_linear.to(gptq_linear.qweight.device)
split_row_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank)

if tp_size > 1:
cai_linear.register_forward_hook(all_reduce_hook)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)
2 changes: 2 additions & 0 deletions colossalai/gptq/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .bloom import GPTQBloomConfig, reset_bloom_attention_params
from .llama import GPTQLlamaConfig, reset_llama_attention_params
18 changes: 18 additions & 0 deletions colossalai/gptq/models/bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass, field, fields


@dataclass
class GPTQBloomConfig():
layer_name = "BloomBlock"
layer_blocks = "transformer.h"
linear_names = [["self_attention.query_key_value"], ["self_attention.dense"], ["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"]]
model_names = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"]
attention = "self_attention"
mlp = "mlp"


def reset_bloom_attention_params(layer, tp_size=1):
attention = getattr(layer, "self_attention")
attention.hidden_size = attention.hidden_size // tp_size
attention.num_heads = attention.num_heads // tp_size
19 changes: 19 additions & 0 deletions colossalai/gptq/models/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass, field, fields


@dataclass
class GPTQLlamaConfig():
layer_name = "LlamaDecoderLayer"
layer_blocks = "model.layers"
linear_names = [["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"], ["mlp.down_proj"]]
model_names = ["model.embed_tokens", "model.norm"]
attention = "self_attn"
mlp = "mlp"


def reset_llama_attention_params(layer, tp_size=1):
attention = getattr(layer, "self_attn")
attention.hidden_size = attention.hidden_size // tp_size
attention.num_heads = attention.num_heads // tp_size
attention.num_key_value_heads = attention.num_key_value_heads // tp_size
Loading

0 comments on commit 145ff94

Please sign in to comment.