diff --git a/rl/llm/engines.py b/rl/llm/engines.py index 098e7d9..f08aeb8 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -74,7 +74,7 @@ def __init__(self, llm_config: LLMConfig): self.llm_config = llm_config def __enter__(self): - pass + return self def __exit__(self, exc_type, exc_value, traceback): pass @@ -251,6 +251,7 @@ def __init__(self, llm_config: LLMConfig): def __enter__(self): genai.configure(api_key=rl.utils.io.getenv("GEMINI_API_KEY")) + return self def generate(self, prompt: ChatInput) -> InferenceOutput: if not isinstance(prompt, list): @@ -325,6 +326,7 @@ def __init__(self, llm_config: LLMConfig): def __enter__(self): self.client = Anthropic(api_key=rl.utils.io.getenv(self.API_KEY_NAME)) + return self def generate(self, prompt: ChatInput) -> InferenceOutput: """Given the input prompt, returns the generated text. @@ -385,6 +387,7 @@ def __enter__(self): LOGGER.warning( f"Modal app {self.app_name} is ready! Took {time.time() - start_time:.2f}s" ) + return self def generate(self, prompt: InferenceInput) -> InferenceOutput: return self.batch_generate([prompt])[0] @@ -433,7 +436,7 @@ def __enter__(self): self.tokenizer = AutoTokenizer.from_pretrained( self.llm_config.tokenizer_name_or_path ) - pass + return self def __exit__(self, exc_type, exc_value, traceback): pass @@ -580,6 +583,7 @@ def __enter__(self): self.llm_config, use_async=False ) self.tokenizer = self.vllm.get_tokenizer() + return self def __exit__(self, exc_type, exc_value, traceback): del self.vllm @@ -772,6 +776,7 @@ def __enter__(self): self.tokenizer = AutoTokenizer.from_pretrained( self.llm_config.tokenizer_name_or_path ) + return self def __exit__(self, exc_type, exc_value, traceback): del self.vllm