Skip to content

Commit

Permalink
Use tempfile for batch
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 3, 2024
1 parent 6e51461 commit e99d069
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions rl/llm/engines/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
import json
import re
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING

import tqdm
Expand Down Expand Up @@ -35,7 +36,7 @@ class ClientEngine(InferenceEngine, ABC):
def generate(self, prompt: ChatInput) -> InferenceOutput:
pass

def batch_generate(self, prompts: list[ChatInput]) -> InferenceOutput:
def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
return thread_map(
self.generate,
prompts,
Expand Down Expand Up @@ -342,38 +343,43 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
return self.batch_generate([prompt])[0]

def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
# Create in-memory JSONL file
jsonl_file = io.StringIO()
for i, prompt in enumerate(prompts):
body_kwargs = {
"model": self.llm_config.model_name_or_path,
"messages": prompt,
}
if self.llm_config.max_new_tokens is not None:
body_kwargs["max_tokens"] = self.llm_config.max_new_tokens
if self.llm_config.temperature is not None:
body_kwargs["temperature"] = self.llm_config.temperature
if EngineFeature.JSON_OUTPUT in self.enabled_features:
body_kwargs["response_format"] = {"type": "json_object"}
request = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": body_kwargs,
}
jsonl_file.write(json.dumps(request) + "\n")

jsonl_file.seek(0)
batch_input_file = self.client.files.create(file=jsonl_file, purpose="batch")
with tempfile.NamedTemporaryFile(
mode="w+", suffix=".jsonl", delete=False
) as temp_file:
for i, prompt in enumerate(prompts):
body_kwargs = {
"model": self.llm_config.model_name_or_path,
"messages": prompt,
}
if self.llm_config.max_new_tokens is not None:
body_kwargs["max_tokens"] = self.llm_config.max_new_tokens
if self.llm_config.temperature is not None:
body_kwargs["temperature"] = self.llm_config.temperature
if EngineFeature.JSON_OUTPUT in self.enabled_features:
body_kwargs["response_format"] = {"type": "json_object"}
request = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": body_kwargs,
}
json.dump(request, temp_file)
temp_file.write("\n")
temp_file_path = temp_file.name

# Upload file
with Path(temp_file_path).open("rb") as file:
batch_input_file = self.client.files.create(file=file, purpose="batch")

# Create batch
batch = self.client.batches.create(
input_file_id=batch_input_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)

# Poll for status and update progress bar
pbar = tqdm.tqdm(total=len(prompts), desc="Polling batch status")
pbar = tqdm(total=len(prompts), desc="Polling batch status")
while batch.status not in ["completed", "failed", "expired"]:
time.sleep(5) # Poll every 5 seconds
batch = self.client.batches.retrieve(batch.id)
Expand Down

0 comments on commit e99d069

Please sign in to comment.