Skip to content

Commit

Permalink
Make all __enter__ methods return self
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jun 22, 2024
1 parent b6aed06 commit 917f5f3
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, llm_config: LLMConfig):
self.llm_config = llm_config

def __enter__(self):
pass
return self

def __exit__(self, exc_type, exc_value, traceback):
pass
Expand Down Expand Up @@ -251,6 +251,7 @@ def __init__(self, llm_config: LLMConfig):

def __enter__(self):
genai.configure(api_key=rl.utils.io.getenv("GEMINI_API_KEY"))
return self

def generate(self, prompt: ChatInput) -> InferenceOutput:
if not isinstance(prompt, list):
Expand Down Expand Up @@ -325,6 +326,7 @@ def __init__(self, llm_config: LLMConfig):

def __enter__(self):
self.client = Anthropic(api_key=rl.utils.io.getenv(self.API_KEY_NAME))
return self

def generate(self, prompt: ChatInput) -> InferenceOutput:
"""Given the input prompt, returns the generated text.
Expand Down Expand Up @@ -385,6 +387,7 @@ def __enter__(self):
LOGGER.warning(
f"Modal app {self.app_name} is ready! Took {time.time() - start_time:.2f}s"
)
return self

def generate(self, prompt: InferenceInput) -> InferenceOutput:
return self.batch_generate([prompt])[0]
Expand Down Expand Up @@ -433,7 +436,7 @@ def __enter__(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.llm_config.tokenizer_name_or_path
)
pass
return self

def __exit__(self, exc_type, exc_value, traceback):
pass
Expand Down Expand Up @@ -580,6 +583,7 @@ def __enter__(self):
self.llm_config, use_async=False
)
self.tokenizer = self.vllm.get_tokenizer()
return self

def __exit__(self, exc_type, exc_value, traceback):
del self.vllm
Expand Down Expand Up @@ -772,6 +776,7 @@ def __enter__(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.llm_config.tokenizer_name_or_path
)
return self

def __exit__(self, exc_type, exc_value, traceback):
del self.vllm
Expand Down

0 comments on commit 917f5f3

Please sign in to comment.