Skip to content

Commit

Permalink
replace auto-gptq
Browse files Browse the repository at this point in the history
  • Loading branch information
Xu-Kai committed Aug 16, 2023
1 parent a56d61b commit ef97b74
Show file tree
Hide file tree
Showing 21 changed files with 98 additions and 2,564 deletions.
21 changes: 11 additions & 10 deletions colossalai/gptq/cai_gptq/cai_quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
from .gptq_op import CaiGPTQLinearOp
import triton

Expand All @@ -22,22 +21,18 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64))
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
# self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int64))
# self.order_qzeros = torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int64)
# self.register_buffer('input_idx', torch.zeros(infeatures], dtype=torch.int32))

self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))

if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None

self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
self.printed = False
self.reorder_zeros = False
def pack(self, linear, scales, zeros, g_idx=None):


def pack(self, linear, scales, zeros, g_idx=None):

g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)

scales = scales.t().contiguous()
Expand Down Expand Up @@ -103,8 +98,13 @@ def pack(self, linear, scales, zeros, g_idx=None):
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(sign_type)
qzeros = torch.from_numpy(qzeros)
qzeros = qzeros #.to(torch.cuda.current_device())
qzeros = qzeros
self.qzeros.data.copy_(qzeros)

if torch.equal(self.g_idx, g_idx):
self.g_idx = None
else:
self.g_idx = g_idx


def forward(self, x):
Expand All @@ -113,7 +113,8 @@ def forward(self, x):
self.qweight,
self.scales,
self.qzeros,
bias = self.bias)
g_idx = self.g_idx,
bias = self.bias,)
return cai_out

def make_cai_quant_linear(module, names, bits, groupsize, name=''):
Expand Down
5 changes: 3 additions & 2 deletions colossalai/gptq/cai_gptq/gptq_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def forward(self,
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
g_idx: torch.Tensor = None,
act_type = 0,
bias: torch.Tensor = None,
residual: torch.Tensor=None,
Expand All @@ -32,9 +33,9 @@ def forward(self,
add_residual = False
x = input.view(-1, input.shape[-1])


out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual,
self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, act_type=act_type)
self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual,
act_type=act_type, g_idx=g_idx)
if qkv_fused:
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
else:
Expand Down
21 changes: 10 additions & 11 deletions colossalai/gptq/cai_gptq/gptq_triton.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import triton
import triton.language as tl
import torch
from ..gptq_utils.quant import custom_autotune
from auto_gptq.nn_modules.triton_utils import custom_autotune
# from ..ops.triton.kernels.activations_kernels import relu, gelu, silu
# code based https://github.com/fpgaminer/GPTQ-triton
# triton.Config({
Expand Down Expand Up @@ -221,11 +221,11 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_

for k in range(0, num_pid_k):
# g_idx = tl.load(g_ptrs)
if (k + 1) * BLOCK_SIZE_K > currend_group_end:
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
# if (k + 1) * BLOCK_SIZE_K > currend_group_end:
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
Expand Down Expand Up @@ -391,8 +391,7 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i

for k in range(0, num_pid_k):
# g_idx = tl.load(g_ptrs)
if (k + 1) * BLOCK_SIZE_K > currend_group_end:
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)

zeros = (zeros >> zeros_shifter[None, :]) & maxq
Expand Down Expand Up @@ -438,7 +437,7 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i


def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual,
bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, idx = None, act_type = 0):
bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, g_idx = None, act_type = 0):
# print("gptq fused ", qkv_fused, add_bias, add_residual)
with torch.cuda.device(input.device):
if qkv_fused:
Expand All @@ -448,15 +447,15 @@ def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual,
grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
# print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype)
if idx is None:
if g_idx is None:
cai_gptq_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, bias, residual,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
gptq_group_size,
input.stride(0), input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0),
QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type)
else:
cai_gptq_idx_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, idx, bias, residual,
cai_gptq_idx_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, bias, residual,
input.shape[0], qweight.shape[1], input.shape[1], bits, maxq,
gptq_group_size,
input.stride(0), input.stride(1), qweight.stride(0),
Expand Down
1 change: 0 additions & 1 deletion colossalai/gptq/gptq_utils/__init__.py

This file was deleted.

236 changes: 0 additions & 236 deletions colossalai/gptq/gptq_utils/gptq.py

This file was deleted.

5 changes: 0 additions & 5 deletions colossalai/gptq/gptq_utils/quant/__init__.py

This file was deleted.

Loading

0 comments on commit ef97b74

Please sign in to comment.