diff --git a/optimum_benchmark/import_utils.py b/optimum_benchmark/import_utils.py index 09a2a08d..7fdff6ef 100644 --- a/optimum_benchmark/import_utils.py +++ b/optimum_benchmark/import_utils.py @@ -246,16 +246,16 @@ def get_hf_libs_info(): return { "optimum_benchmark_version": optimum_benchmark_version(), "optimum_benchmark_commit": get_git_revision_hash("optimum_benchmark"), - "transformers_version": transformers_version(), + "transformers_version": transformers_version() if is_transformers_available() else None, "transformers_commit": get_git_revision_hash("transformers"), - "accelerate_version": accelerate_version(), + "accelerate_version": accelerate_version() if is_accelerate_available else None, "accelerate_commit": get_git_revision_hash("accelerate"), - "diffusers_version": diffusers_version(), + "diffusers_version": diffusers_version() if is_diffusers_available() else None, "diffusers_commit": get_git_revision_hash("diffusers"), - "optimum_version": optimum_version(), + "optimum_version": optimum_version() if is_optimum_available() else None, "optimum_commit": get_git_revision_hash("optimum"), - "timm_version": timm_version(), + "timm_version": timm_version() if is_timm_available() else None, "timm_commit": get_git_revision_hash("timm"), - "peft_version": peft_version(), + "peft_version": peft_version() if is_peft_available() else None, "peft_commit": get_git_revision_hash("peft"), } diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 1f74a377..6b8d614f 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -264,7 +264,7 @@ def __init__(self, device: str, backend: str): self.start_time: Optional[float] = None self.prefilled: Optional[bool] = None - self.per_token_events: List[Union[float, torch.cuda.Event]] = [] + self.per_token_events: List[List[Union[float, torch.cuda.Event]]] = [] self.prefill_start_events: List[Union[float, torch.cuda.Event]] = [] self.prefill_end_events: List[Union[float, torch.cuda.Event]] = [] self.decode_start_events: List[Union[float, torch.cuda.Event]] = [] @@ -282,6 +282,9 @@ def reset(self): @contextmanager def track(self): + self.prefilled = False + self.per_token_events.append([]) + if self.is_distributed: torch.distributed.barrier() @@ -291,14 +294,10 @@ def track(self): else: self.prefill_start_events.append(time.perf_counter()) - self.prefilled = False - # this is where generate is called, # and for each decoded token, we record an event yield - self.prefilled = None - if self.is_asynchronous: self.decode_end_events.append(torch.cuda.Event(enable_timing=True)) self.decode_end_events[-1].record() @@ -308,6 +307,8 @@ def track(self): if self.is_distributed: torch.distributed.barrier() + self.prefilled = False + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): assert ( self.prefilled is not None @@ -319,13 +320,13 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): else: event = time.perf_counter() - self.per_token_events.append(event) - if not self.prefilled: self.prefill_end_events.append(event) self.decode_start_events.append(event) self.prefilled = True + self.per_token_events[-1].append(event) + return scores def get_prefill_latency(self) -> Latency: @@ -368,13 +369,15 @@ def get_per_token_latency(self) -> Latency: torch.cuda.synchronize() latencies_list = [ - self.per_token_events[i].elapsed_time(self.per_token_events[i + 1]) / 1e3 - for i in range(0, len(self.per_token_events) - 1) + self.per_token_events[i][j].elapsed_time(self.per_token_events[i][j + 1]) / 1e3 + for i in range(len(self.per_token_events)) + for j in range(0, len(self.per_token_events[i]) - 1) ] else: latencies_list = [ - (self.per_token_events[i + 1] - self.per_token_events[i]) - for i in range(0, len(self.per_token_events) - 1) + (self.per_token_events[i][j + 1] - self.per_token_events[i][j]) + for i in range(len(self.per_token_events)) + for j in range(0, len(self.per_token_events[i]) - 1) ] assert not any(latency < 0 for latency in latencies_list), "Negative latency detected"