From d0f14f7389fbe420d4fb8cf5cabf5b9a0cff57b6 Mon Sep 17 00:00:00 2001 From: Varun Date: Sun, 23 Jun 2024 17:28:59 -0700 Subject: [PATCH] add json type option, uspport for gemini --- rl/llm/config.py | 1 + rl/llm/engines.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/rl/llm/config.py b/rl/llm/config.py index 28d6463..cbba111 100644 --- a/rl/llm/config.py +++ b/rl/llm/config.py @@ -20,6 +20,7 @@ class LLMConfig: frequency_penalty: float = 0.2 # Experiment with this num_gpus: int | None = None visible_devices: str | None = None + json_output: bool = False def __post_init__(self): if not self.tokenizer_name_or_path: diff --git a/rl/llm/engines.py b/rl/llm/engines.py index f08aeb8..2e1713c 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -270,7 +270,7 @@ def generate(self, prompt: ChatInput) -> InferenceOutput: generation_config={ "temperature": self.llm_config.temperature, "max_output_tokens": self.llm_config.max_new_tokens, - "response_mime_type": "text/plain", + "response_mime_type": "application/json" if self.config.json_output else "text/plain", }, system_instruction=system_message, safety_settings={