Skip to content

Commit

Permalink
add mlu model class for comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
jwilber committed Sep 28, 2023
1 parent 6a6b901 commit 1bc02dc
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
1 change: 1 addition & 0 deletions pykoi/chat/llm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ class ModelSource(Enum):
OPENAI = "openai"
HUGGINGFACE = "huggingface"
PEFT_HUGGINGFACE = "peft_huggingface"
MLU = "mlu"
49 changes: 49 additions & 0 deletions pykoi/chat/llm/mlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""MLU HF model."""
from transformers import GenerationConfig
from pykoi.chat.llm.abs_llm import AbsLlm

from transformers import GenerationConfig


class MLUWrapper(AbsLlm):
model_source = "mlu_trainer"

def __init__(self, trainer, tokenizer, name=None):
self._trainer = trainer
self._model = trainer.model
self._tokenizer = tokenizer
self._name = name
self._model.to("cuda:0")
self._model.eval()
super().__init__()

@property
def name(self):
if self._name:
return self._name
return "_".join([str(MLUWrapper.model_source), "trainer_model"])

def predict(self, message: str, num_of_response: int = 1):
MAX_RESPONSE = 100
prompt_template = """Below is a sentence that you need to complete. Write a response that appropriately completes the request. Sentence: {instruction}\n Response:"""
answer_template = """{response}"""

generation_output = self._model.generate(
input_ids=self._tokenizer(
prompt_template.format(instruction=message), return_tensors="pt"
)["input_ids"].cuda(),
generation_config=GenerationConfig(
do_sample=False, num_beams=2
), # Match the standalone function
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=MAX_RESPONSE,
num_return_sequences=num_of_response,
)

response = [
self._tokenizer.decode(seq, skip_special_tokens=True)
for seq in generation_output.sequences
]
response = [resp.split("\n")[1] for resp in response if "\n" in resp]
return response
8 changes: 5 additions & 3 deletions pykoi/chat/llm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ def create_model(model_source: Union[str, ModelSource], **kwargs) -> AbsLlm:
from pykoi.chat.llm.peft_huggingface import PeftHuggingfacemodel

return PeftHuggingfacemodel(**kwargs)
elif model_source == ModelSource.MLU:
from pykoi.chat.llm.mlu import MLUWrapper

return MLUWrapper(**kwargs)
else:
raise ValueError(
f"[llm_factory]: Unknown model source {model_source}"
)
raise ValueError(f"[llm_factory]: Unknown model source {model_source}")
except ValueError as ex:
raise ValueError("[llm_factory]: initialize model failure") from ex

0 comments on commit 1bc02dc

Please sign in to comment.