diff --git a/rl/llm/engines.py b/rl/llm/engines.py index f5a2360..4730244 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -270,7 +270,9 @@ def generate(self, prompt: ChatInput) -> InferenceOutput: generation_config={ "temperature": self.llm_config.temperature, "max_output_tokens": self.llm_config.max_new_tokens, - "response_mime_type": "application/json" if self.llm_config.json_output else "text/plain", + "response_mime_type": "application/json" + if self.llm_config.json_output + else "text/plain", }, system_instruction=system_message, safety_settings={ @@ -825,6 +827,7 @@ def _wrap_output(self, req_output) -> InferenceOutput: e.NAME: e for e in ( VLLMEngine, + AsyncVLLMEngine, WorkerVLLMEngine, OpenAIEngine, TogetherEngine, @@ -832,6 +835,7 @@ def _wrap_output(self, req_output) -> InferenceOutput: AnthropicEngine, ModalEngine, GeminiEngine, + ManualEditEngine, ) }