Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xeon27 committed Mar 22, 2024
1 parent 9e5faff commit bce44f6
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions kronfluence/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ def start(self, action_name: str) -> None:
return
if action_name in self.current_actions:
raise ValueError(f"Attempted to start {action_name} which has already started.")
self.current_actions[action_name] = 0.0 # Dummy value
# Set dummy value, since only used to track duplicate actions
self.current_actions[action_name] = 0.0
self.actions.append(action_name)
self._torch_prof.start()

def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""
"""Defines how to stop recording an action."""
if not self.state.is_main_process:
return
if action_name not in self.current_actions:
Expand All @@ -206,13 +207,17 @@ def _set_up_torch_profiler(self) -> None:

def _trace_handler(self, p) -> None:
"""Adds the PyTorch Profiler trace output to a list once it is ready."""
# Set metric to sort by based on device
# Set metric to sort based on device
is_cpu = (self.state.device == torch.device("cpu"))
sort_by_metric = "self_cpu_time_total" if is_cpu else "self_cuda_time_total"

# Obtain formatted output from profiler
output = p.key_averages().table(sort_by=sort_by_metric, row_limit=10)
self.trace_outputs.append(output)

# Obtain total time taken for the action
total_time = p.key_averages().self_cpu_time_total if is_cpu else p.key_averages().self_cuda_time_total
total_time = total_time * 10**(-6) # Convert from us to s
total_time = total_time * 10**(-6) # Convert from micro sec to sec
self.recorded_durations[self.actions[-1]].append(total_time)

def _reset_output(self) -> None:
Expand Down Expand Up @@ -255,10 +260,11 @@ def log_row(action: str, mean: str, num_calls: str, total: str, per: str) -> str
def summary(self) -> str:
"""Returns a formatted summary for the PyTorch Profiler."""
assert len(self.actions) == len(self.trace_outputs), \
f"Mismatch in the number of actions and profiles collected: " + \
f"# Actions: {len(self.actions)}, # Profiles: {len(self.trace_outputs)}"
f"Mismatch in the number of actions and outputs collected: " + \
f"# Actions: {len(self.actions)}, # Ouptuts: {len(self.trace_outputs)}"
prof_prefix = "Profiler Summary for Action"
no_summary_str = "*** No Summary returned from PyTorch Profiler ***"
# Consolidate detailed summary
outputs = [no_summary_str if elm=="" else elm for elm in self.trace_outputs]
summary = "\n".join([f"\n{prof_prefix}: {elm[0]}\n{elm[1]}" for elm in zip(self.actions, outputs)])
# Append overall action level summary
Expand Down

0 comments on commit bce44f6

Please sign in to comment.