Skip to content

Commit

Permalink
Clean up and link up llm_inputs
Browse files Browse the repository at this point in the history
Clean up and link up llm_inputs
  • Loading branch information
debermudez committed Mar 4, 2024
1 parent 9e763d3 commit cd9e588
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 31 deletions.
8 changes: 8 additions & 0 deletions src/c++/perf_analyzer/genai-pa/genai_pa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

LOGGER_NAME: str = "genai-pa"

DEFAULT_HTTP_URL = "localhost:8000"
DEFAULT_GRPC_URL = "localhost:8001"


OPEN_ORCA = "openorca"
CNN_DAILY_MAIL = "cnn_dailymail"
DEFAULT_INPUT_DATA_JSON = "./llm_inputs.json"
17 changes: 15 additions & 2 deletions src/c++/perf_analyzer/genai-pa/genai_pa/llm_inputs/llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict, List, Optional, Tuple

import requests
from genai_pa.constants import CNN_DAILY_MAIL, DEFAULT_INPUT_DATA_JSON, OPEN_ORCA
from genai_pa.exceptions import GenAiPAException
from requests import Response

Expand All @@ -26,7 +27,7 @@ class LlmInputs:
A library of methods that control the generation of LLM Inputs
"""

OUTPUT_FILENAME = "./llm_inputs.json"
OUTPUT_FILENAME = DEFAULT_INPUT_DATA_JSON

OPEN_ORCA_URL = "https://datasets-server.huggingface.co/rows?dataset=Open-Orca%2FOpenOrca&config=default&split=train"
CNN_DAILYMAIL_URL = "https://datasets-server.huggingface.co/rows?dataset=cnn_dailymail&config=1.0.0&split=train"
Expand All @@ -39,10 +40,12 @@ class LlmInputs:

EMPTY_JSON_IN_OPENAI_PA_FORMAT = {"data": [{"payload": []}]}

dataset_url_map = {OPEN_ORCA: OPEN_ORCA_URL, CNN_DAILY_MAIL: CNN_DAILYMAIL_URL}

@classmethod
def create_openai_llm_inputs(
cls,
url: str = OPEN_ORCA_URL,
url: str = OPEN_ORCA,
starting_index: int = DEFAULT_STARTING_INDEX,
length: int = DEFAULT_LENGTH,
model_name: str = None,
Expand All @@ -66,6 +69,7 @@ def create_openai_llm_inputs(
If true adds a steam field to each payload
"""

url = LlmInputs._resolve_url(url)
LlmInputs._check_for_valid_args(starting_index, length)
configured_url = LlmInputs._create_configured_url(url, starting_index, length)
dataset = LlmInputs._download_dataset(configured_url, starting_index, length)
Expand All @@ -77,6 +81,15 @@ def create_openai_llm_inputs(

return json_in_pa_format

@classmethod
def _resolve_url(cls, url: str) -> str:
"""
Resolve the dataset to a url if its known, otherwise use the dataset url passed in.
"""
if url in LlmInputs.dataset_url_map:
return LlmInputs.dataset_url_map[url]
return url

@classmethod
def _check_for_valid_args(cls, starting_index: int, length: int) -> None:
try:
Expand Down
12 changes: 12 additions & 0 deletions src/c++/perf_analyzer/genai-pa/genai_pa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,29 @@
from genai_pa import parser
from genai_pa.constants import LOGGER_NAME
from genai_pa.exceptions import GenAiPAException
from genai_pa.llm_inputs.llm_inputs import LlmInputs

logging.basicConfig(level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(LOGGER_NAME)


def generate_inputs(args):
LlmInputs.create_openai_llm_inputs(
args.dataset,
LlmInputs.DEFAULT_STARTING_INDEX,
LlmInputs.DEFAULT_LENGTH,
args.model,
args.streaming,
)


# Separate function that can raise exceptions used for testing
# to assert correct errors and messages.
# Optional argv used for testing - will default to sys.argv if None.
def run(argv=None):
try:
args = parser.parse_args(argv)
generate_inputs(args)
args.func(args)
except Exception as e:
raise GenAiPAException(e)
Expand Down
77 changes: 51 additions & 26 deletions src/c++/perf_analyzer/genai-pa/genai_pa/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@
import logging
from pathlib import Path

from genai_pa.constants import LOGGER_NAME
from genai_pa.constants import (
CNN_DAILY_MAIL,
DEFAULT_GRPC_URL,
DEFAULT_HTTP_URL,
LOGGER_NAME,
OPEN_ORCA,
)

logger = logging.getLogger(LOGGER_NAME)


def prune_args(args: argparse.ArgumentParser) -> argparse.ArgumentParser:
def _prune_args(args: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""
Prune the parsed arguments to remove args with None or False values.
"""
Expand All @@ -42,7 +48,7 @@ def prune_args(args: argparse.ArgumentParser) -> argparse.ArgumentParser:
)


def update_load_manager_args(args: argparse.ArgumentParser) -> argparse.ArgumentParser:
def _update_load_manager_args(args: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""
Update GenAI-PA load manager attributes to PA format
"""
Expand All @@ -54,10 +60,20 @@ def update_load_manager_args(args: argparse.ArgumentParser) -> argparse.Argument
return args


def _verify_valid_arg_combination(
args: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
# Verify protocol and default url match
if getattr(args, "i") == "grpc" and getattr(args, "u") == DEFAULT_HTTP_URL:
setattr(args, "u", DEFAULT_GRPC_URL)
# print(args)

return args


### Handlers ###


# NOTE: Placeholder
def handler(args):
from genai_pa.wrapper import Profiler

Expand All @@ -67,7 +83,7 @@ def handler(args):
### Parsers ###


def add_model_args(parser):
def _add_model_args(parser):
model_group = parser.add_argument_group("Model")

model_group.add_argument(
Expand All @@ -79,7 +95,7 @@ def add_model_args(parser):
)


def add_profile_args(parser):
def _add_profile_args(parser):
profile_group = parser.add_argument_group("Profiling")
load_management_group = profile_group.add_mutually_exclusive_group()

Expand Down Expand Up @@ -152,34 +168,42 @@ def add_profile_args(parser):
)


def add_endpoint_args(parser):
def _add_endpoint_args(parser):
endpoint_group = parser.add_argument_group("Endpoint")

endpoint_group.add_argument(
"-i",
type=str.lower,
choices=["http", "grpc"],
default="http",
required=False,
help=f"Sets the protocol used to communicate with inference service",
)

endpoint_group.add_argument(
"-u",
"--url",
type=str,
default="localhost:8001",
default=DEFAULT_HTTP_URL,
required=False,
dest="u",
metavar="URL",
help="URL of the endpoint to target for benchmarking.",
)


def add_dataset_args(parser):
pass

def _add_dataset_args(parser):
dataset_group = parser.add_argument_group("Dataset")
# TODO: Do we want to remove dataset and tokenizer?
# dataset_group.add_argument(
# "--dataset",
# type=str,
# default="OpenOrca",
# choices=["OpenOrca", "cnn_dailymail"],
# required=False,
# help="HuggingFace dataset to use for the benchmark.",
# )

dataset_group.add_argument(
"--dataset",
type=str.lower,
default=OPEN_ORCA,
choices=[OPEN_ORCA, CNN_DAILY_MAIL],
required=False,
help="HuggingFace dataset to use for benchmarking.",
)

# dataset_group.add_argument(
# "--tokenizer",
# type=str,
Expand All @@ -202,14 +226,15 @@ def parse_args(argv=None):
parser.set_defaults(func=handler)

# Conceptually group args for easier visualization
add_model_args(parser)
add_profile_args(parser)
add_endpoint_args(parser)
add_dataset_args(parser)
_add_model_args(parser)
_add_profile_args(parser)
_add_endpoint_args(parser)
_add_dataset_args(parser)

args = parser.parse_args(argv)

args = update_load_manager_args(args)
args = prune_args(args)
args = _update_load_manager_args(args)
args = _verify_valid_arg_combination(args)
args = _prune_args(args)

return args
15 changes: 12 additions & 3 deletions src/c++/perf_analyzer/genai-pa/genai_pa/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@

class Profiler:
@staticmethod
def run(model, args=None):
skip_args = ["model", "func"]
def build_cmd(model, args):
skip_args = ["model", "func", "dataset"]
if hasattr(args, "version"):
cmd = f"perf_analyzer --version"
else:
Expand All @@ -52,9 +52,18 @@ def run(model, args=None):
cmd += f"-b {value} "
else:
if len(arg) == 1:
cmd += f"-{arg} {value}"
cmd += f"-{arg} {value} "
else:
arg = utils.convert_option_name(arg)
cmd += f"--{arg} {value} "
# TODO: Once the OpenAI endpoint support is in place in PA core,
# update the input-data option arg
# cmd += f"--input-data {DEFAULT_INPUT_DATA_JSON} -p 10000 -s 99"
cmd += f"--input-data ./input_data.json -p 10000 -s 99"
return cmd

@staticmethod
def run(model, args=None):
cmd = Profiler.build_cmd(model, args)
logger.info(f"Running Perf Analyzer : '{cmd}'")
subprocess.run(cmd, shell=True, check=True)

0 comments on commit cd9e588

Please sign in to comment.