Skip to content

Commit

Permalink
Add chat formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jun 6, 2024
1 parent f137cd9 commit 4f18fd5
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import dataclasses
import datetime
import hashlib
import json
import math
import os
import re
Expand Down Expand Up @@ -293,13 +290,26 @@ class ModalEngine(InferenceEngine):
NAME = "modal"
app_name: str
modal_call: modal.Function
tokenizer: PreTrainedTokenizer

def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)
self.app_name = self._get_modal_app_name(self.llm_config.model_name_or_path)

def __enter__(self):
self.modal_call = modal.Function.lookup(self.app_name, "Model.call")
self.modal_call = modal.Function.lookup(self.app_name, "ModalModel.call")
self.tokenizer = AutoTokenizer.from_pretrained(
self.llm_config.tokenizer_name_or_path
)
LOGGER.warning(
f"Waiting for Modal app {self.app_name} to be ready. "
"This may take a few minutes."
)
start_time = time.time()
self.modal_call.remote("generate", "foo", {"max_tokens": 1})
LOGGER.warning(
f"Modal app {self.app_name} is ready! Took {time.time() - start_time:.2f}s"
)

def generate(self, prompt: InferenceInput) -> InferenceOutput:
return self.batch_generate([prompt])[0]
Expand All @@ -311,6 +321,12 @@ def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput]
"frequency_penalty": self.llm_config.frequency_penalty,
"top_p": 1.0,
}
prompts = [
_apply_chat_template(self.tokenizer, prompt)
if not isinstance(prompt, str)
else prompt
for prompt in prompts
]
output_texts = self.modal_call.remote(
"batch_generate", prompts, sampling_params
)
Expand Down

0 comments on commit 4f18fd5

Please sign in to comment.