Skip to content

Commit

Permalink
[GPTQ] fix tests (#1598)
Browse files Browse the repository at this point in the history
fix tests
  • Loading branch information
SunMarc authored Dec 14, 2023
1 parent 40a942b commit 4a0f16f
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def test_serialization(self):
disable_exllama=self.disable_exllama,
exllama_config=self.exllama_config,
)
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")
if self.disable_exllama:
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")
else:
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")

with torch.device("cuda"):
_ = AutoModelForCausalLM.from_pretrained(tmpdirname)
Expand All @@ -172,13 +175,15 @@ class GPTQTestExllama(GPTQTest):
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of")
EXPECTED_OUTPUTS.add("Hello my name is Nate and I am a new member of the")


class GPTQTestActOrder(GPTQTest):
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
EXPECTED_OUTPUTS.add("Hello my name is jessie and i am a very sweet and")
EXPECTED_OUTPUTS.add("Hello my name is nathalie, I am a young girl from")
EXPECTED_OUTPUTS.add("Hello my name is\nI am a student of the University of the'")

disable_exllama = True
desc_act = True
Expand Down Expand Up @@ -256,6 +261,11 @@ def test_exllama_max_input_length(self):
class GPTQTestExllamav2(GPTQTest):
desc_act = False
disable_exllama = True
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of")
EXPECTED_OUTPUTS.add("Hello my name is Nate and I am a new member of the")

def test_generate_quality(self):
# don't need to test
Expand Down Expand Up @@ -300,6 +310,7 @@ class GPTQTestNoBlockCaching(GPTQTest):
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of")
EXPECTED_OUTPUTS.add("Hello my name is Aiden and I am a very good looking")


class GPTQTestModuleQuant(GPTQTest):
Expand Down

0 comments on commit 4a0f16f

Please sign in to comment.