Skip to content

Commit

Permalink
Use tempfile for batch
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 3, 2024
1 parent 6e51461 commit 913fb56
Showing 1 changed file with 33 additions and 28 deletions.
61 changes: 33 additions & 28 deletions rl/llm/engines/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
import json
import re
import tempfile
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING

import tqdm
Expand Down Expand Up @@ -35,7 +36,7 @@ class ClientEngine(InferenceEngine, ABC):
def generate(self, prompt: ChatInput) -> InferenceOutput:
pass

def batch_generate(self, prompts: list[ChatInput]) -> InferenceOutput:
def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
return thread_map(
self.generate,
prompts,
Expand Down Expand Up @@ -342,37 +343,41 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:
return self.batch_generate([prompt])[0]

def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
# Create in-memory JSONL file
jsonl_file = io.StringIO()
for i, prompt in enumerate(prompts):
body_kwargs = {
"model": self.llm_config.model_name_or_path,
"messages": prompt,
}
if self.llm_config.max_new_tokens is not None:
body_kwargs["max_tokens"] = self.llm_config.max_new_tokens
if self.llm_config.temperature is not None:
body_kwargs["temperature"] = self.llm_config.temperature
if EngineFeature.JSON_OUTPUT in self.enabled_features:
body_kwargs["response_format"] = {"type": "json_object"}
request = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": body_kwargs,
}
jsonl_file.write(json.dumps(request) + "\n")

jsonl_file.seek(0)
batch_input_file = self.client.files.create(file=jsonl_file, purpose="batch")
with tempfile.NamedTemporaryFile(
mode="w+", suffix=".jsonl", delete=False
) as temp_file:
for i, prompt in enumerate(prompts):
body_kwargs = {
"model": self.llm_config.model_name_or_path,
"messages": prompt,
}
if self.llm_config.max_new_tokens is not None:
body_kwargs["max_tokens"] = self.llm_config.max_new_tokens
if self.llm_config.temperature is not None:
body_kwargs["temperature"] = self.llm_config.temperature
if EngineFeature.JSON_OUTPUT in self.enabled_features:
body_kwargs["response_format"] = {"type": "json_object"}
request = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": body_kwargs,
}
json.dump(request, temp_file)
temp_file.write("\n")
temp_file_path = temp_file.name

# Upload file
with Path(temp_file_path).open("rb") as file:
batch_input_file = self.client.files.create(file=file, purpose="batch")

# Create batch
batch = self.client.batches.create(
input_file_id=batch_input_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)

# Poll for status and update progress bar
pbar = tqdm.tqdm(total=len(prompts), desc="Polling batch status")
while batch.status not in ["completed", "failed", "expired"]:
time.sleep(5) # Poll every 5 seconds
Expand All @@ -385,11 +390,9 @@ def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
if batch.status != "completed":
raise RuntimeError(f"Batch failed with status: {batch.status}")

# Retrieve results
output_file = self.client.files.content(batch.output_file_id)
results = [json.loads(line) for line in output_file.text.strip().split("\n")]

# Process results
outputs = []
for result in results:
response = result["response"]["body"]
Expand All @@ -404,4 +407,6 @@ def batch_generate(self, prompts: list[ChatInput]) -> list[InferenceOutput]:
)
)

Path(temp_file_path).unlink()

return outputs

0 comments on commit 913fb56

Please sign in to comment.