Skip to content

Commit

Permalink
Add overall report based on total time taken
Browse files Browse the repository at this point in the history
  • Loading branch information
xeon27 committed Mar 22, 2024
1 parent b21325b commit 9e5faff
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions kronfluence/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,47 @@ def _trace_handler(self, p) -> None:
sort_by_metric = "self_cpu_time_total" if is_cpu else "self_cuda_time_total"
output = p.key_averages().table(sort_by=sort_by_metric, row_limit=10)
self.trace_outputs.append(output)
total_time = p.key_averages().self_cpu_time_total if is_cpu else p.key_averages().self_cuda_time_total
total_time = total_time * 10**(-6) # Convert from us to s
self.recorded_durations[self.actions[-1]].append(total_time)

def _reset_output(self) -> None:
"""Resets actions and outputs list."""
self.actions = []
self.trace_outputs = []

def _high_level_summary(self) -> str:
"""Returns a formatted high level summary for the PyTorch Profiler."""
sep = os.linesep
output_string = "Overall PyTorch Profiler Report:"

if len(self.recorded_durations) > 0:
max_key = max(len(k) for k in self.recorded_durations.keys())

def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str:
row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|"
row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
return row

header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-' * output_string_len}"
output_string += sep_lines + header_string + sep_lines
report_extended, total_calls, total_duration = self._make_report()
output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
output_string += sep_lines
for action, mean_duration, num_calls, total_duration, duration_per in report_extended:
output_string += log_row(
action,
f"{mean_duration:.5}",
f"{num_calls}",
f"{total_duration:.5}",
f"{duration_per:.5}",
)
output_string += sep_lines
output_string += sep
return output_string

def summary(self) -> str:
"""Returns a formatted summary for the PyTorch Profiler."""
assert len(self.actions) == len(self.trace_outputs), \
Expand All @@ -225,10 +260,12 @@ def summary(self) -> str:
prof_prefix = "Profiler Summary for Action"
no_summary_str = "*** No Summary returned from PyTorch Profiler ***"
outputs = [no_summary_str if elm=="" else elm for elm in self.trace_outputs]
summary = [f"\n{prof_prefix}: {elm[0]}\n{elm[1]}" for elm in zip(self.actions, outputs)]
summary = "\n".join([f"\n{prof_prefix}: {elm[0]}\n{elm[1]}" for elm in zip(self.actions, outputs)])
# Append overall action level summary
summary = f"{summary}\n\n{self._high_level_summary()}"
# Reset actions and outputs once summary is invoked
self._reset_output()
return "\n".join(summary)
return summary


# Timing utilities copied from
Expand Down

0 comments on commit 9e5faff

Please sign in to comment.