Skip to content

Commit

Permalink
redirect to quantized model
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Feb 1, 2024
1 parent 35dcbec commit 2a98332
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/autora/doc/runtime/predict_hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Iterable, List
from typing import Dict, Iterable, List, Tuple

import torch
import transformers
Expand All @@ -9,6 +9,9 @@

logger = logging.getLogger(__name__)

# TODO: Redirect the quantized model to an 'autora' HF org
quantized_models = {"meta-llama/Llama-2-7b-chat-hf": "carlosgjs/Llama-2-7b-chat-hf-4bit"}


def preprocess_code(code: str) -> str:
lines: Iterable[str] = code.splitlines()
Expand All @@ -21,10 +24,12 @@ def preprocess_code(code: str) -> str:


class Predictor:
def __init__(self, model_path: str):
config = self.get_config()
def __init__(self, input_model_path: str):
model_path, config = self.get_config(input_model_path)
if model_path != input_model_path:
logger.info(f"Mapped requested model '{input_model_path}' to '{model_path}'")

logger.info(f"Loading model from {model_path}")
logger.info(f"Loading model from {model_path} using config {config}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
Expand Down Expand Up @@ -84,12 +89,14 @@ def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]:
tokens: Dict[str, List[List[int]]] = self.tokenizer(input)
return tokens

def get_config(self) -> Dict[str, str]:
def get_config(self, model_path: str) -> Tuple[str, Dict[str, str]]:
if torch.cuda.is_available():
from transformers import BitsAndBytesConfig

model_path = quantized_models.get(model_path, model_path)

# Load the model in 4bit quantization for faster inference on smaller GPUs
return {
return model_path, {
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
Expand All @@ -99,4 +106,4 @@ def get_config(self) -> Dict[str, str]:
"device_map": "auto",
}
else:
return {}
return model_path, {}

0 comments on commit 2a98332

Please sign in to comment.