Skip to content

Commit

Permalink
Use JSON string parsing function everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
dyastremsky committed Jul 11, 2024
1 parent 9edf3d3 commit c9cfd81
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from genai_perf.exceptions import GenAIPerfException
from genai_perf.llm_inputs.synthetic_prompt_generator import SyntheticPromptGenerator
from genai_perf.tokenizer import DEFAULT_TOKENIZER, Tokenizer, get_tokenizer
from genai_perf.utils import load_json_str
from requests import Response


Expand Down Expand Up @@ -315,7 +316,7 @@ def _get_input_dataset_from_embeddings_file(
cls, input_filename: Path, batch_size: int, num_prompts: int
) -> Dict[str, Any]:
with open(input_filename, "r") as file:
file_content = [json.loads(line) for line in file]
file_content = [load_json_str(line) for line in file]

texts = [item["text"] for item in file_content]

Expand Down Expand Up @@ -344,11 +345,11 @@ def _get_input_dataset_from_rankings_files(
) -> Dict[str, Any]:

with open(queries_filename, "r") as file:
queries_content = [json.loads(line) for line in file]
queries_content = [load_json_str(line) for line in file]
queries_texts = [item for item in queries_content]

with open(passages_filename, "r") as file:
passages_content = [json.loads(line) for line in file]
passages_content = [load_json_str(line) for line in file]
passages_texts = [item for item in passages_content]

if batch_size > len(passages_texts):
Expand All @@ -363,7 +364,7 @@ def _get_input_dataset_from_rankings_files(
for _ in range(num_prompts):
sampled_texts = random.sample(passages_texts, batch_size)
query_sample = random.choice(queries_texts)
entry_dict = {}
entry_dict: Dict = {}
entry_dict["query"] = query_sample
entry_dict["passages"] = sampled_texts
dataset_json["rows"].append({"row": {"payload": entry_dict}})
Expand Down Expand Up @@ -536,7 +537,7 @@ def _get_prompts_from_input_file(cls, input_filename: Path) -> List[str]:
with open(input_filename, mode="r", newline=None) as file:
for line in file:
if line.strip():
prompts.append(json.loads(line).get("text_input", "").strip())
prompts.append(load_json_str(line).get("text_input", "").strip())
return prompts

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/c++/perf_analyzer/genai-perf/genai_perf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def get_extra_inputs_as_dict(args: argparse.Namespace) -> dict:
if args.extra_inputs:
for input_str in args.extra_inputs:
if input_str.startswith("{") and input_str.endswith("}"):
request_inputs.update(json.loads(input_str))
request_inputs.update(utils.load_json_str(input_str))
else:
semicolon_count = input_str.count(":")
if semicolon_count != 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
ResponseFormat,
)
from genai_perf.tokenizer import Tokenizer
from genai_perf.utils import remove_sse_prefix
from genai_perf.utils import load_json_str, remove_sse_prefix


class LLMProfileDataParser(ProfileDataParser):
Expand Down Expand Up @@ -178,7 +178,7 @@ def _preprocess_response(
response = res_outputs[i]["response"]
responses = response.strip().split("\n\n")
if len(responses) > 1:
merged_response = json.loads(remove_sse_prefix(responses[0]))
merged_response = load_json_str(remove_sse_prefix(responses[0]))
if (
merged_response["choices"][0]["delta"].get("content", None)
is None
Expand Down Expand Up @@ -213,7 +213,7 @@ def _get_input_token_count(self, req_inputs: dict) -> int:

def _get_openai_input_text(self, req_inputs: dict) -> str:
"""Tokenize the OpenAI request input texts."""
payload = json.loads(req_inputs["payload"])
payload = load_json_str(req_inputs["payload"])
if self._response_format == ResponseFormat.OPENAI_CHAT_COMPLETIONS:
return payload["messages"][0]["content"]
elif self._response_format == ResponseFormat.OPENAI_COMPLETIONS:
Expand Down Expand Up @@ -268,7 +268,7 @@ def _extract_openai_text_output(self, response: str) -> str:
if response == "[DONE]":
return ""

data = json.loads(response)
data = load_json_str(response)
completions = data["choices"][0]

text_output = ""
Expand Down
11 changes: 10 additions & 1 deletion src/c++/perf_analyzer/genai-perf/genai_perf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,22 @@ def load_json(filepath: Path) -> Dict[str, Any]:
with open(str(filepath), encoding="utf-8", errors="ignore") as f:
content = f.read()
try:
return json.loads(content)
return load_json_str(content)
except json.JSONDecodeError:
snippet = content[:200] + ("..." if len(content) > 200 else "")
logger.error("Failed to parse JSON string: '%s'", snippet)
raise


def load_json_str(json_str: str) -> Dict[str, Any]:
try:
return json.loads(json_str)
except json.JSONDecodeError:
snippet = json_str[:200] + ("..." if len(json_str) > 200 else "")
logger.error("Failed to parse JSON string: '%s'", snippet)
raise


def remove_file(file: Path) -> None:
if file.is_file():
file.unlink()
Expand Down

0 comments on commit c9cfd81

Please sign in to comment.