Skip to content

Commit

Permalink
Add rankings support (#724)
Browse files Browse the repository at this point in the history
* Initial support for rankings api

Add file or directory support for rankings

Update testing

* Remove commented code

* Address feedback

* Add llm_inputs ranking tests

* Update comments to include rankings in batch-size documentation

* Update error messages
  • Loading branch information
debermudez authored Jun 28, 2024
1 parent 90c60a6 commit 2d7cce0
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 28 deletions.
125 changes: 113 additions & 12 deletions src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class OutputFormat(Enum):
OPENAI_CHAT_COMPLETIONS = auto()
OPENAI_COMPLETIONS = auto()
OPENAI_EMBEDDINGS = auto()
RANKINGS = auto()
TENSORRTLLM = auto()
VLLM = auto()

Expand Down Expand Up @@ -243,21 +244,39 @@ def get_generic_dataset_json(
Dict:
The generic dataset JSON
"""

if output_format == OutputFormat.OPENAI_EMBEDDINGS:
if input_type == PromptSource.FILE:
input_filename = cast(Path, input_filename)
input_file_dataset = cls._get_input_dataset_from_embeddings_file(
input_filename,
batch_size,
num_of_output_prompts,
if input_type != PromptSource.FILE:
raise GenAIPerfException(
f"{OutputFormat.OPENAI_EMBEDDINGS.to_lowercase()} only supports a file as input."
)
generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
input_file_dataset
)
input_filename = cast(Path, input_filename)
input_file_dataset = cls._get_input_dataset_from_embeddings_file(
input_filename,
batch_size,
num_of_output_prompts,
)
generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
input_file_dataset
)
else:
raise GenAIPerfException("OpenAI embeddings only supports file input.")
)
elif output_format == OutputFormat.RANKINGS:
if input_type != PromptSource.FILE:
raise GenAIPerfException(
f"{OutputFormat.RANKINGS.to_lowercase()} only supports a directory as input."
)
queries_filename = cast(Path, input_filename) / "queries.jsonl"
passages_filename = cast(Path, input_filename) / "passages.jsonl"
input_file_dataset = cls._get_input_dataset_from_rankings_files(
queries_filename, passages_filename, batch_size, num_of_output_prompts
)

generic_dataset_json = (
cls._convert_input_synthetic_or_file_dataset_to_generic_json(
input_file_dataset
)
)
else:
if input_type == PromptSource.DATASET:
dataset = cls._get_input_dataset_from_url(
Expand Down Expand Up @@ -315,6 +334,41 @@ def _get_input_dataset_from_embeddings_file(

return dataset_json

@classmethod
def _get_input_dataset_from_rankings_files(
cls,
queries_filename: Path,
passages_filename: Path,
batch_size: int,
num_prompts: int,
) -> Dict[str, Any]:

with open(queries_filename, "r") as file:
queries_content = [json.loads(line) for line in file]
queries_texts = [item for item in queries_content]

with open(passages_filename, "r") as file:
passages_content = [json.loads(line) for line in file]
passages_texts = [item for item in passages_content]

if batch_size > len(passages_texts):
raise ValueError(
"Batch size cannot be larger than the number of available passages"
)

dataset_json: Dict[str, Any] = {}
dataset_json["features"] = [{"name": "input"}]
dataset_json["rows"] = []

for _ in range(num_prompts):
sampled_texts = random.sample(passages_texts, batch_size)
query_sample = random.choice(queries_texts)
entry_dict = {}
entry_dict["query"] = query_sample
entry_dict["passages"] = sampled_texts
dataset_json["rows"].append({"row": {"payload": entry_dict}})
return dataset_json

@classmethod
def _check_for_valid_args(
cls,
Expand Down Expand Up @@ -535,6 +589,13 @@ def _convert_generic_json_to_output_format(
model_name,
model_selection_strategy,
)
elif output_format == OutputFormat.RANKINGS:
output_json = cls._convert_generic_json_to_rankings_format(
generic_dataset,
extra_inputs,
model_name,
model_selection_strategy,
)
elif output_format == OutputFormat.VLLM:
output_json = cls._convert_generic_json_to_vllm_format(
generic_dataset,
Expand Down Expand Up @@ -672,6 +733,46 @@ def _convert_generic_json_to_openai_embeddings_format(

return pa_json

@classmethod
def _convert_generic_json_to_rankings_format(
cls,
generic_dataset: Dict,
extra_inputs: Dict,
model_name: list = [],
model_selection_strategy: ModelSelectionStrategy = ModelSelectionStrategy.ROUND_ROBIN,
) -> Dict[str, Any]:
pa_json: Dict[str, Any] = {"data": []}

for index, entry in enumerate(generic_dataset["rows"]):
iter_model_name = cls._select_model_name(
model_name, index, model_selection_strategy
)
payload = entry.get("payload", {})
query_values = payload.get("query")
passage_values = payload.get("passages")

if query_values is None:
raise ValueError("Missing required fields 'query' in dataset entry")
if passage_values is None:
raise ValueError("Missing required fields 'passages' in dataset entry")
if not isinstance(passage_values, list):
raise ValueError(
f"Required field 'query' must be a list (actual: {type(query_values)})"
)

payload = {
"query": query_values,
"passages": passage_values,
"model": iter_model_name,
}

for key, value in extra_inputs.items():
payload[key] = value

pa_json["data"].append({"payload": [payload]})

return pa_json

@classmethod
def _convert_generic_json_to_vllm_format(
cls,
Expand Down
3 changes: 2 additions & 1 deletion src/c++/perf_analyzer/genai-perf/genai_perf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def create_artifacts_dirs(args: Namespace) -> None:

def generate_inputs(args: Namespace, tokenizer: Tokenizer) -> None:
# TODO (TMA-1759): review if add_model_name is always true
input_filename = Path(args.input_file.name) if args.input_file else None
filepath, _ = args.input_file
input_filename = Path(filepath) if filepath else None
add_model_name = True
try:
extra_input_dict = parser.get_extra_inputs_as_dict(args)
Expand Down
70 changes: 60 additions & 10 deletions src/c++/perf_analyzer/genai-perf/genai_perf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import json
import os
import sys
from enum import Enum, auto
from pathlib import Path
from typing import Tuple

import genai_perf.logging as logging
import genai_perf.utils as utils
Expand All @@ -50,12 +52,22 @@

from . import __version__


class PathType(Enum):
FILE = auto()
DIRECTORY = auto()

def to_lowercase(self):
return self.name.lower()


logger = logging.getLogger(__name__)

_endpoint_type_map = {
"chat": "v1/chat/completions",
"completions": "v1/completions",
"embeddings": "v1/embeddings",
"rankings": "v1/ranking",
}


Expand Down Expand Up @@ -116,6 +128,8 @@ def _check_conditional_args(
args.output_format = OutputFormat.OPENAI_COMPLETIONS
elif args.endpoint_type == "embeddings":
args.output_format = OutputFormat.OPENAI_EMBEDDINGS
elif args.endpoint_type == "rankings":
args.output_format = OutputFormat.RANKINGS

if args.endpoint is not None:
args.endpoint = args.endpoint.lstrip(" /")
Expand Down Expand Up @@ -147,25 +161,42 @@ def _check_conditional_args(
"The --output-tokens-mean-deterministic option is only supported with the Triton service-kind."
)

_check_conditional_args_embeddings(parser, args)
_check_conditional_args_embeddings_rankings(parser, args)

return args


def _check_conditional_args_embeddings(
def _check_conditional_args_embeddings_rankings(
parser: argparse.ArgumentParser, args: argparse.Namespace
):
if args.endpoint_type == "embeddings":

if args.output_format in [
OutputFormat.OPENAI_EMBEDDINGS,
OutputFormat.RANKINGS,
]:
if args.streaming:
parser.error(
"The --streaming option is not supported with the embeddings endpoint type."
f"The --streaming option is not supported with the {args.endpoint_type} endpoint type."
)
else:
if args.batch_size != LlmInputs.DEFAULT_BATCH_SIZE:
parser.error(
"The --batch-size option is currently only supported with the embeddings endpoint type."
"The --batch-size option is currently only supported with the embeddings and rankings endpoint types."
)

if args.input_file:
_, path_type = args.input_file
if args.output_format != OutputFormat.RANKINGS:
if path_type == "directory":
parser.error(
"A directory is only currently supported for the rankings endpoint type."
)
else:
if path_type == PathType.FILE:
parser.error(
"The rankings endpoint-type requires a directory value for the --input-file flag."
)


def _check_load_manager_args(args: argparse.Namespace) -> argparse.Namespace:
"""
Expand Down Expand Up @@ -224,7 +255,12 @@ def _infer_prompt_source(args: argparse.Namespace) -> argparse.Namespace:
logger.debug(f"Input source is the following dataset: {args.input_dataset}")
elif args.input_file:
args.prompt_source = PromptSource.FILE
logger.debug(f"Input source is the following file: {args.input_file.name}")
if args.endpoint_type == "rankings":
logger.debug(
f"Input source is the following directory: {args.input_file[0]}"
)
else:
logger.debug(f"Input source is the following file: {args.input_file[0]}")
else:
args.prompt_source = PromptSource.SYNTHETIC
logger.debug("Input source is synthetic data")
Expand All @@ -241,6 +277,18 @@ def _convert_str_to_enum_entry(args, option, enum):
return args


### Types ###


def file_or_directory(path: str) -> Tuple[Path, PathType]:
if os.path.isfile(path):
return (Path(path), PathType.FILE)
elif os.path.isdir(path):
return (Path(path), PathType.DIRECTORY)
else:
raise ValueError(f"'{path}' is not a valid file or directory")


### Parsers ###


Expand All @@ -254,7 +302,7 @@ def _add_input_args(parser):
default=LlmInputs.DEFAULT_BATCH_SIZE,
required=False,
help=f"The batch size of the requests GenAI-Perf should send. "
"This is currently only supported with the embeddings endpoint type.",
"This is currently only supported with the embeddings and rankings endpoint types.",
)

input_group.add_argument(
Expand All @@ -277,12 +325,14 @@ def _add_input_args(parser):

prompt_source_group.add_argument(
"--input-file",
type=argparse.FileType("r"),
type=file_or_directory,
default=None,
required=False,
help="The input file containing the prompts to use for profiling. "
"Each line should be a JSON object with a 'text_input' field in JSONL format. "
'Example: {"text_input": "Your prompt here"}',
'Example: {"text_input": "Your prompt here"}'
"For the rankings endpoint-type, a directory should be passed in instead with "
'a "queries.jsonl" file and a "passages.jsonl" file with the same format.',
)

input_group.add_argument(
Expand Down Expand Up @@ -437,7 +487,7 @@ def _add_endpoint_args(parser):
endpoint_group.add_argument(
"--endpoint-type",
type=str,
choices=["chat", "completions", "embeddings"],
choices=["chat", "completions", "embeddings", "rankings"],
required=False,
help=f"The endpoint-type to send requests to on the "
'server. This is only used with the "openai" service-kind.',
Expand Down
Loading

0 comments on commit 2d7cce0

Please sign in to comment.