Skip to content

Commit

Permalink
Fix engine init
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Aug 26, 2024
1 parent a7be127 commit 5a32637
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _register_engine(
supported_features: tuple[EngineFeature, ...] = (),
):
def init_decorator(cls):
original_init = cls.__init__
old_init = cls.__init__

def new_init(self, llm_config: LLMConfig, *args, **kwargs):
if not llm_config.features.issubset(set(supported_features)):
Expand All @@ -117,7 +117,7 @@ def new_init(self, llm_config: LLMConfig, *args, **kwargs):
f"{llm_config.features - set(supported_features)}"
)

original_init(self, llm_config, *args, **kwargs)
old_init(self, llm_config, *args, **kwargs)

cls.__init__ = new_init
return cls
Expand Down Expand Up @@ -162,15 +162,6 @@ class InferenceEngine:
llm_config: LLMConfig

def __init__(self, llm_config: LLMConfig):
if self.__class__ is InferenceEngine:
if llm_config.engine_name is None:
raise MissingEngineNameError(
"When initializing an inference engine via the base class InferenceEngine, "
"you must pass an llm_config with an engine_name set. Available engine names: "
f"{', '.join(ENGINES.keys())}"
)
return get_inference_engine(llm_config)

rl.utils.io.ensure_dotenv_loaded()
self.llm_config = llm_config

Expand All @@ -180,6 +171,16 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
pass

@staticmethod
def from_config(llm_config: LLMConfig) -> "InferenceEngine":
if llm_config.engine_name is None:
raise MissingEngineNameError(
"When initializing an inference engine via the base class InferenceEngine, "
"you must pass an llm_config with an engine_name set. Available engine names: "
f"{', '.join(ENGINES.keys())}"
)
return get_inference_engine(llm_config)

def generate(self, prompt: InferenceInput) -> InferenceOutput:
"""Given the input prompt, returns the generated text.
Expand Down Expand Up @@ -267,12 +268,13 @@ def generate(
)


class ClientEngine(InferenceEngine):
class ClientEngine(InferenceEngine, ABC):
BASE_URL: str
API_KEY_NAME: str

@abstractmethod
def generate(self, prompt: ChatInput) -> InferenceOutput:
raise NotImplementedError
pass


class OpenAIClientEngine(InferenceEngine, ABC):
Expand Down

0 comments on commit 5a32637

Please sign in to comment.