Skip to content

Commit

Permalink
Fix abstract class issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Aug 26, 2024
1 parent 3833b4d commit a7be127
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _import_if_available(module_name: str) -> bool:
return False


class InferenceEngine(ABC):
class InferenceEngine:
NAME: str
REQUIRED_MODULES: set[str]
SUPPORTED_FEATURES: set[EngineFeature]
Expand All @@ -180,7 +180,6 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
pass

@abstractmethod
def generate(self, prompt: InferenceInput) -> InferenceOutput:
"""Given the input prompt, returns the generated text.
Expand All @@ -190,7 +189,7 @@ def generate(self, prompt: InferenceInput) -> InferenceOutput:
Returns:
The generated text (not including the prompt).
"""
pass
raise NotImplementedError

def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput]:
"""Given the input prompts, returns the generated texts.
Expand Down Expand Up @@ -268,19 +267,12 @@ def generate(
)


class ClientEngine(InferenceEngine, ABC):
NAME: str
class ClientEngine(InferenceEngine):
BASE_URL: str
API_KEY_NAME: str
llm_config: LLMConfig

def __init__(self, llm_config: LLMConfig):
rl.utils.io.ensure_dotenv_loaded()
self.llm_config = llm_config

@abstractmethod
def generate(self, prompt: ChatInput) -> InferenceOutput:
pass
raise NotImplementedError


class OpenAIClientEngine(InferenceEngine, ABC):
Expand Down

0 comments on commit a7be127

Please sign in to comment.