Skip to content

Commit

Permalink
Fix client max workers
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 3, 2024
1 parent 1ca8ba6 commit 80e41b7
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions rl/llm/engines/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
from transformers import PreTrainedTokenizer


_CLIENT_ENGINE_MAX_WORKERS = int(rl.utils.io.getenv("RL_MAX_WORKERS", 4))


class ClientEngine(InferenceEngine, ABC):
BASE_URL: str
API_KEY_NAME: str
Expand All @@ -37,7 +34,9 @@ def generate(self, prompt: ChatInput) -> InferenceOutput:

def batch_generate(self, prompts: list[ChatInput]) -> InferenceOutput:
return thread_map(
self.generate, prompts, max_workers=_CLIENT_ENGINE_MAX_WORKERS
self.generate,
prompts,
max_workers=int(rl.utils.io.getenv("RL_MAX_WORKERS", 4)),
)


Expand Down

0 comments on commit 80e41b7

Please sign in to comment.