Skip to content

Commit

Permalink
Add thread_map batch_generate for client engine
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Aug 26, 2024
1 parent 5a32637 commit 416febd
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import tqdm.asyncio
from pydantic import BaseModel, Field
from tqdm.contrib.concurrent import thread_map
from typing_extensions import TypedDict

import rl.llm.modal_utils
Expand Down Expand Up @@ -201,7 +202,9 @@ def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput]
Returns:
The generated texts (not including the prompts).
"""
return [self.generate(prompt) for prompt in prompts]
return [
self.generate(prompt) for prompt in tqdm.tqdm(prompts, desc="Generating")
]


_RESPONSE_CANARY = "### Response template begins now, delete this line. ###"
Expand Down Expand Up @@ -268,6 +271,9 @@ def generate(
)


_CLIENT_ENGINE_MAX_WORKERS = int(rl.utils.io.getenv("RL_MAX_WORKERS", 4))


class ClientEngine(InferenceEngine, ABC):
BASE_URL: str
API_KEY_NAME: str
Expand All @@ -276,8 +282,13 @@ class ClientEngine(InferenceEngine, ABC):
def generate(self, prompt: ChatInput) -> InferenceOutput:
pass

def batch_generate(self, prompts: list[ChatInput]) -> InferenceOutput:
return thread_map(
self.generate, prompts, max_workers=_CLIENT_ENGINE_MAX_WORKERS
)


class OpenAIClientEngine(InferenceEngine, ABC):
class _OpenAIClientEngine(ClientEngine, ABC):
BASE_URL: str = "https://api.openai.com/v1"
API_KEY_NAME: str = "OPENAI_API_KEY"
llm_config: LLMConfig
Expand Down Expand Up @@ -323,19 +334,19 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:


@_register_engine("together", required_modules=("openai",))
class TogetherEngine(OpenAIClientEngine):
class TogetherEngine(_OpenAIClientEngine):
BASE_URL = "https://api.together.xyz/v1"
API_KEY_NAME = "TOGETHER_API_KEY"


@_register_engine("openai", required_modules=("openai",))
class OpenAIEngine(OpenAIClientEngine):
class OpenAIEngine(_OpenAIClientEngine):
BASE_URL = "https://api.openai.com/v1"
API_KEY_NAME = "OPENAI_API_KEY"


@_register_engine("groq", required_modules=("openai",))
class GroqEngine(OpenAIClientEngine):
class GroqEngine(_OpenAIClientEngine):
BASE_URL = "https://api.groq.com/openai/v1"
API_KEY_NAME = "GROQ_API_KEY"

Expand Down

0 comments on commit 416febd

Please sign in to comment.