Skip to content

Commit

Permalink
Fixes + all unit testing completed
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Sep 11, 2024
1 parent b0a2244 commit 52c1bfe
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 197 deletions.
16 changes: 7 additions & 9 deletions genai-perf/genai_perf/measurements/model_config_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass
from functools import total_ordering
from statistics import mean
from typing import Any, Dict, List, Optional, TypeAlias
from typing import Any, Dict, Optional, TypeAlias

from genai_perf.record.record import Record

Expand Down Expand Up @@ -121,7 +121,7 @@ def _read_perf_metrics_from_checkpoint(
) -> Records:
perf_metrics: Records = {}

for [tag, record_dict] in perf_metrics_dict.items():
for [tag, record_dict] in perf_metrics_dict.values():
record = Record.get(tag)
record = record.read_from_checkpoint(record_dict) # type: ignore
perf_metrics[tag] = record # type: ignore
Expand Down Expand Up @@ -149,13 +149,11 @@ def __gt__(self, other: "ModelConfigMeasurement") -> bool:
== ModelConfigMeasurementDefaults.SELF_IS_BETTER
)

# TODO: OPTIMIZE
# Why is mypy complaining about this?
# def __eq__(self, other: "ModelConfigMeasurement") -> bool:
# return (
# self._compare_measurements(other)
# == ModelConfigMeasurementDefaults.EQUALIVILENT
# )
def __eq__(self, other: "ModelConfigMeasurement") -> bool: # type: ignore
return (
self._compare_measurements(other)
== ModelConfigMeasurementDefaults.EQUALIVILENT
)

def _compare_measurements(self, other: "ModelConfigMeasurement") -> int:
"""
Expand Down
2 changes: 1 addition & 1 deletion genai-perf/genai_perf/record/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def tag(self) -> str:
the name tag of the record type.
"""

def to_dict(self):
def write_to_checkpoint(self):
return (self.tag, self.__dict__)

@classmethod
Expand Down
11 changes: 11 additions & 0 deletions genai-perf/genai_perf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,14 @@ def get_enum_entry(name: str, enum: Type[Enum]) -> Optional[Enum]:

def scale(value, factor):
return value * factor


# FIXME: OPTIMIZE
# This will move to the checkpoint class when it's created
def checkpoint_encoder(obj):
if isinstance(obj, bytes):
return obj.decode("utf-8")
elif hasattr(obj, "write_to_checkpoint"):
return obj.write_to_checkpoint()
else:
return obj.__dict__
Loading

0 comments on commit 52c1bfe

Please sign in to comment.