diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index ff53bd6..21bce82 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -275,6 +275,34 @@ def initialize_torch_settings( torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=init_linalg_device)) +def profiler_simple( + path_save: str = 'trace.json', +): + """ + Simple profiler for PyTorch. \n + Makes a context manager that can be used to profile code. \n + Upon exit, will save the trace to the specified path. \n + """ + from torch.profiler import profile, record_function, ProfilerActivity + from contextlib import contextmanager + + @contextmanager + def simple_profiler(path_save: str = 'trace.json'): + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=True, + with_stack=True, + # with_flops=True, + # with_modules=True, + ) as p: + with record_function("model_inference"): + yield + p.export_chrome_trace(path_save) + + return simple_profiler(path_save=path_save) + + ###################################### ############ DATA HELPERS ############