Skip to content

Commit

Permalink
Add simple profiler for PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 6, 2024
1 parent 22c7c2c commit 0daff5f
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,34 @@ def initialize_torch_settings(
torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=init_linalg_device))


def profiler_simple(
path_save: str = 'trace.json',
):
"""
Simple profiler for PyTorch. \n
Makes a context manager that can be used to profile code. \n
Upon exit, will save the trace to the specified path. \n
"""
from torch.profiler import profile, record_function, ProfilerActivity
from contextlib import contextmanager

@contextmanager
def simple_profiler(path_save: str = 'trace.json'):
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
# with_flops=True,
# with_modules=True,
) as p:
with record_function("model_inference"):
yield
p.export_chrome_trace(path_save)

return simple_profiler(path_save=path_save)



######################################
############ DATA HELPERS ############
Expand Down

0 comments on commit 0daff5f

Please sign in to comment.