Skip to content

Commit

Permalink
Make progress on modal engine
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jun 4, 2024
1 parent 0267bbc commit e34fbea
Show file tree
Hide file tree
Showing 5 changed files with 414 additions and 33 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,5 @@ skip_gitignore = true
disable_error_code = ["import", "override"]

[tool.ruff.lint]
extend-select = ["I"]
extend-ignore = ["F401"]
124 changes: 91 additions & 33 deletions rl/llm/engines.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import datetime
import hashlib
import json
import math
import os
import socket
import subprocess
import sys
import tempfile
import textwrap as tw
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from subprocess import call
from typing import Any, AsyncGenerator, Iterator, Union, cast
import textwrap as tw

import click
import huggingface_hub
import modal
import modal.runner
import openai
import torch
import tqdm.asyncio
from openai import OpenAI
from anthropic import Anthropic
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from transformers import AutoTokenizer, PreTrainedTokenizer

import rl.llm.modal_utils
import rl.utils.io
from rl.llm.config import LLMConfig
from rl.utils import LOGGER
Expand All @@ -39,7 +47,9 @@
)
from vllm.lora.request import LoRARequest

InferenceInput = Union[str, ChatCompletionMessageParam]

ChatInput = list[ChatCompletionMessageParam]
InferenceInput = Union[str, ChatInput]


@dataclass(frozen=True)
Expand Down Expand Up @@ -93,8 +103,14 @@ def batch_generate(self, prompts: list[InferenceInput]) -> list[InferenceOutput]
return [self.generate(prompt) for prompt in prompts]


_RESPONSE_CANARY = "### Response template begins now, delete this line. ###"


class ManualEditEngine(InferenceEngine):
NAME = "human input"
NAME = "manual_edit"
_EDITOR = os.environ.get("EDITOR", "vim")

tokenizer: PreTrainedTokenizer

def __init__(
self, llm_config: LLMConfig | None = None, response_template: str = ""
Expand All @@ -105,14 +121,8 @@ def __init__(
def generate(
self, prompt: InferenceInput, wrap_prompt: bool = True
) -> InferenceOutput:
"""Open a temp file, and put the prompt in there. Then open the file in EDITOR, and wait for the user to write the response. make any necessary imports in the mehtod"""
import sys
import tempfile
import os
import datetime
from subprocess import call

EDITOR = os.environ.get("EDITOR", "vim")
"""Open a temp file, and put the prompt in there. Then open the file in EDITOR,
and wait for the user to write the response. make any necessary imports in the method."""

if not isinstance(prompt, str):
if not hasattr(self, "tokenizer"):
Expand All @@ -131,27 +141,26 @@ def generate(
with tempfile.NamedTemporaryFile(suffix=".tmp") as tf:
tf.write(prompt.encode())
if self.response_template:
tf.write(
f"\n### response template begins now, delete this line ###\n{self.response_template}".encode()
)
tf.write(f"\n{_RESPONSE_CANARY}\n{self.response_template}".encode())
tf.flush()
call([EDITOR, tf.name])
call([self._EDITOR, tf.name])
tf.seek(0)
edited_message = tf.read().decode()
if not edited_message.startswith(prompt):
raise ValueError(
"The prompt has been modified. Please do not modify the prompt."
"The prompt has been modified. Please do not modify the prompt. "
"Future attempts to do so will lead to disciplinary action."
)
response = edited_message[len(prompt) :].strip()
if not response:
raise ValueError("The response is empty or unsaved.")
if "### response template begins now, delete this line ###" in response:
if _RESPONSE_CANARY in response:
raise ValueError("The response template marker has not been deleted.")
return InferenceOutput(
prompt=prompt,
text=response,
metadata={
"model": "manual edit",
"model": self.NAME,
"created_at": datetime.datetime.now().isoformat(),
},
)
Expand All @@ -168,9 +177,7 @@ def __init__(self, llm_config: LLMConfig):
self.llm_config = llm_config

@abstractmethod
def generate(
self, prompt: openai.types.chat.ChatCompletionMessageParam
) -> InferenceOutput:
def generate(self, prompt: ChatInput) -> InferenceOutput:
pass


Expand All @@ -179,16 +186,17 @@ class OpenAIClientEngine(InferenceEngine, ABC):
BASE_URL: str
API_KEY_NAME: str
llm_config: LLMConfig
client: openai.Client

def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)
self.client = openai.OpenAI(

def __enter__(self):
self.client = openai.Client(
api_key=rl.utils.io.getenv(self.API_KEY_NAME), base_url=self.BASE_URL
)

def generate(
self, prompt: openai.types.chat.ChatCompletionMessageParam
) -> InferenceOutput:
def generate(self, prompt: ChatInput) -> InferenceOutput:
"""Given the input prompt, returns the generated text.
Args:
Expand Down Expand Up @@ -228,20 +236,24 @@ class OpenAIEngine(OpenAIClientEngine):
API_KEY_NAME = "OPENAI_API_KEY"


class GroqEngine(OpenAIClientEngine):
NAME = "groq"
BASE_URL = "https://api.groq.com/openai/v1"
API_KEY_NAME = "GROQ_API_KEY"


class AnthropicEngine(ClientEngine):
NAME = "anthropic"
BASE_URL = "https://api.anthropic.com/v1"
API_KEY_NAME = "ANTHROPIC_API_KEY"

def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)

def __enter__(self):
self.client = Anthropic(api_key=rl.utils.io.getenv(self.API_KEY_NAME))

def generate(
self,
prompt: openai.types.chat.ChatCompletionMessageParam,
max_tokens: int = 1024,
) -> InferenceOutput:
def generate(self, prompt: ChatInput) -> InferenceOutput:
"""Given the input prompt, returns the generated text.
Args:
Expand All @@ -262,9 +274,9 @@ def generate(

message = self.client.messages.create(
model=self.llm_config.model_name_or_path,
messages=prompt,
system=system_prompt,
max_tokens=max_tokens,
messages=prompt,
max_tokens=self.llm_config.max_new_tokens,
)
return InferenceOutput(
prompt=prompt, # type: ignore
Expand All @@ -276,6 +288,52 @@ def generate(
)


class ModalEngine(OpenAIClientEngine):
NAME = "modal"
app_name: str

def __init__(self, llm_config: LLMConfig):
super().__init__(llm_config)
self.app_name = self._get_modal_app_name()

def __enter__(self):
if self.llm_config.num_gpus is None:
LOGGER.warning(
"num_gpus is not set. Will deploy to Modal with 1 A100 80GB GPU."
)
deployed_id = rl.llm.modal_utils.get_deployed_id(self.app_name)
if deployed_id is None:
LOGGER.info(f"No deployed app found for {self.app_name}. Deploying...")
deploy_config = {
"app_name": self.app_name,
"model_name_or_path": self.llm_config.model_name_or_path,
"num_gpus": self.llm_config.num_gpus,
"vllm_kwargs": _get_vllm_kwargs(self.llm_config),
}
deploy_env = {"MODAL_DEPLOY_CONFIG": json.dumps(deploy_config)}
print(deploy_env)
entrypoint_path = str(Path(__file__).parent / "modal_entrypoint.py")
subprocess.run(
[
sys.executable,
"-m",
"modal",
"deploy",
entrypoint_path,
],
check=True,
env=deploy_env,
)

def _get_modal_app_name(self):
vllm_kwargs = _get_vllm_kwargs(self.llm_config)
vllm_kwargs["enforce_eager"] = False
app_name = f"vllm_{self.llm_config.model_name_or_path}_"
# Add the hash of the kwargs to the app name
app_name += hashlib.md5(json.dumps(vllm_kwargs).encode()).hexdigest()[:16]
return app_name.replace("/", "_")


class AsyncInferenceEngine:
NAME: str
llm_config: LLMConfig
Expand Down
Loading

0 comments on commit e34fbea

Please sign in to comment.