Skip to content

Commit

Permalink
Add gemini engine, at significant cost to well being
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jun 7, 2024
1 parent 4f18fd5 commit a9f6649
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ llm = [
"modal",
"anthropic",
"openai",
"google-generativeai"
]
dev = [
"build",
Expand Down
65 changes: 65 additions & 0 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, AsyncGenerator, Iterator, Union, cast

import click
import google.generativeai as genai
import huggingface_hub
import modal
import modal.runner
Expand Down Expand Up @@ -193,6 +194,7 @@ def __enter__(self):
self.client = openai.Client(
api_key=rl.utils.io.getenv(self.API_KEY_NAME), base_url=self.BASE_URL
)
return self

def generate(self, prompt: ChatInput) -> InferenceOutput:
"""Given the input prompt, returns the generated text.
Expand Down Expand Up @@ -240,6 +242,68 @@ class GroqEngine(OpenAIClientEngine):
API_KEY_NAME = "GROQ_API_KEY"


class GeminiEngine(InferenceEngine):
NAME: str = "gemini"

def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)

def __enter__(self):
genai.configure(api_key=rl.utils.io.getenv("GEMINI_API_KEY"))

def generate(self, prompt: ChatInput) -> InferenceOutput:
if not isinstance(prompt, list):
raise ValueError(
"ClientEngine requires a list of dicts, in the Gemini API style."
)
system_message, prev_messages, last_message = self._convert_openai_to_gemini(
prompt
)
# One might reasonably ask, why not initialize the model in __enter__?
# Well, I'll tell you: Google's moronic abstraction requires you to
# pass the system instruction when *initializing* the model object,
# because that makes sense.
model = genai.GenerativeModel(
model_name=self.llm_config.model_name_or_path,
generation_config={
"temperature": self.llm_config.temperature,
"max_output_tokens": self.llm_config.max_new_tokens,
"response_mime_type": "text/plain",
},
system_instruction=system_message,
)
chat_session = model.start_chat(
history=prev_messages,
)
# Can't include the last message in the history, because
# that would make too much sense!
response = chat_session.send_message(last_message)

return InferenceOutput(
prompt=prompt,
text=response.text,
metadata={
"model": self.llm_config.model_name_or_path,
},
)

def _convert_openai_to_gemini(
self, prompt: ChatInput
) -> tuple[str | None, list, str]:
"""Returns the system instruction, the previous messages, and the last message in the Gemini format."""
system_prompt = None
if prompt and prompt[0]["role"] == "system":
system_prompt = prompt[0]["content"]
prompt = prompt[1:]
last_message = prompt[-1]["content"]
prompt = prompt[:-1]
return (
system_prompt,
[{"role": msg["role"], "parts": [msg["content"]]} for msg in prompt],
last_message,
)


class AnthropicEngine(ClientEngine):
NAME = "anthropic"
BASE_URL = "https://api.anthropic.com/v1"
Expand Down Expand Up @@ -749,6 +813,7 @@ def _wrap_output(self, req_output) -> InferenceOutput:
GroqEngine,
AnthropicEngine,
ModalEngine,
GeminiEngine,
)
}

Expand Down

0 comments on commit a9f6649

Please sign in to comment.