From a9f66491945445e190448c142d141009de80271d Mon Sep 17 00:00:00 2001 From: Faiz Surani Date: Thu, 6 Jun 2024 17:34:08 -0700 Subject: [PATCH] Add gemini engine, at significant cost to well being --- pyproject.toml | 1 + rl/llm/engines.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index c032a11..02b6a05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ llm = [ "modal", "anthropic", "openai", + "google-generativeai" ] dev = [ "build", diff --git a/rl/llm/engines.py b/rl/llm/engines.py index 0f3a7b9..ba4babb 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -15,6 +15,7 @@ from typing import Any, AsyncGenerator, Iterator, Union, cast import click +import google.generativeai as genai import huggingface_hub import modal import modal.runner @@ -193,6 +194,7 @@ def __enter__(self): self.client = openai.Client( api_key=rl.utils.io.getenv(self.API_KEY_NAME), base_url=self.BASE_URL ) + return self def generate(self, prompt: ChatInput) -> InferenceOutput: """Given the input prompt, returns the generated text. @@ -240,6 +242,68 @@ class GroqEngine(OpenAIClientEngine): API_KEY_NAME = "GROQ_API_KEY" +class GeminiEngine(InferenceEngine): + NAME: str = "gemini" + + def __init__(self, llm_config: LLMConfig): + super().__init__(llm_config) + + def __enter__(self): + genai.configure(api_key=rl.utils.io.getenv("GEMINI_API_KEY")) + + def generate(self, prompt: ChatInput) -> InferenceOutput: + if not isinstance(prompt, list): + raise ValueError( + "ClientEngine requires a list of dicts, in the Gemini API style." + ) + system_message, prev_messages, last_message = self._convert_openai_to_gemini( + prompt + ) + # One might reasonably ask, why not initialize the model in __enter__? + # Well, I'll tell you: Google's moronic abstraction requires you to + # pass the system instruction when *initializing* the model object, + # because that makes sense. + model = genai.GenerativeModel( + model_name=self.llm_config.model_name_or_path, + generation_config={ + "temperature": self.llm_config.temperature, + "max_output_tokens": self.llm_config.max_new_tokens, + "response_mime_type": "text/plain", + }, + system_instruction=system_message, + ) + chat_session = model.start_chat( + history=prev_messages, + ) + # Can't include the last message in the history, because + # that would make too much sense! + response = chat_session.send_message(last_message) + + return InferenceOutput( + prompt=prompt, + text=response.text, + metadata={ + "model": self.llm_config.model_name_or_path, + }, + ) + + def _convert_openai_to_gemini( + self, prompt: ChatInput + ) -> tuple[str | None, list, str]: + """Returns the system instruction, the previous messages, and the last message in the Gemini format.""" + system_prompt = None + if prompt and prompt[0]["role"] == "system": + system_prompt = prompt[0]["content"] + prompt = prompt[1:] + last_message = prompt[-1]["content"] + prompt = prompt[:-1] + return ( + system_prompt, + [{"role": msg["role"], "parts": [msg["content"]]} for msg in prompt], + last_message, + ) + + class AnthropicEngine(ClientEngine): NAME = "anthropic" BASE_URL = "https://api.anthropic.com/v1" @@ -749,6 +813,7 @@ def _wrap_output(self, req_output) -> InferenceOutput: GroqEngine, AnthropicEngine, ModalEngine, + GeminiEngine, ) }