Skip to content

Commit

Permalink
Fix profile options argument: allow disabling selected options
Browse files Browse the repository at this point in the history
Summary:
After enabling detailed profiling by default in D49785828, --profile-options becomes unusable because the logic of --no-profile-detailed is all-or-nothing.
Fix this by changing --profile-options to --disable-profile-options.

Most of the changes are auto formatting. (I marked the logic changes with diff comments below)

Reviewed By: davidberard98

Differential Revision: D50524522

fbshipit-source-id: a74624c6f9b0d1619804161f7c132687baf82fd4
  • Loading branch information
YuqingJ authored and facebook-github-bot committed Oct 23, 2023
1 parent 2f9b20e commit bd0e8d5
Showing 1 changed file with 175 additions and 56 deletions.
231 changes: 175 additions & 56 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def run_one_step_with_cudastreams(func, streamcount):
for i in range(1, streamcount + 1, 1):

# create additional streams and prime with load
while len(streamlist) < i :
while len(streamlist) < i:
s = torch.cuda.Stream()
streamlist.append(s)

Expand All @@ -80,50 +80,110 @@ def run_one_step_with_cudastreams(func, streamcount):
torch.cuda.synchronize()

print(f"Cuda StreamCount:{len(streamlist)}")
print('{:<20} {:>20}'.format("GPU Time:", "%.3f milliseconds" % start_event.elapsed_time(end_event)), sep='')
print(
"{:<20} {:>20}".format(
"GPU Time:", "%.3f milliseconds" % start_event.elapsed_time(end_event)
),
sep="",
)


def printResultSummaryTime(result_summary, model, metrics_needed=[], flops_model_analyzer=None, model_flops=None, cpu_peak_mem=None, mem_device_id=None, gpu_peak_mem=None):
assert (model is not None), "model can not be None."
def printResultSummaryTime(
result_summary,
model,
metrics_needed=[],
flops_model_analyzer=None,
model_flops=None,
cpu_peak_mem=None,
mem_device_id=None,
gpu_peak_mem=None,
):
assert model is not None, "model can not be None."
if args.device == "cuda":
gpu_time = np.median(list(map(lambda x: x[0], result_summary)))
cpu_walltime = np.median(list(map(lambda x: x[1], result_summary)))
print('{:<20} {:>20}'.format("GPU Time per batch:", "%.3f milliseconds" %
(gpu_time / model.num_batch), sep=''))
print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" %
(cpu_walltime / model.num_batch), sep=''))
print(
"{:<20} {:>20}".format(
"GPU Time per batch:",
"%.3f milliseconds" % (gpu_time / model.num_batch),
sep="",
)
)
print(
"{:<20} {:>20}".format(
"CPU Wall Time per batch:",
"%.3f milliseconds" % (cpu_walltime / model.num_batch),
sep="",
)
)
else:
cpu_walltime = np.median(list(map(lambda x: x[0], result_summary)))
print('{:<20} {:>20}'.format("CPU Wall Time per batch:", "%.3f milliseconds" % (cpu_walltime / model.num_batch), sep=''))
print(
"{:<20} {:>20}".format(
"CPU Wall Time per batch:",
"%.3f milliseconds" % (cpu_walltime / model.num_batch),
sep="",
)
)
# if model_flops is not None, output the TFLOPs per sec
if 'flops' in metrics_needed:
if flops_model_analyzer.metrics_backend_mapping['flops'] == 'dcgm':
if "flops" in metrics_needed:
if flops_model_analyzer.metrics_backend_mapping["flops"] == "dcgm":
tflops_device_id, tflops = flops_model_analyzer.calculate_flops()
else:
flops = model.get_flops()
tflops = flops / (cpu_walltime / 1.0e3) / 1.0e12
print('{:<20} {:>20}'.format("GPU FLOPS:", "%.4f TFLOPs per second" % tflops, sep=''))
if 'ttfb' in metrics_needed:
print('{:<20} {:>20}'.format("Time to first batch:", "%.4f ms" % model.ttfb, sep=''))
print(
"{:<20} {:>20}".format(
"GPU FLOPS:", "%.4f TFLOPs per second" % tflops, sep=""
)
)
if "ttfb" in metrics_needed:
print(
"{:<20} {:>20}".format(
"Time to first batch:", "%.4f ms" % model.ttfb, sep=""
)
)
if model_flops is not None:
tflops = model_flops / (cpu_walltime / 1.0e3) / 1.0e12
print('{:<20} {:>20}'.format("Model Flops:", "%.4f TFLOPs per second" % tflops, sep=''))
print(
"{:<20} {:>20}".format(
"Model Flops:", "%.4f TFLOPs per second" % tflops, sep=""
)
)
if gpu_peak_mem is not None:
print('{:<20} {:>20}'.format("GPU %d Peak Memory:" % mem_device_id, "%.4f GB" % gpu_peak_mem, sep=''))
print(
"{:<20} {:>20}".format(
"GPU %d Peak Memory:" % mem_device_id, "%.4f GB" % gpu_peak_mem, sep=""
)
)
if cpu_peak_mem is not None:
print('{:<20} {:>20}'.format("CPU Peak Memory:", "%.4f GB" % cpu_peak_mem, sep=''))
print(
"{:<20} {:>20}".format("CPU Peak Memory:", "%.4f GB" % cpu_peak_mem, sep="")
)


def run_one_step(func, model, nwarmup=WARMUP_ROUNDS, num_iter=10, export_metrics_file=None, stress=0, metrics_needed=[], metrics_gpu_backend=None):
def run_one_step(
func,
model,
nwarmup=WARMUP_ROUNDS,
num_iter=10,
export_metrics_file=None,
stress=0,
metrics_needed=[],
metrics_gpu_backend=None,
):
# Warm-up `nwarmup` rounds
for _i in range(nwarmup):
func()

result_summary = []
flops_model_analyzer = None
if 'flops' in metrics_needed:
if "flops" in metrics_needed:
from components.model_analyzer.TorchBenchAnalyzer import ModelAnalyzer
flops_model_analyzer = ModelAnalyzer(export_metrics_file, ['flops'], metrics_gpu_backend)

flops_model_analyzer = ModelAnalyzer(
export_metrics_file, ["flops"], metrics_gpu_backend
)
flops_model_analyzer.start_monitor()

if stress:
Expand All @@ -135,7 +195,7 @@ def run_one_step(func, model, nwarmup=WARMUP_ROUNDS, num_iter=10, export_metrics
_i = 0
last_it = 0
first_print_out = True
while (not stress and _i < num_iter) or (stress and cur_time < target_time) :
while (not stress and _i < num_iter) or (stress and cur_time < target_time):
if args.device == "cuda":
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -149,7 +209,9 @@ def run_one_step(func, model, nwarmup=WARMUP_ROUNDS, num_iter=10, export_metrics
end_event.record()
torch.cuda.synchronize()
t1 = time.time_ns()
result_summary.append((start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000))
result_summary.append(
(start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000)
)
elif args.device == "mps":
t0 = time.time_ns()
func()
Expand All @@ -167,11 +229,19 @@ def run_one_step(func, model, nwarmup=WARMUP_ROUNDS, num_iter=10, export_metrics
# print out the status every 10s.
if (cur_time - last_time) >= 10 * 1e9:
if first_print_out:
print('|{:^20}|{:^20}|{:^20}|'.format("Iterations", "Time/Iteration(ms)", "Rest Time(s)"))
print(
"|{:^20}|{:^20}|{:^20}|".format(
"Iterations", "Time/Iteration(ms)", "Rest Time(s)"
)
)
first_print_out = False
est = (target_time - cur_time) / 1e9
time_per_it = (cur_time - last_time) / (_i - last_it) / 1e6
print('|{:^20}|{:^20}|{:^20}|'.format("%d" % _i, "%.2f" % time_per_it , "%d" % int(est)))
print(
"|{:^20}|{:^20}|{:^20}|".format(
"%d" % _i, "%.2f" % time_per_it, "%d" % int(est)
)
)
last_time = cur_time
last_it = _i
_i += 1
Expand All @@ -183,38 +253,64 @@ def run_one_step(func, model, nwarmup=WARMUP_ROUNDS, num_iter=10, export_metrics
gpu_peak_mem = None
mem_device_id = None
model_flops = None
if 'cpu_peak_mem' in metrics_needed or 'gpu_peak_mem' in metrics_needed:
cpu_peak_mem, mem_device_id, gpu_peak_mem = get_peak_memory(func, model.device, export_metrics_file=export_metrics_file, metrics_needed=metrics_needed, metrics_gpu_backend=metrics_gpu_backend)
if 'model_flops' in metrics_needed:
if "cpu_peak_mem" in metrics_needed or "gpu_peak_mem" in metrics_needed:
cpu_peak_mem, mem_device_id, gpu_peak_mem = get_peak_memory(
func,
model.device,
export_metrics_file=export_metrics_file,
metrics_needed=metrics_needed,
metrics_gpu_backend=metrics_gpu_backend,
)
if "model_flops" in metrics_needed:
model_flops = get_model_flops(model)
printResultSummaryTime(result_summary, model, metrics_needed, flops_model_analyzer, model_flops, cpu_peak_mem, mem_device_id, gpu_peak_mem)
printResultSummaryTime(
result_summary,
model,
metrics_needed,
flops_model_analyzer,
model_flops,
cpu_peak_mem,
mem_device_id,
gpu_peak_mem,
)


def profile_one_step(func, model, nwarmup=WARMUP_ROUNDS):
activity_groups = []
result_summary = []
device_to_activity = {'cuda': profiler.ProfilerActivity.CUDA, 'cpu': profiler.ProfilerActivity.CPU}
device_to_activity = {
"cuda": profiler.ProfilerActivity.CUDA,
"cpu": profiler.ProfilerActivity.CPU,
}
if args.profile_devices:
activity_groups = [
device_to_activity[device] for device in args.profile_devices if (device in device_to_activity)
device_to_activity[device]
for device in args.profile_devices
if (device in device_to_activity)
]
else:
if args.device == 'cuda':
if args.device == "cuda":
activity_groups = [
profiler.ProfilerActivity.CUDA,
profiler.ProfilerActivity.CPU,
]
elif args.device == 'cpu':
elif args.device == "cpu":
activity_groups = [profiler.ProfilerActivity.CPU]

profile_opts = {}

for opt in SUPPORT_PROFILE_LIST:
profile_opts[opt] = True if args.profile_options is not None and opt in args.profile_options else False
profile_opts[opt] = False if args.no_profile_detailed else True
# options can be overriden by disable-profile-options
if args.disable_profile_options is not None and opt in args.disable_profile_options:
profile_opts[opt] = False

if args.profile_eg:
from datetime import datetime
import os
from datetime import datetime

from torch.profiler import ExecutionTraceObserver

start_time = datetime.now()
timestamp = int(datetime.timestamp(start_time))
eg_file = f"{args.model}_{timestamp}_eg.json"
Expand All @@ -227,12 +323,17 @@ def profile_one_step(func, model, nwarmup=WARMUP_ROUNDS):
with profiler.profile(
schedule=profiler.schedule(wait=0, warmup=nwarmup, active=1, repeat=1),
activities=activity_groups,
record_shapes=args.no_profile_detailed if args.no_profile_detailed else profile_opts["record_shapes"],
profile_memory=args.no_profile_detailed if args.no_profile_detailed else profile_opts["profile_memory"],
with_stack=args.no_profile_detailed if args.no_profile_detailed else profile_opts["with_stack"],
with_flops=args.no_profile_detailed if args.no_profile_detailed else profile_opts["with_flops"],
with_modules=args.no_profile_detailed if args.no_profile_detailed else profile_opts["with_modules"],
on_trace_ready= partial(trace_handler, f"torchbench_{args.model}") if (not hasattr(torch.version, "git_version") and args.profile_export_chrome_trace) else profiler.tensorboard_trace_handler(args.profile_folder),
record_shapes=profile_opts["record_shapes"],
profile_memory=profile_opts["profile_memory"],
with_stack=profile_opts["with_stack"],
with_flops=profile_opts["with_flops"],
with_modules=profile_opts["with_modules"],
on_trace_ready=partial(trace_handler, f"torchbench_{args.model}")
if (
not hasattr(torch.version, "git_version")
and args.profile_export_chrome_trace
)
else profiler.tensorboard_trace_handler(args.profile_folder),
) as prof:
if args.device == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -245,10 +346,12 @@ def profile_one_step(func, model, nwarmup=WARMUP_ROUNDS):
end_event.record()
t1 = time.time_ns()
if i >= nwarmup:
result_summary.append((start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000))
result_summary.append(
(start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000)
)
prof.step()
else:
for i in range(nwarmup + 1):
for i in range(nwarmup + 1):
t0 = time.time_ns()
func()
t1 = time.time_ns()
Expand All @@ -259,25 +362,34 @@ def profile_one_step(func, model, nwarmup=WARMUP_ROUNDS):
eg.stop()
eg.unregister_callback()
print(f"Save Exeution Trace to : {args.profile_eg_folder}/{eg_file}")
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=30))
print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=30
)
)
print(f"Saved TensorBoard Profiler traces to {args.profile_folder}.")

printResultSummaryTime(result_summary, model=m)


def _validate_devices(devices: str):
devices_list = devices.split(",")
valid_devices = SUPPORT_DEVICE_LIST
for d in devices_list:
if d not in valid_devices:
raise ValueError(f'Invalid device {d} passed into --profile-devices. Expected devices: {valid_devices}.')
raise ValueError(
f"Invalid device {d} passed into --profile-devices. Expected devices: {valid_devices}."
)
return devices_list


def _validate_profile_options(profile_options: str):
profile_options_list = profile_options.split(",")
for opt in profile_options_list:
if opt not in SUPPORT_PROFILE_LIST:
raise ValueError(f'Invalid profile option {opt} passed into --profile-options. Expected options: {SUPPORT_PROFILE_LIST}.')
raise ValueError(
f"Invalid profile option {opt} passed into --profile-options. Expected options: {SUPPORT_PROFILE_LIST}."
)
return profile_options_list


Expand Down Expand Up @@ -305,9 +417,9 @@ def _validate_profile_options(profile_options: str):
"--profile", action="store_true", help="Run the profiler around the function"
)
parser.add_argument(
"--profile-options",
"--disable-profile-options",
type=_validate_profile_options,
help=f"Select which profile options to enable. Valid options: {SUPPORT_PROFILE_LIST}.",
help=f"Select which profile options to disable. Valid options: {SUPPORT_PROFILE_LIST}.",
)
parser.add_argument("--amp", action="store_true", help="enable torch.autocast()")
parser.add_argument(
Expand All @@ -317,8 +429,9 @@ def _validate_profile_options(profile_options: str):
)
parser.add_argument(
"--no-profile-detailed",
action="store_false",
help=f"Only profile GPU kernels, excluding {SUPPORT_PROFILE_LIST}. Overrides --profile-options.",
action="store_true",
help=f"Only profile GPU kernels, excluding {SUPPORT_PROFILE_LIST}. "
"To only disable some profile options, use --disable-profile-options instead.",
)
parser.add_argument(
"--profile-export-chrome-trace",
Expand Down Expand Up @@ -359,14 +472,23 @@ def _validate_profile_options(profile_options: str):
"--metrics",
type=str,
default="cpu_peak_mem,gpu_peak_mem,ttfb",
help="Specify metrics [cpu_peak_mem,gpu_peak_mem,ttfb,flops,model_flops]to be collected. You can also set `none` to disable all metrics. The metrics are separated by comma such as cpu_peak_mem,gpu_peak_mem.",
help="Specify metrics [cpu_peak_mem,gpu_peak_mem,ttfb,flops,model_flops]to be collected. "
"You can also set `none` to disable all metrics. The metrics are separated by comma such as cpu_peak_mem,gpu_peak_mem.",
)
parser.add_argument(
"--metrics-gpu-backend",
choices=["dcgm", "default"],
default="default",
help="""Specify the backend [dcgm, default] to collect metrics. \nIn default mode, the latency(execution time) is collected by time.time_ns() and it is always enabled. Optionally,
\n - you can specify cpu peak memory usage by --metrics cpu_peak_mem, and it is collected by psutil.Process(). \n - you can specify gpu peak memory usage by --metrics gpu_peak_mem, and it is collected by nvml library.\n - you can specify flops by --metrics flops, and it is collected by fvcore.\nIn dcgm mode, the latency(execution time) is collected by time.time_ns() and it is always enabled. Optionally,\n - you can specify cpu peak memory usage by --metrics cpu_peak_mem, and it is collected by psutil.Process().\n - you can specify cpu and gpu peak memory usage by --metrics cpu_peak_mem,gpu_peak_mem, and they are collected by dcgm library.""",
help="""
Specify the backend [dcgm, default] to collect metrics.
In default mode, the latency(execution time) is collected by time.time_ns() and it is always enabled.
Optionally, - you can specify cpu peak memory usage by --metrics cpu_peak_mem, and it is collected by psutil.Process().
- you can specify gpu peak memory usage by --metrics gpu_peak_mem, and it is collected by nvml library.
- you can specify flops by --metrics flops, and it is collected by fvcore.
In dcgm mode, the latency(execution time) is collected by time.time_ns() and it is always enabled.
Optionally,
- you can specify cpu peak memory usage by --metrics cpu_peak_mem, and it is collected by psutil.Process().
- you can specify cpu and gpu peak memory usage by --metrics cpu_peak_mem,gpu_peak_mem, and they are collected by dcgm library.""",
)
parser.add_argument(
"--channels-last", action="store_true", help="enable torch.channels_last()"
Expand Down Expand Up @@ -464,10 +586,7 @@ def _validate_profile_options(profile_options: str):
else:
export_metrics_file = None
if args.profile:
profile_one_step(
test,
model=m
)
profile_one_step(test, model=m)
elif args.cudastreams:
run_one_step_with_cudastreams(test, 10)
else:
Expand Down

0 comments on commit bd0e8d5

Please sign in to comment.