Skip to content

Commit

Permalink
unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Feb 1, 2024
1 parent 90c1099 commit cc6bf7e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
17 changes: 9 additions & 8 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def preprocess_code(code: str) -> str:

class Predictor:
def __init__(self, input_model_path: str):
model_path, config = self.get_config(input_model_path)
model_path, config = Predictor.get_config(input_model_path)
if model_path != input_model_path:
logger.info(f"Mapped requested model '{input_model_path}' to '{model_path}'")

Expand Down Expand Up @@ -89,7 +89,8 @@ def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]:
tokens: Dict[str, List[List[int]]] = self.tokenizer(input)
return tokens

def get_config(self, model_path: str) -> Tuple[str, Dict[str, str]]:
@staticmethod
def get_config(model_path: str) -> Tuple[str, Dict[str, str]]:
if torch.cuda.is_available():
from transformers import BitsAndBytesConfig

Expand All @@ -100,12 +101,12 @@ def get_config(self, model_path: str) -> Tuple[str, Dict[str, str]]:
return mapped_path, config

# Load the model in 4bit quantization for faster inference on smaller GPUs
config ["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
config["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return model_path, config
else:
return model_path, {}
27 changes: 26 additions & 1 deletion tests/test_predict_hf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from autora.doc.runtime.predict_hf import Predictor
from unittest import mock

from autora.doc.runtime.predict_hf import Predictor, quantized_models

# Test models with and without available quantized models
MODEL_NO_QUANTIZED = "hf-internal-testing/tiny-random-FalconForCausalLM"
MODEL_WITH_QUANTIZED = "meta-llama/Llama-2-7b-chat-hf"


def test_trim_prompt() -> None:
Expand All @@ -14,3 +20,22 @@ def test_trim_prompt() -> None:
"""
output = Predictor.trim_prompt(with_marker)
assert output == "output\n"


@mock.patch("torch.cuda.is_available", return_value=True)
def test_get_config_cuda(mock: mock.Mock) -> None:
model, config = Predictor.get_config(MODEL_WITH_QUANTIZED)
assert model == quantized_models[MODEL_WITH_QUANTIZED]
assert "quantization_config" not in config

model, config = Predictor.get_config(MODEL_NO_QUANTIZED)
# no pre-quantized model available
assert model == MODEL_NO_QUANTIZED
assert "quantization_config" in config


@mock.patch("torch.cuda.is_available", return_value=False)
def test_get_config_nocuda(mock: mock.Mock) -> None:
model, config = Predictor.get_config(MODEL_WITH_QUANTIZED)
assert model == MODEL_WITH_QUANTIZED
assert len(config) == 0

0 comments on commit cc6bf7e

Please sign in to comment.