From 4f18fd5ba6c2a43cfb4bbcc91f7079c269076cc3 Mon Sep 17 00:00:00 2001 From: Faiz Surani Date: Thu, 6 Jun 2024 14:50:11 -0700 Subject: [PATCH] Add chat formatting --- rl/llm/engines.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/rl/llm/engines.py b/rl/llm/engines.py index 5d662ee..0f3a7b9 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -1,7 +1,4 @@ -import dataclasses import datetime -import hashlib -import json import math import os import re @@ -293,13 +290,26 @@ class ModalEngine(InferenceEngine): NAME = "modal" app_name: str modal_call: modal.Function + tokenizer: PreTrainedTokenizer def __init__(self, llm_config: LLMConfig): super().__init__(llm_config) self.app_name = self._get_modal_app_name(self.llm_config.model_name_or_path) def __enter__(self): - self.modal_call = modal.Function.lookup(self.app_name, "Model.call") + self.modal_call = modal.Function.lookup(self.app_name, "ModalModel.call") + self.tokenizer = AutoTokenizer.from_pretrained( + self.llm_config.tokenizer_name_or_path + ) + LOGGER.warning( + f"Waiting for Modal app {self.app_name} to be ready. " + "This may take a few minutes." + ) + start_time = time.time() + self.modal_call.remote("generate", "foo", {"max_tokens": 1}) + LOGGER.warning( + f"Modal app {self.app_name} is ready! Took {time.time() - start_time:.2f}s" + ) def generate(self, prompt: InferenceInput) -> InferenceOutput: return self.batch_generate([prompt])[0] @@ -311,6 +321,12 @@ def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput] "frequency_penalty": self.llm_config.frequency_penalty, "top_p": 1.0, } + prompts = [ + _apply_chat_template(self.tokenizer, prompt) + if not isinstance(prompt, str) + else prompt + for prompt in prompts + ] output_texts = self.modal_call.remote( "batch_generate", prompts, sampling_params )