Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gptq] add gptq tensor parallel #4538

Merged
merged 5 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions colossalai/gptq/cai_gptq/cai_quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CaiQuantLinear(nn.Module):
"temp_dq": None,
}

def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
Expand All @@ -46,7 +46,14 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
if row_split:
self.register_buffer(
'g_idx',
torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)],
dtype=torch.int32))
else:
self.register_buffer('g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))

if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
Expand All @@ -57,6 +64,9 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):

self.q4 = None
self.empty_tensor = torch.empty((1, 1), device="meta")
self.tp_size = tp_size
self.tp_rank = tp_rank
self.row_split = row_split

def pack(self, linear, scales, zeros, g_idx=None):

Expand Down Expand Up @@ -137,17 +147,31 @@ def pack(self, linear, scales, zeros, g_idx=None):
else:
self.g_idx = g_idx

def prepare_buffers(self):
assert self.qweight.device.type == "cuda"
device = self.qweight.device
print(self.g_idx)
ver217 marked this conversation as resolved.
Show resolved Hide resolved
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None

CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8)

if self.g_idx is not None:
CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures,
self.outfeatures)
CaiQuantLinear.max_input_len = 4096

def prepare_buffers(self):
assert self.qweight.device.type == "cuda"
device = self.qweight.device

# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros(
Expand All @@ -170,6 +194,21 @@ def prepare_buffers(self):
def init_q4(self):
assert self.qweight.device.type == "cuda"
self.q4_width = self.qweight.shape[1]
if self.g_idx is not None:
if self.row_split and torch.equal(
self.g_idx,
torch.tensor(
[(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None
elif torch.equal(
self.g_idx,
torch.tensor([i // self.groupsize for i in range(self.infeatures)],
dtype=torch.int32,
device=self.g_idx.device)):
self.g_idx = None

if self.g_idx is not None:
g_idx = self.g_idx.to("cpu")
else:
Expand All @@ -192,16 +231,20 @@ def forward(self, x):
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
gptq_cuda.q4_matmul(x, self.q4, output)
if self.bias is not None:
if (self.bias is not None and not self.row_split) or self.tp_size == 1:
output.add_(self.bias)
else:
if (self.bias is not None and not self.row_split) or self.tp_size == 1:
bias = self.bias
else:
bias = None
output = self.gptq_linear(
x,
self.qweight,
self.scales,
self.qzeros,
g_idx=self.g_idx,
bias=self.bias,
bias=bias,
)
return output.view(outshape)

Expand Down
180 changes: 180 additions & 0 deletions colossalai/gptq/gptq_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import warnings

import torch
import torch.distributed as dist

HAS_AUTO_GPTQ = False
try:
import auto_gptq
HAS_AUTO_GPTQ = True
except ImportError:
warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
HAS_AUTO_GPTQ = False

from .cai_gptq import CaiQuantLinear
from .models import GPTQBloomConfig, GPTQLlamaConfig, reset_bloom_attention_params, reset_llama_attention_params

model_config_map = {
"llama": GPTQLlamaConfig,
"bloom": GPTQBloomConfig,
}
attention_proc_map = {
"llama": reset_llama_attention_params,
"bloom": reset_bloom_attention_params,
}
if HAS_AUTO_GPTQ:

def get_module_by_name_prefix(model, module_name: str):
for name, module in model.named_modules():
if name.startswith(module_name):
return module

def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):

qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
g_idx = gptq_linear.g_idx
if gptq_linear.bias is not None:
bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)

cai_split_out_features = cai_linear.outfeatures // split_num
zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num

for i in range(split_num):
cai_linear.qweight[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
cai_linear.qzeros[:, i * zero_split_block:(i + 1) *
zero_split_block] = qzeros[i][:,
tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block]
cai_linear.scales[:, i * cai_split_out_features:(i + 1) *
cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]
if cai_linear.bias is not None:
cai_linear.bias[i * cai_split_out_features:(i + 1) *
cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) *
cai_split_out_features]

cai_linear.g_idx.copy_(g_idx)

def split_row_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):

qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)

cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
idx_split_features = cai_linear.infeatures // split_num

for i in range(split_num):
cai_linear.qweight[i * cai_split_in_features:(i + 1) *
cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) *
cai_split_in_features, :]
cai_linear.qzeros[i * zero_split_block:(i + 1) *
zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.scales[i * zero_split_block:(i + 1) *
zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) *
zero_split_block, :]
cai_linear.g_idx[i * idx_split_features:(i + 1) *
idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) *
idx_split_features]
if cai_linear.bias is not None:
cai_linear.bias.copy_(gptq_linear.bias)

def replace_autogptq_linear(model, tp_size=1, tp_rank=0, tp_group=None):

def all_reduce_hook(cai_linear, input, output):
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=tp_group)
if cai_linear.bias is not None:
output.add_(cai_linear.bias)

model_type_name = model.config.model_type

gptq_model_config = model_config_map[model_type_name]
layers = get_module_by_name_prefix(model.model, gptq_model_config.layer_blocks)

for layer in layers:

attention_proc_map[model_type_name](layer, tp_size=tp_size)
for linear_name in gptq_model_config.linear_names[0]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#column split copy
cai_linear = CaiQuantLinear(
gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features,
gptq_linear.out_features // tp_size,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
)
cai_linear.to(gptq_linear.qweight.device)
if len(gptq_model_config.linear_names[0]) == 1:
split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank, split_num=3)
else:
split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank, split_num=1)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)

for linear_name in gptq_model_config.linear_names[1]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#row split copy
cai_linear = CaiQuantLinear(gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features // tp_size,
gptq_linear.out_features,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True)
cai_linear.to(gptq_linear.qweight.device)
split_row_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank)

if tp_size > 1:
cai_linear.register_forward_hook(all_reduce_hook)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)

for linear_name in gptq_model_config.linear_names[2]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#column split copy
cai_linear = CaiQuantLinear(
gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features,
gptq_linear.out_features // tp_size,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
)
cai_linear.to(gptq_linear.qweight.device)
split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)

for linear_name in gptq_model_config.linear_names[3]:
gptq_linear = get_module_by_name_prefix(layer, linear_name)
#row split copy
cai_linear = CaiQuantLinear(gptq_linear.bits,
gptq_linear.group_size,
gptq_linear.in_features // tp_size,
gptq_linear.out_features,
gptq_linear.bias is not None,
tp_size=tp_size,
tp_rank=tp_rank,
row_split=True)
cai_linear.to(gptq_linear.qweight.device)
split_row_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank)

if tp_size > 1:
cai_linear.register_forward_hook(all_reduce_hook)
name1, name2 = linear_name.split(".")
parent_module = get_module_by_name_prefix(layer, name1)
setattr(parent_module, name2, cai_linear)
2 changes: 2 additions & 0 deletions colossalai/gptq/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .bloom import GPTQBloomConfig, reset_bloom_attention_params
from .llama import GPTQLlamaConfig, reset_llama_attention_params
18 changes: 18 additions & 0 deletions colossalai/gptq/models/bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass, field, fields


@dataclass
class GPTQBloomConfig():
layer_name = "BloomBlock"
layer_blocks = "transformer.h"
linear_names = [["self_attention.query_key_value"], ["self_attention.dense"], ["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"]]
model_names = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"]
attention = "self_attention"
mlp = "mlp"


def reset_bloom_attention_params(layer, tp_size=1):
attention = getattr(layer, "self_attention")
attention.hidden_size = attention.hidden_size // tp_size
attention.num_heads = attention.num_heads // tp_size
19 changes: 19 additions & 0 deletions colossalai/gptq/models/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass, field, fields


@dataclass
class GPTQLlamaConfig():
layer_name = "LlamaDecoderLayer"
layer_blocks = "model.layers"
linear_names = [["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"], ["mlp.down_proj"]]
model_names = ["model.embed_tokens", "model.norm"]
attention = "self_attn"
mlp = "mlp"


def reset_llama_attention_params(layer, tp_size=1):
attention = getattr(layer, "self_attn")
attention.hidden_size = attention.hidden_size // tp_size
attention.num_heads = attention.num_heads // tp_size
attention.num_key_value_heads = attention.num_key_value_heads // tp_size
Loading
Loading