From e82a632c414b1fc731761b5013fe8a0e7a311473 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 29 Aug 2023 16:47:40 +0800 Subject: [PATCH 1/5] add gptq tensor parallel --- colossalai/gptq/cai_gptq/cai_quant_linear.py | 61 ++++++- colossalai/gptq/gptq_tp.py | 180 +++++++++++++++++++ colossalai/gptq/models/__init__.py | 2 + colossalai/gptq/models/bloom.py | 18 ++ colossalai/gptq/models/llama.py | 19 ++ 5 files changed, 271 insertions(+), 9 deletions(-) create mode 100644 colossalai/gptq/gptq_tp.py create mode 100644 colossalai/gptq/models/__init__.py create mode 100644 colossalai/gptq/models/bloom.py create mode 100644 colossalai/gptq/models/llama.py diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index c65b325d54ee..1fc88904cac5 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -30,7 +30,7 @@ class CaiQuantLinear(nn.Module): "temp_dq": None, } - def __init__(self, bits, groupsize, infeatures, outfeatures, bias): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): super().__init__() if bits not in [2, 4, 8]: raise NotImplementedError("Only 2,4,8 bits are supported.") @@ -46,7 +46,14 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) - self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + if row_split: + self.register_buffer( + 'g_idx', + torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], + dtype=torch.int32)) + else: + self.register_buffer('g_idx', + torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) if bias: self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) @@ -57,6 +64,9 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): self.q4 = None self.empty_tensor = torch.empty((1, 1), device="meta") + self.tp_size = tp_size + self.tp_rank = tp_rank + self.row_split = row_split def pack(self, linear, scales, zeros, g_idx=None): @@ -137,17 +147,31 @@ def pack(self, linear, scales, zeros, g_idx=None): else: self.g_idx = g_idx + def prepare_buffers(self): + assert self.qweight.device.type == "cuda" + device = self.qweight.device + print(self.g_idx) + if self.g_idx is not None: + if self.row_split and torch.equal( + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + elif torch.equal( + self.g_idx, + torch.tensor([i // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8) if self.g_idx is not None: CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures, self.outfeatures) CaiQuantLinear.max_input_len = 4096 - - def prepare_buffers(self): - assert self.qweight.device.type == "cuda" - device = self.qweight.device - # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros( @@ -170,6 +194,21 @@ def prepare_buffers(self): def init_q4(self): assert self.qweight.device.type == "cuda" self.q4_width = self.qweight.shape[1] + if self.g_idx is not None: + if self.row_split and torch.equal( + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + elif torch.equal( + self.g_idx, + torch.tensor([i // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): + self.g_idx = None + if self.g_idx is not None: g_idx = self.g_idx.to("cpu") else: @@ -192,16 +231,20 @@ def forward(self, x): x = x.view(-1, x.shape[-1]) output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) gptq_cuda.q4_matmul(x, self.q4, output) - if self.bias is not None: + if (self.bias is not None and not self.row_split) or self.tp_size == 1: output.add_(self.bias) else: + if (self.bias is not None and not self.row_split) or self.tp_size == 1: + bias = self.bias + else: + bias = None output = self.gptq_linear( x, self.qweight, self.scales, self.qzeros, g_idx=self.g_idx, - bias=self.bias, + bias=bias, ) return output.view(outshape) diff --git a/colossalai/gptq/gptq_tp.py b/colossalai/gptq/gptq_tp.py new file mode 100644 index 000000000000..e8d1d7f00fe8 --- /dev/null +++ b/colossalai/gptq/gptq_tp.py @@ -0,0 +1,180 @@ +import warnings + +import torch +import torch.distributed as dist + +HAS_AUTO_GPTQ = False +try: + import auto_gptq + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') + HAS_AUTO_GPTQ = False + +from .cai_gptq import CaiQuantLinear +from .models import GPTQBloomConfig, GPTQLlamaConfig, reset_bloom_attention_params, reset_llama_attention_params + +model_config_map = { + "llama": GPTQLlamaConfig, + "bloom": GPTQBloomConfig, +} +attention_proc_map = { + "llama": reset_llama_attention_params, + "bloom": reset_bloom_attention_params, +} +if HAS_AUTO_GPTQ: + + def get_module_by_name_prefix(model, module_name: str): + for name, module in model.named_modules(): + if name.startswith(module_name): + return module + + def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) + qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) + scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) + g_idx = gptq_linear.g_idx + if gptq_linear.bias is not None: + bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1) + + cai_split_out_features = cai_linear.outfeatures // split_num + zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num + + for i in range(split_num): + cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + cai_linear.qzeros[:, i * zero_split_block:(i + 1) * + zero_split_block] = qzeros[i][:, + tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] + cai_linear.scales[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + if cai_linear.bias is not None: + cai_linear.bias[i * cai_split_out_features:(i + 1) * + cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + + cai_linear.g_idx.copy_(g_idx) + + def split_row_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): + + qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) + qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) + scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) + g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0) + + cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num + zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num + idx_split_features = cai_linear.infeatures // split_num + + for i in range(split_num): + cai_linear.qweight[i * cai_split_in_features:(i + 1) * + cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * + cai_split_in_features, :] + cai_linear.qzeros[i * zero_split_block:(i + 1) * + zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.scales[i * zero_split_block:(i + 1) * + zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.g_idx[i * idx_split_features:(i + 1) * + idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * + idx_split_features] + if cai_linear.bias is not None: + cai_linear.bias.copy_(gptq_linear.bias) + + def replace_autogptq_linear(model, tp_size=1, tp_rank=0, tp_group=None): + + def all_reduce_hook(cai_linear, input, output): + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=tp_group) + if cai_linear.bias is not None: + output.add_(cai_linear.bias) + + model_type_name = model.config.model_type + + gptq_model_config = model_config_map[model_type_name] + layers = get_module_by_name_prefix(model.model, gptq_model_config.layer_blocks) + + for layer in layers: + + attention_proc_map[model_type_name](layer, tp_size=tp_size) + for linear_name in gptq_model_config.linear_names[0]: + gptq_linear = get_module_by_name_prefix(layer, linear_name) + #column split copy + cai_linear = CaiQuantLinear( + gptq_linear.bits, + gptq_linear.group_size, + gptq_linear.in_features, + gptq_linear.out_features // tp_size, + gptq_linear.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + ) + cai_linear.to(gptq_linear.qweight.device) + if len(gptq_model_config.linear_names[0]) == 1: + split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank, split_num=3) + else: + split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank, split_num=1) + name1, name2 = linear_name.split(".") + parent_module = get_module_by_name_prefix(layer, name1) + setattr(parent_module, name2, cai_linear) + + for linear_name in gptq_model_config.linear_names[1]: + gptq_linear = get_module_by_name_prefix(layer, linear_name) + #row split copy + cai_linear = CaiQuantLinear(gptq_linear.bits, + gptq_linear.group_size, + gptq_linear.in_features // tp_size, + gptq_linear.out_features, + gptq_linear.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True) + cai_linear.to(gptq_linear.qweight.device) + split_row_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank) + + if tp_size > 1: + cai_linear.register_forward_hook(all_reduce_hook) + name1, name2 = linear_name.split(".") + parent_module = get_module_by_name_prefix(layer, name1) + setattr(parent_module, name2, cai_linear) + + for linear_name in gptq_model_config.linear_names[2]: + gptq_linear = get_module_by_name_prefix(layer, linear_name) + #column split copy + cai_linear = CaiQuantLinear( + gptq_linear.bits, + gptq_linear.group_size, + gptq_linear.in_features, + gptq_linear.out_features // tp_size, + gptq_linear.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + ) + cai_linear.to(gptq_linear.qweight.device) + split_column_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank) + name1, name2 = linear_name.split(".") + parent_module = get_module_by_name_prefix(layer, name1) + setattr(parent_module, name2, cai_linear) + + for linear_name in gptq_model_config.linear_names[3]: + gptq_linear = get_module_by_name_prefix(layer, linear_name) + #row split copy + cai_linear = CaiQuantLinear(gptq_linear.bits, + gptq_linear.group_size, + gptq_linear.in_features // tp_size, + gptq_linear.out_features, + gptq_linear.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True) + cai_linear.to(gptq_linear.qweight.device) + split_row_copy(gptq_linear, cai_linear, tp_size=tp_size, tp_rank=tp_rank) + + if tp_size > 1: + cai_linear.register_forward_hook(all_reduce_hook) + name1, name2 = linear_name.split(".") + parent_module = get_module_by_name_prefix(layer, name1) + setattr(parent_module, name2, cai_linear) diff --git a/colossalai/gptq/models/__init__.py b/colossalai/gptq/models/__init__.py new file mode 100644 index 000000000000..ed444b4ed9cb --- /dev/null +++ b/colossalai/gptq/models/__init__.py @@ -0,0 +1,2 @@ +from .bloom import GPTQBloomConfig, reset_bloom_attention_params +from .llama import GPTQLlamaConfig, reset_llama_attention_params diff --git a/colossalai/gptq/models/bloom.py b/colossalai/gptq/models/bloom.py new file mode 100644 index 000000000000..b57fa3a5abbe --- /dev/null +++ b/colossalai/gptq/models/bloom.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass, field, fields + + +@dataclass +class GPTQBloomConfig(): + layer_name = "BloomBlock" + layer_blocks = "transformer.h" + linear_names = [["self_attention.query_key_value"], ["self_attention.dense"], ["mlp.dense_h_to_4h"], + ["mlp.dense_4h_to_h"]] + model_names = ["transformer.word_embeddings", "transformer.word_embeddings_layernorm", "transformer.ln_f"] + attention = "self_attention" + mlp = "mlp" + + +def reset_bloom_attention_params(layer, tp_size=1): + attention = getattr(layer, "self_attention") + attention.hidden_size = attention.hidden_size // tp_size + attention.num_heads = attention.num_heads // tp_size diff --git a/colossalai/gptq/models/llama.py b/colossalai/gptq/models/llama.py new file mode 100644 index 000000000000..71690ba748a5 --- /dev/null +++ b/colossalai/gptq/models/llama.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field, fields + + +@dataclass +class GPTQLlamaConfig(): + layer_name = "LlamaDecoderLayer" + layer_blocks = "model.layers" + linear_names = [["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"], + ["mlp.up_proj", "mlp.gate_proj"], ["mlp.down_proj"]] + model_names = ["model.embed_tokens", "model.norm"] + attention = "self_attn" + mlp = "mlp" + + +def reset_llama_attention_params(layer, tp_size=1): + attention = getattr(layer, "self_attn") + attention.hidden_size = attention.hidden_size // tp_size + attention.num_heads = attention.num_heads // tp_size + attention.num_key_value_heads = attention.num_key_value_heads // tp_size From a1a2ea5c05b3c3d5bc70ec043931d3f6ac47d5a6 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 29 Aug 2023 16:49:29 +0800 Subject: [PATCH 2/5] add gptq tp --- examples/inference/gptq_llama.py | 71 ++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/inference/gptq_llama.py diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py new file mode 100644 index 000000000000..e2c0e057cc83 --- /dev/null +++ b/examples/inference/gptq_llama.py @@ -0,0 +1,71 @@ +import logging + +import torch +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from auto_gptq.nn_modules.qlinear import GeneralQuantLinear +from torch import distributed as dist +from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline + +from colossalai.gptq import CaiQuantLinear +from colossalai.gptq.gptq_tp import replace_autogptq_linear + +logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S") +dist.init_process_group(backend="nccl") +pretrained_model_dir = "/data/scratch/llama-7b-hf" +# quantized_model_dir = "llama-7b-with-act-4bit" +quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/llama-7b-no-act-4bit" +rank = dist.get_rank() +world_size = dist.get_world_size() +# rank = 1 +# world_size=2 +torch.cuda.set_device(rank) +print("world size {0} rank {1} deivce {2}".format(world_size, rank, torch.cuda.current_device())) +tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +examples = [ + tokenizer( + "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.") +] + +quantize_config = BaseQuantizeConfig( + bits=4, # quantize model to 4-bit + group_size=128, # it is recommended to set the value to 128 + desc_act=True, # set to False can significantly speed up inference but the perplexity may slightly bad +) + +# # load un-quantized model, by default, the model will always be loaded into CPU memory +# model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + +# # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" +# model.quantize(examples) + +# # save quantized model +# model.save_quantized(quantized_model_dir) + +# # save quantized model using safetensors +# model.save_quantized(quantized_model_dir, use_safetensors=True) + +# load quantized model to the first GPU +model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + +replace_autogptq_linear(model, tp_size=world_size, tp_rank=rank) + +# if rank == 0: +# print(model.config) +# print(model) +# download quantized model from Hugging Face Hub and load to the first GPU +# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False) + +# inference with model.generate +print("input is:", "auto-gptq is") +print( + tokenizer.decode( + model.generate(**tokenizer("auto-gptq is", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) +dist.barrier() +print("input is:", "today is") +print( + tokenizer.decode( + model.generate(**tokenizer("today is ", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) From 32a5d6ff506d9caba07d0f5ec9b26f2eef5ef4d2 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 29 Aug 2023 17:35:05 +0800 Subject: [PATCH 3/5] delete print --- colossalai/gptq/cai_gptq/cai_quant_linear.py | 1 - examples/inference/gptq_llama.py | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index 1fc88904cac5..78a37e7bbfb3 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -150,7 +150,6 @@ def pack(self, linear, scales, zeros, g_idx=None): def prepare_buffers(self): assert self.qweight.device.type == "cuda" device = self.qweight.device - print(self.g_idx) if self.g_idx is not None: if self.row_split and torch.equal( self.g_idx, diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index e2c0e057cc83..ae398740dcdb 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -28,11 +28,11 @@ "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.") ] -quantize_config = BaseQuantizeConfig( - bits=4, # quantize model to 4-bit - group_size=128, # it is recommended to set the value to 128 - desc_act=True, # set to False can significantly speed up inference but the perplexity may slightly bad -) +# quantize_config = BaseQuantizeConfig( +# bits=4, # quantize model to 4-bit +# group_size=128, # it is recommended to set the value to 128 +# desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad +# ) # # load un-quantized model, by default, the model will always be loaded into CPU memory # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) From 836286fd9f1cf05e3ca2b4244138af02850a89b4 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 30 Aug 2023 09:43:56 +0800 Subject: [PATCH 4/5] add test gptq check --- tests/test_gptq/test_gptq_linear.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index 7b3913928587..20a177378142 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -6,13 +6,30 @@ import torch import torch.nn as nn import transformers -from auto_gptq.modeling._utils import autogptq_post_init, find_layers, pack_model -from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear -from auto_gptq.quantization import GPTQ -from auto_gptq.quantization.quantizer import Quantizer +from packaging import version from colossalai.gptq import CaiGPTQLinearOp, CaiQuantLinear +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from auto_gptq.modeling._utils import autogptq_post_init, find_layers, pack_model + from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear + from auto_gptq.quantization import GPTQ + from auto_gptq.quantization.quantizer import Quantizer + HAS_AUTO_GPTQ = True +except: + HAS_AUTO_GPTQ = False + print("please install triton from https://github.com/PanQiWei/AutoGPTQ") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + wbits = 4 trits = False nsamples = 1 @@ -214,6 +231,8 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize return qweight, qscales, qzeros +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") def test_gptq_linear(): infeature = 5120 @@ -265,7 +284,6 @@ def test_gptq_linear(): gptq_model.to(torch.cuda.current_device()) gptq_model = autogptq_post_init(gptq_model, False) - with torch.no_grad(): gptq_out = gptq_model(inps) batch_gptq_out = gptq_model(batch_inps) From 79a7fc4275de75b01650582865591a650a8971bb Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 30 Aug 2023 10:00:36 +0800 Subject: [PATCH 5/5] add test auto gptq check --- tests/test_gptq/test_gptq_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index 20a177378142..0d0343a5c407 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -8,8 +8,6 @@ import transformers from packaging import version -from colossalai.gptq import CaiGPTQLinearOp, CaiQuantLinear - try: import triton import triton.language as tl @@ -23,6 +21,8 @@ from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear from auto_gptq.quantization import GPTQ from auto_gptq.quantization.quantizer import Quantizer + + from colossalai.gptq import CaiGPTQLinearOp, CaiQuantLinear HAS_AUTO_GPTQ = True except: HAS_AUTO_GPTQ = False