Skip to content

Commit

Permalink
Inject sanity checks via register_engine
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Sep 3, 2024
1 parent 1865d12 commit 76795df
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions rl/llm/engines/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class EngineNotSupportedError(InferenceEngineError):
pass


class EngineNotEnteredContextError(InferenceEngineError):
pass


class FeatureNotSupportedError(InferenceEngineError):
pass

Expand Down Expand Up @@ -88,6 +92,10 @@ def register_engine(
):
def init_decorator(cls):
old_init = cls.__init__
old_enter = cls.__enter__
old_exit = cls.__exit__
old_generate = cls.generate
old_batch_generate = cls.batch_generate

def new_init(self, llm_config: LLMConfig, *args, **kwargs):
if not llm_config.features.issubset(cls.SUPPORTED_FEATURES):
Expand All @@ -96,9 +104,41 @@ def new_init(self, llm_config: LLMConfig, *args, **kwargs):
f"{llm_config.features - cls.SUPPORTED_FEATURES}"
)
self.enabled_features = llm_config.features.copy()
self.entered_context = False
old_init(self, llm_config, *args, **kwargs)

def new_enter(self):
self.entered_context = True
return old_enter(self)

def new_exit(self, exc_type, exc_value, traceback):
if self.entered_context:
self.entered_context = False
old_exit(self, exc_type, exc_value, traceback)

def new_generate(self, prompt: InferenceInput) -> InferenceOutput:
if not self.entered_context:
raise EngineNotEnteredContextError(
"You must enter the context of the engine by calling "
"`with engine as e:` before you can call `e.generate()`."
)
return old_generate(self, prompt)

def new_batch_generate(
self, prompts: list[InferenceInput]
) -> list[InferenceOutput]:
if not self.entered_context:
raise EngineNotEnteredContextError(
"You must enter the context of the engine by calling "
"`with engine as e:` before you can call `e.batch_generate()`."
)
return old_batch_generate(self, prompts)

cls.__init__ = new_init
cls.__enter__ = new_enter
cls.__exit__ = new_exit
cls.generate = new_generate
cls.batch_generate = new_batch_generate
return cls

def decorator(cls):
Expand Down

0 comments on commit 76795df

Please sign in to comment.