diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py index de528aac4..6fcd9372b 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/llm_inputs/llm_inputs.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import json import random from copy import deepcopy from enum import Enum, auto +from io import BytesIO from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, cast @@ -24,9 +26,53 @@ from genai_perf.exceptions import GenAIPerfException from genai_perf.llm_inputs.synthetic_prompt_generator import SyntheticPromptGenerator from genai_perf.tokenizer import DEFAULT_TOKENIZER, Tokenizer, get_tokenizer +from PIL import Image, ImageDraw from requests import Response +# (TMA-1984) Remove the dummy image input with random noise image +def make_snowman_image(): + # Create a blank image with white background + img = Image.new("RGB", (600, 800), color="skyblue") + d = ImageDraw.Draw(img) + + # Draw the snowman's body (three circles) + body_color = "white" + d.ellipse([200, 500, 400, 700], fill=body_color, outline="black") # Bottom circle + d.ellipse([225, 350, 375, 550], fill=body_color, outline="black") # Middle circle + d.ellipse([250, 200, 350, 400], fill=body_color, outline="black") # Head circle + + # Draw the snowman's eyes + eye_color = "black" + d.ellipse([275, 250, 285, 260], fill=eye_color) # Left eye + d.ellipse([315, 250, 325, 260], fill=eye_color) # Right eye + + # Draw the snowman's nose (carrot) + nose_color = "orange" + d.polygon([(300, 270), (300, 280), (340, 275)], fill=nose_color) # Nose + + # Draw the snowman's mouth (smile) + mouth_color = "black" + d.arc([275, 290, 325, 310], start=0, end=180, fill=mouth_color) # Smile + + # Draw the snowman's buttons + d.ellipse([290, 420, 310, 440], fill=eye_color) # Top button + d.ellipse([290, 460, 310, 480], fill=eye_color) # Middle button + d.ellipse([290, 500, 310, 520], fill=eye_color) # Bottom button + + # Draw the snowman's arms + arm_color = "brown" + d.line([225, 450, 150, 400], fill=arm_color, width=5) # Left arm + d.line([375, 450, 450, 400], fill=arm_color, width=5) # Right arm + + return img + + +class ImageFormat(Enum): + PNG = auto() + JPEG = auto() + + class ModelSelectionStrategy(Enum): ROUND_ROBIN = auto() RANDOM = auto() @@ -42,6 +88,7 @@ class OutputFormat(Enum): OPENAI_CHAT_COMPLETIONS = auto() OPENAI_COMPLETIONS = auto() OPENAI_EMBEDDINGS = auto() + OPENAI_VISION = auto() RANKINGS = auto() TENSORRTLLM = auto() VLLM = auto() @@ -308,6 +355,12 @@ def get_generic_dataset_json( else: raise GenAIPerfException("Input source is not recognized.") + if output_format == OutputFormat.OPENAI_VISION: + snowman_image = make_snowman_image() + generic_dataset_json = cls._add_images_to_generic_json( + generic_dataset_json, snowman_image + ) + return generic_dataset_json @classmethod @@ -544,6 +597,35 @@ def verify_file(cls, input_filename: Path) -> None: if not input_filename.exists(): raise FileNotFoundError(f"The file '{input_filename}' does not exist.") + @classmethod + def _add_images_to_generic_json( + cls, generic_dataset_json: Dict[str, List[Dict]], img: Image + ) -> Dict[str, List[Dict]]: + # (TMA-1985) Support multiple image formats + img_format = ImageFormat.PNG + img_base64 = cls._encode_image(img, img_format) + for row in generic_dataset_json["rows"]: + if isinstance(row["text_input"], str): + row["text_input"] = [ + { + "type": "text", + "text": row["text_input"], + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_base64}"}, + }, + ] + + return generic_dataset_json + + @classmethod + def _encode_image(cls, img: Image, format=ImageFormat.PNG): + """Encodes an image into base64 encoding.""" + buffered = BytesIO() + img.save(buffered, format=format.name) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + @classmethod def _convert_generic_json_to_output_format( cls, @@ -558,7 +640,10 @@ def _convert_generic_json_to_output_format( model_name: list = [], model_selection_strategy: ModelSelectionStrategy = ModelSelectionStrategy.ROUND_ROBIN, ) -> Dict: - if output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS: + if ( + output_format == OutputFormat.OPENAI_CHAT_COMPLETIONS + or output_format == OutputFormat.OPENAI_VISION + ): output_json = cls._convert_generic_json_to_openai_chat_completions_format( generic_dataset, add_model_name, diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py index 64178fd4c..2fcea8eb8 100644 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/parser.py @@ -68,6 +68,7 @@ def to_lowercase(self): "completions": "v1/completions", "embeddings": "v1/embeddings", "rankings": "v1/ranking", + "vision": "v1/chat/completions", } @@ -131,6 +132,11 @@ def _check_conditional_args( elif args.endpoint_type == "rankings": args.output_format = OutputFormat.RANKINGS + # (TMA-1986) deduce vision format from chat completions + image CLI + # because there's no openai vision endpoint. + elif args.endpoint_type == "vision": + args.output_format = OutputFormat.OPENAI_VISION + if args.endpoint is not None: args.endpoint = args.endpoint.lstrip(" /") else: @@ -492,7 +498,7 @@ def _add_endpoint_args(parser): endpoint_group.add_argument( "--endpoint-type", type=str, - choices=["chat", "completions", "embeddings", "rankings"], + choices=["chat", "completions", "embeddings", "rankings", "vision"], required=False, help=f"The endpoint-type to send requests to on the " 'server. This is only used with the "openai" service-kind.', diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py index cbb2da5ee..b32120deb 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py @@ -218,6 +218,9 @@ def _get_openai_input_text(self, req_inputs: dict) -> str: return payload["messages"][0]["content"] elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS: return payload["prompt"] + elif self._response_format == ResponseFormat.OPENAI_VISION: + content = payload["messages"][0]["content"] + return " ".join(c["text"] for c in content if c["type"] == "text") else: raise ValueError( "Failed to parse OpenAI request input in profile export file." diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py index d18d8f6fb..74eb48a23 100755 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py @@ -39,6 +39,7 @@ class ResponseFormat(Enum): OPENAI_CHAT_COMPLETIONS = auto() OPENAI_COMPLETIONS = auto() OPENAI_EMBEDDINGS = auto() + OPENAI_VISION = auto() RANKINGS = auto() TRITON = auto() @@ -59,7 +60,15 @@ def _get_profile_metadata(self, data: dict) -> None: if data["endpoint"] == "rerank": self._response_format = ResponseFormat.HUGGINGFACE_RANKINGS elif data["endpoint"] == "v1/chat/completions": - self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + # (TPA-66) add PA metadata to deduce the response format instead + # of parsing the request input payload in profile export json + # file. + request = data["experiments"][0]["requests"][0] + request_input = request["request_inputs"]["payload"] + if "image_url" in request_input: + self._response_format = ResponseFormat.OPENAI_VISION + else: + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS elif data["endpoint"] == "v1/completions": self._response_format = ResponseFormat.OPENAI_COMPLETIONS elif data["endpoint"] == "v1/embeddings": @@ -67,13 +76,17 @@ def _get_profile_metadata(self, data: dict) -> None: elif data["endpoint"] == "v1/ranking": self._response_format = ResponseFormat.RANKINGS else: - # TPA-66: add PA metadata to handle this case + # (TPA-66) add PA metadata to handle this case # When endpoint field is either empty or custom endpoint, fall # back to parsing the response to extract the response format. request = data["experiments"][0]["requests"][0] + request_input = request["request_inputs"]["payload"] response = request["response_outputs"][0]["response"] if "chat.completion" in response: - self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS + if "image_url" in request_input: + self._response_format = ResponseFormat.OPENAI_VISION + else: + self._response_format = ResponseFormat.OPENAI_CHAT_COMPLETIONS elif "text_completion" in response: self._response_format = ResponseFormat.OPENAI_COMPLETIONS elif "embedding" in response: diff --git a/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py b/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py deleted file mode 100644 index 3cc2999f5..000000000 --- a/src/c++/perf_analyzer/genai-perf/genai_perf/test_end_to_end.py +++ /dev/null @@ -1,92 +0,0 @@ -import itertools -import os -import subprocess -import sys - -# How to run: -# test_end_to_end.py -# Where target is "nim_chat" or "nim_completions" or "vllm_openai" or "triton_tensorrtllm" -# -# For all cases but vllm_openai, it assumes that the server will be on port 9999 -# -# This script will run a sweep of all combinations of values in the testing matrix -# by appending those options on to the genai-pa base command -# - - -testing_matrix = [ - ["--concurrency 1", "--concurrency 32", "--request-rate 1", "--request-rate 32"], - ["--streaming", ""], -] - -base_commands = { - "nim_chat": "genai-perf -s 999 -p 20000 -m llama-2-7b-chat -u http://localhost:9999 --service-kind openai --endpoint-type chat", - "nim_completions": "genai-perf -s 999 -p 20000 -m llama-2-7b -u http://localhost:9999 --service-kind openai --endpoint-type completions", - "vllm_openai": "genai-perf -s 999 -p 20000 -m mistralai/Mistral-7B-v0.1 --service-kind openai --endpoint-type chat", - "triton_tensorrtllm": "genai-perf -s 999 -p 20000 -m llama-2-7b -u 0.0.0.0:9999 --service-kind triton --backend tensorrtllm", - "triton_vllm": "genai-perf -s 999 -p 20000 -m gpt2_vllm --service-kind triton --backend vllm", -} -testname = "" - -if len(sys.argv) == 2: - # The second element in sys.argv is the input string - testname = sys.argv[1] -else: - options = " ".join(base_commands.keys()) - print(f"This script requires exactly one argument. It must be one of {options}") - exit(1) - -base_command = base_commands[testname] - - -def rename_files(files: list, substr: str) -> None: - for f in files: - name, ext = f.rsplit(".", 1) - # Insert the substring and reassemble the filename - new_filename = f"{testname}__{name}__{substr}.{ext}" - try: - os.rename(f, new_filename) - except FileNotFoundError: - # Just ignore the error, since if PA failed these files may not exist - pass - - -def print_summary(): - # FIXME -- print out a few basic metrics. Maybe from the csv? - pass - - -def sanity_check(): - # FIXME -- add in some sanity checking? Throughput isn't 0? - pass - - -# Loop through all combinations -for combination in itertools.product(*testing_matrix): - options_string = " ".join(combination) - command_with_options = f"{base_command} {options_string}" - command_array = command_with_options.split() - - file_options_string = "__".join(combination) - file_options_string = file_options_string.replace(" ", "") - file_options_string = file_options_string.replace("-", "") - output_file = testname + "__" + file_options_string + ".log" - - with open(output_file, "w") as outfile: - print(f"\nCMD: {command_with_options}") - print(f" Output log is {output_file}") - proc = subprocess.run(command_array, stdout=outfile, stderr=subprocess.STDOUT) - - if proc.returncode != 0: - print(f" Command failed with return code: {proc.returncode}") - else: - print(f" Command executed successfully!") - print_summary() - sanity_check() - - files = [ - "profile_export.json", - "profile_export_genai_pa.csv", - "llm_inputs.json", - ] - rename_files(files, file_options_string) diff --git a/src/c++/perf_analyzer/genai-perf/pyproject.toml b/src/c++/perf_analyzer/genai-perf/pyproject.toml index 982ee24b7..68d5e3740 100644 --- a/src/c++/perf_analyzer/genai-perf/pyproject.toml +++ b/src/c++/perf_analyzer/genai-perf/pyproject.toml @@ -59,6 +59,7 @@ dependencies = [ "pytest-mock", "pyyaml", "responses", + "pillow", ] # CLI Entrypoint diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py index c6351918e..e989224d1 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_inputs.py @@ -29,6 +29,7 @@ ModelSelectionStrategy, OutputFormat, PromptSource, + make_snowman_image, ) from genai_perf.tokenizer import Tokenizer @@ -78,6 +79,7 @@ class TestLlmInputs: ("triton", "tensorrtllm", OutputFormat.TENSORRTLLM), ("openai", "v1/completions", OutputFormat.OPENAI_COMPLETIONS), ("openai", "v1/chat/completions", OutputFormat.OPENAI_CHAT_COMPLETIONS), + ("openai", "v1/chat/completions", OutputFormat.OPENAI_VISION), ] @pytest.fixture @@ -550,6 +552,42 @@ def test_llm_inputs_with_defaults(self, default_configured_url): # else: # assert False, f"Unsupported output format: {output_format}" + def test_add_image_inputs_openai_vision(self) -> None: + generic_json = { + "rows": [ + {"text_input": "test input one"}, + {"text_input": "test input two"}, + ] + } + img = make_snowman_image() + encoded_img = LlmInputs._encode_image(img) + + generic_json = LlmInputs._add_images_to_generic_json(generic_json, img) + + row1 = generic_json["rows"][0]["text_input"] + assert row1 == [ + { + "type": "text", + "text": "test input one", + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{encoded_img}"}, + }, + ] + + row2 = generic_json["rows"][1]["text_input"] + assert row2 == [ + { + "type": "text", + "text": "test input two", + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{encoded_img}"}, + }, + ] + # def test_trtllm_default_max_tokens(self, default_tokenizer: Tokenizer) -> None: # input_name = "max_tokens" # input_value = 256 diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py index 05de5b122..689e366cd 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_metrics.py @@ -69,6 +69,7 @@ def test_llm_metric_system_metrics(self) -> None: output_sequence_lengths=[3, 4], input_sequence_lengths=[12, 34], ) + sys_metrics = m.system_metrics assert len(sys_metrics) == 2 assert sys_metrics[0].name == "output_token_throughput" diff --git a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py index 75976189d..d776a6a85 100644 --- a/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py +++ b/src/c++/perf_analyzer/genai-perf/tests/test_llm_profile_data_parser.py @@ -71,6 +71,9 @@ def write(self: Any, content: str) -> int: elif filename == "openai_profile_export.json": tmp_file = StringIO(json.dumps(self.openai_profile_data)) return tmp_file + elif filename == "openai_vlm_profile_export.json": + tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data)) + return tmp_file elif filename == "empty_profile_export.json": tmp_file = StringIO(json.dumps(self.empty_profile_data)) return tmp_file @@ -322,6 +325,91 @@ def test_openai_llm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N with pytest.raises(KeyError): pd.get_statistics(infer_mode="concurrency", load_level="40") + def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> None: + """Collect LLM metrics from profile export data and check values. + + Metrics + * time to first tokens + - experiment 1: [5 - 1, 7 - 2] = [4, 5] + * inter token latencies + - experiment 1: [((12 - 1) - 4)/(3 - 1), ((15 - 2) - 5)/(6 - 1)] + : [3.5, 1.6] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(12 - 1), 6/(15 - 2)] = [3/11, 6/13] + * output token throughputs + - experiment 1: [(3 + 6)/(15 - 1)] = [9/14] + * output sequence lengths + - experiment 1: [3, 6] + * input sequence lengths + - experiment 1: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("openai_vlm_profile_export.json"), + tokenizer=tokenizer, + ) + + # experiment 1 statistics + stat_obj = pd.get_statistics(infer_mode="concurrency", load_level="10") + metrics = stat_obj.metrics + stat = stat_obj.stats_dict + assert isinstance(metrics, LLMMetrics) + + assert metrics.time_to_first_tokens == [4, 5] + assert metrics.inter_token_latencies == [4, 2] + ottpr = [3 / ns_to_sec(11), 6 / ns_to_sec(13)] + assert metrics.output_token_throughputs_per_request == pytest.approx(ottpr) + ott = [9 / ns_to_sec(14)] + assert metrics.output_token_throughputs == pytest.approx(ott) + assert metrics.output_sequence_lengths == [3, 6] + assert metrics.input_sequence_lengths == [3, 4] + + assert stat["time_to_first_token"]["avg"] == pytest.approx(4.5) # type: ignore + assert stat["inter_token_latency"]["avg"] == pytest.approx(3) # type: ignore + assert stat["output_token_throughput_per_request"]["avg"] == pytest.approx( # type: ignore + np.mean(ottpr) + ) + assert stat["output_sequence_length"]["avg"] == 4.5 # type: ignore + assert stat["input_sequence_length"]["avg"] == 3.5 # type: ignore + + assert stat["time_to_first_token"]["p50"] == pytest.approx(4.5) # type: ignore + assert stat["inter_token_latency"]["p50"] == pytest.approx(3) # type: ignore + assert stat["output_token_throughput_per_request"]["p50"] == pytest.approx( # type: ignore + np.percentile(ottpr, 50) + ) + assert stat["output_sequence_length"]["p50"] == 4.5 # type: ignore + assert stat["input_sequence_length"]["p50"] == 3.5 # type: ignore + + assert stat["time_to_first_token"]["min"] == pytest.approx(4) # type: ignore + assert stat["inter_token_latency"]["min"] == pytest.approx(2) # type: ignore + min_ottpr = 3 / ns_to_sec(11) + assert stat["output_token_throughput_per_request"]["min"] == pytest.approx(min_ottpr) # type: ignore + assert stat["output_sequence_length"]["min"] == 3 # type: ignore + assert stat["input_sequence_length"]["min"] == 3 # type: ignore + + assert stat["time_to_first_token"]["max"] == pytest.approx(5) # type: ignore + assert stat["inter_token_latency"]["max"] == pytest.approx(4) # type: ignore + max_ottpr = 6 / ns_to_sec(13) + assert stat["output_token_throughput_per_request"]["max"] == pytest.approx(max_ottpr) # type: ignore + assert stat["output_sequence_length"]["max"] == 6 # type: ignore + assert stat["input_sequence_length"]["max"] == 4 # type: ignore + + assert stat["time_to_first_token"]["std"] == np.std([4, 5]) * (1) # type: ignore + assert stat["inter_token_latency"]["std"] == np.std([4, 2]) * (1) # type: ignore + assert stat["output_token_throughput_per_request"]["std"] == pytest.approx( # type: ignore + np.std(ottpr) + ) + assert stat["output_sequence_length"]["std"] == np.std([3, 6]) # type: ignore + assert stat["input_sequence_length"]["std"] == np.std([3, 4]) # type: ignore + + oott = 9 / ns_to_sec(14) + assert stat["output_token_throughput"]["avg"] == pytest.approx(oott) # type: ignore + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="40") + def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: """Test merging the multiple sse response.""" res_timestamps = [0, 1, 2, 3] @@ -522,6 +610,73 @@ def test_empty_response(self, mock_read_write: pytest.MonkeyPatch) -> None: ], } + openai_vlm_profile_data = { + "service_kind": "openai", + "endpoint": "v1/chat/completions", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}', + }, + # the first, and the last two responses will be ignored because they have no "content" + "response_timestamps": [3, 5, 8, 12, 13, 14], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" like"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":" dogs"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + { + "timestamp": 2, + "request_inputs": { + "payload": '{"messages":[{"role":"user","content":[{"type":"text","text":"This is test too"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abcdef"}}]}],"model":"llava-1.6","stream":true}', + }, + # the first, and the last two responses will be ignored because they have no "content" + "response_timestamps": [4, 7, 11, 15, 18, 19], + "response_outputs": [ + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"I"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"don\'t"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"cook food"},"finish_reason":null}]}\n\n' + }, + { + "response": 'data: {"id":"abc","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":null}]}\n\n' + }, + {"response": "data: [DONE]\n\n"}, + ], + }, + ], + }, + ], + } + triton_profile_data = { "service_kind": "triton", "endpoint": "",