diff --git a/rl/llm/engines.py b/rl/llm/engines.py index 3b9108b..2342694 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -154,7 +154,7 @@ def _import_if_available(module_name: str) -> bool: return False -class InferenceEngine(ABC): +class InferenceEngine: NAME: str REQUIRED_MODULES: set[str] SUPPORTED_FEATURES: set[EngineFeature] @@ -180,7 +180,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): pass - @abstractmethod def generate(self, prompt: InferenceInput) -> InferenceOutput: """Given the input prompt, returns the generated text. @@ -190,7 +189,7 @@ def generate(self, prompt: InferenceInput) -> InferenceOutput: Returns: The generated text (not including the prompt). """ - pass + raise NotImplementedError def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput]: """Given the input prompts, returns the generated texts. @@ -268,19 +267,12 @@ def generate( ) -class ClientEngine(InferenceEngine, ABC): - NAME: str +class ClientEngine(InferenceEngine): BASE_URL: str API_KEY_NAME: str - llm_config: LLMConfig - - def __init__(self, llm_config: LLMConfig): - rl.utils.io.ensure_dotenv_loaded() - self.llm_config = llm_config - @abstractmethod def generate(self, prompt: ChatInput) -> InferenceOutput: - pass + raise NotImplementedError class OpenAIClientEngine(InferenceEngine, ABC):