Skip to content

Commit

Permalink
Merge pull request #99 from pytorch-labs/int4_take2
Browse files Browse the repository at this point in the history
int4 from gpt-fast
  • Loading branch information
mikekgfb authored Apr 10, 2024
2 parents f5618d0 + 62ec820 commit 8c80c17
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 4 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ jobs:
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
python generate.py --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager
cat ./output_eager
python generate.py --compile --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled
cat ./output_compiled
python export.py --quant '{"linear:int4" : {"group_size": 32}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so
python generate.py --checkpoint-path ${MODEL_PATH} --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so > ./output_aoti
cat ./output_aoti
echo "tests complete"
echo "******************************************"
# echo "********* EAGER vs TORCH.COMPILE *********"
Expand Down
270 changes: 266 additions & 4 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def quantize_model(model: nn.Module, quantize_options):
**q_kwargs
).quantized_model()
elif quantizer == "linear:int4":
linears_quantized = True
model = WeightOnlyInt4QuantHandler(
model,
**q_kwargs
).quantized_model()
elif quantizer == "linear:a8w4dq":
linears_quantized = True
model = Int8DynActInt4WeightQuantHandler(
model,
Expand All @@ -70,6 +76,9 @@ def quantize_model(model: nn.Module, quantize_options):
assert 0 == 1, f"quantizer {quantizer} not supported"


#########################################################################
##### Quantization Primitives ######

def dynamically_quantize_per_channel(
x,
quant_min,
Expand Down Expand Up @@ -164,6 +173,115 @@ def dynamically_quantize_per_channel(
return quant, scales, zero_points



def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype= torch.float):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(scales_dtype).reshape(w.shape[0], -1), zeros.to(
scales_dtype
).reshape(w.shape[0], -1)


def pack_scales_and_zeros(scales, zeros, *, scales_dtype=torch.float):
assert scales.shape == zeros.shape
assert scales.dtype == scales_dtype
assert zeros.dtype == scales_dtype
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)


def unpack_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)


def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]

assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

return w_int32


def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros


def group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit=4, groupsize=128
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
assert w_int32.dim() == 2

w_int32_grouped = w_int32.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
)
return w_dq


def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
return group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit, groupsize
)

#########################################################################

class QuantHandler:
def __init__(self, mod):
self.mod = mod
Expand All @@ -173,6 +291,12 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict"

def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod


##### Weight-only int8 per-channel quantized code ######
Expand Down Expand Up @@ -202,7 +326,7 @@ def replace_linear_weight_only_int8_per_channel(module, node_type, group_size=No
replace_linear_weight_only_int8_per_channel(child, node_type, group_size)


class WeightOnlyInt8QuantHandler:
class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
mod,
Expand Down Expand Up @@ -349,7 +473,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
)


class EmbeddingOnlyInt8QuantHandler:
class EmbeddingOnlyInt8QuantHandler(QuantHandler):
def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None):
self.mod = mod
self.group_size = group_size
Expand Down Expand Up @@ -466,6 +590,145 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
##################################################################
##### weight only int4 per channel groupwise quantized code ######

def _int4_prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
return weight_int4pack, scales_and_zeros

def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
from model import find_multiple
return find_multiple(k, 1024)

def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm(
x.to(dtype=torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(dtype=torch.bfloat16)
).to(dtype=x.dtype)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


def _int4_check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0

def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=False):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
setattr(module, name, WeightOnlyInt4Linear(
child.in_features, child.out_features, bias=False,
groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda
))
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda)


class WeightOnlyInt4QuantHandler(QuantHandler):
def __init__(self, mod, group_size=128, inner_k_tiles=8, padding_allowed=True):
self.mod = mod
self.groupsize = group_size
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
assert group_size in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]

@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")

weight = mod.weight.data
if not _int4_check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
if self.padding_allowed:
from model import find_multiple
import torch.nn.functional as F
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
else:
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
weight_int4pack, scales_and_zeros = _int4_prepare_int4_weight_and_scales_and_zeros(
weight.to(torch.float), self.groupsize, self.inner_k_tiles
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')

return cur_state_dict

def convert_for_runtime(self, use_cuda=False):
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
return self.mod

def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod


class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor

def __init__(
self, in_features: int, out_features: int,
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
) -> None:
super().__init__()
self.padding = not _int4_check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from model import find_multiple
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)

self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
)
# MKG: torch.float
self.register_buffer(
"scales_and_zeros",
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.float)
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
# MKG torch.float
input = input.to(torch.float)
if self.padding:
import torch.nn.functional as F
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input,
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)

########################################################################
### Int8 Dynamic Activations 4 Bit Weights

def prepare_int4_weight_and_scales_and_zeros(weight, group_size, precision):
weight_int8, scales, zeros = group_quantize_tensor_symmetric(
Expand Down Expand Up @@ -523,7 +786,6 @@ def find_multiple(n: int, *args: Tuple[int]) -> int:
def _check_linear_int4_k(k, group_size=1):
return k % group_size == 0


def _calc_padded_size_linear_int4(k, groupsize=1):
return find_multiple(k, groupsize)

Expand Down Expand Up @@ -560,7 +822,7 @@ def replace_linear_8da4w(
)


class Int8DynActInt4WeightQuantHandler:
class Int8DynActInt4WeightQuantHandler(QuantHandler):
def __init__(
self,
mod,
Expand Down

0 comments on commit 8c80c17

Please sign in to comment.