From e99d069f68fe918b491252f95ce51e01978948b4 Mon Sep 17 00:00:00 2001 From: Faiz Surani Date: Tue, 3 Sep 2024 16:51:44 -0700 Subject: [PATCH] Use tempfile for batch --- rl/llm/engines/client.py | 58 ++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/rl/llm/engines/client.py b/rl/llm/engines/client.py index 0136a34..4b6a6a3 100644 --- a/rl/llm/engines/client.py +++ b/rl/llm/engines/client.py @@ -1,8 +1,9 @@ -import io import json import re +import tempfile import time from abc import ABC, abstractmethod +from pathlib import Path from typing import TYPE_CHECKING import tqdm @@ -35,7 +36,7 @@ class ClientEngine(InferenceEngine, ABC): def generate(self, prompt: ChatInput) -> InferenceOutput: pass - def batch_generate(self, prompts: list[ChatInput]) -> InferenceOutput: + def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]: return thread_map( self.generate, prompts, @@ -342,30 +343,35 @@ def generate(self, prompt: ChatInput) -> InferenceOutput: return self.batch_generate([prompt])[0] def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]: - # Create in-memory JSONL file - jsonl_file = io.StringIO() - for i, prompt in enumerate(prompts): - body_kwargs = { - "model": self.llm_config.model_name_or_path, - "messages": prompt, - } - if self.llm_config.max_new_tokens is not None: - body_kwargs["max_tokens"] = self.llm_config.max_new_tokens - if self.llm_config.temperature is not None: - body_kwargs["temperature"] = self.llm_config.temperature - if EngineFeature.JSON_OUTPUT in self.enabled_features: - body_kwargs["response_format"] = {"type": "json_object"} - request = { - "custom_id": f"request-{i}", - "method": "POST", - "url": "/v1/chat/completions", - "body": body_kwargs, - } - jsonl_file.write(json.dumps(request) + "\n") - - jsonl_file.seek(0) - batch_input_file = self.client.files.create(file=jsonl_file, purpose="batch") + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".jsonl", delete=False + ) as temp_file: + for i, prompt in enumerate(prompts): + body_kwargs = { + "model": self.llm_config.model_name_or_path, + "messages": prompt, + } + if self.llm_config.max_new_tokens is not None: + body_kwargs["max_tokens"] = self.llm_config.max_new_tokens + if self.llm_config.temperature is not None: + body_kwargs["temperature"] = self.llm_config.temperature + if EngineFeature.JSON_OUTPUT in self.enabled_features: + body_kwargs["response_format"] = {"type": "json_object"} + request = { + "custom_id": f"request-{i}", + "method": "POST", + "url": "/v1/chat/completions", + "body": body_kwargs, + } + json.dump(request, temp_file) + temp_file.write("\n") + temp_file_path = temp_file.name + + # Upload file + with Path(temp_file_path).open("rb") as file: + batch_input_file = self.client.files.create(file=file, purpose="batch") + # Create batch batch = self.client.batches.create( input_file_id=batch_input_file.id, endpoint="/v1/chat/completions", @@ -373,7 +379,7 @@ def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]: ) # Poll for status and update progress bar - pbar = tqdm.tqdm(total=len(prompts), desc="Polling batch status") + pbar = tqdm(total=len(prompts), desc="Polling batch status") while batch.status not in ["completed", "failed", "expired"]: time.sleep(5) # Poll every 5 seconds batch = self.client.batches.retrieve(batch.id)