From 9d829bab61db34b917d8b2947aabcc72b153eea2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 4 Oct 2023 01:12:04 +0000 Subject: [PATCH] address review feedback --- .../tools/transformers/benchmark_helper.py | 20 +- .../models/stable_diffusion/README.md | 2 +- .../models/stable_diffusion/benchmark.py | 351 ++++++------------ .../stable_diffusion/diffusion_models.py | 3 + 4 files changed, 129 insertions(+), 247 deletions(-) diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 67d3c95922a87..4f898245d01bd 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -542,7 +542,7 @@ def measure_gpu_usage(self): while True: for i in range(device_count): max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.005) # 2ms + time.sleep(0.005) # 5ms if not self.keep_measuring: break return [ @@ -555,7 +555,7 @@ def measure_gpu_usage(self): ] -def measure_memory(is_gpu, func, monitor_type="cuda"): +def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): memory_monitor_type = None if monitor_type == "rocm": memory_monitor_type = RocmMemoryMonitor @@ -565,10 +565,16 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): monitor = memory_monitor_type(False) if is_gpu: - memory_before_test = monitor.measure_gpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_gpu_usage() if memory_before_test is None: return None + if func is None: + return memory_before_test + with ThreadPoolExecutor() as executor: monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_gpu_usage) @@ -595,7 +601,13 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): return None # CPU memory - memory_before_test = monitor.measure_cpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_cpu_usage() + + if func is None: + return memory_before_test with ThreadPoolExecutor() as executor: monitor = MemoryMonitor() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 28690839b82af..1fbd5092a719a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -79,7 +79,7 @@ If you use CUDA 12.*, you will need build onnxruntime-gpu from source. ``` conda create -n py38 python=3.8 conda activate py38 -pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda.txt ``` diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 4ed4ca4fbc357..f8fda13a35b93 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -15,6 +15,7 @@ # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. import torch +from benchmark_helper import measure_memory SD_MODELS = { "1.5": "runwayml/stable-diffusion-v1-5", @@ -50,136 +51,8 @@ def example_prompts(): return prompts, negative_prompt -class CudaMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - - def measure_gpu_usage(self): - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) - - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - print(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - print(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - print("Error fetching GPU information using nvml: %s", error) - return None - - -class RocmMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - rocm_smi_path = "/opt/rocm/libexec/rocm_smi" - if os.path.exists(rocm_smi_path): - if rocm_smi_path not in sys.path: - sys.path.append(rocm_smi_path) - try: - import rocm_smi - - self.rocm_smi = rocm_smi - self.rocm_smi.initializeRsmi() - except ImportError: - self.rocm_smi = None - - def get_used_memory(self, dev): - if self.rocm_smi is None: - return -1 - return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 - - def measure_gpu_usage(self): - device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [f"GPU{i}" for i in range(device_count)] - while True: - for i in range(device_count): - max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - - def measure_gpu_memory(monitor_type, func, start_memory=None): - if monitor_type is None: - return None - - monitor = monitor_type(False) - memory_before_test = monitor.measure_gpu_usage() - - if start_memory is None: - start_memory = memory_before_test - if start_memory is None: - return None - if func is None: - return start_memory - - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor() as executor: - monitor = monitor_type() - mem_thread = executor.submit(monitor.measure_gpu_usage) - try: - fn_thread = executor.submit(func) - _ = fn_thread.result() - finally: - monitor.keep_measuring = False - max_usage = mem_thread.result() - - if max_usage is None: - return None - - print(f"GPU memory usage: before={memory_before_test} peak={max_usage}") - if len(start_memory) >= 1 and len(max_usage) >= 1 and len(start_memory) == len(max_usage): - # When there are multiple GPUs, we will check the one with maximum usage. - max_used = 0 - for i, memory_before in enumerate(start_memory): - before = memory_before["max_used_MB"] - after = max_usage[i]["max_used_MB"] - used = after - before - max_used = max(max_used, used) - return max_used - return None + return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory) def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool): @@ -285,7 +158,6 @@ def warmup(): num_inference_steps=steps, negative_prompt=[negative_prompt] * batch_size, guidance_scale=7.5, - # num_images_per_prompt=batch_size, ).images inference_end = time.time() latency = inference_end - inference_start @@ -352,7 +224,6 @@ def warmup(): num_inference_steps=steps, guidance_scale=7.5, negative_prompt=[negative_prompt] * batch_size, - # num_images_per_prompt=batch_size, generator=None, # torch.Generator ).images @@ -815,10 +686,10 @@ def run_tensorrt_static( # Load TensorRT engines and pytorch modules pipeline.backend.load_engines( - engine_dir, - framework_model_dir, - onnx_dir, - 17, + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, opt_batch_size=batch_size, opt_image_height=height, opt_image_width=width, @@ -967,10 +838,10 @@ def init_pipeline(pipeline_class, pipeline_info): ) pipeline.backend.load_engines( - engine_dir, - framework_model_dir, - onnx_dir, - 17, + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, opt_batch_size=batch_size, opt_image_height=height, opt_image_width=width, @@ -1526,11 +1397,7 @@ def main(): coloredlogs.install(fmt="%(funcName)20s: %(message)s") - memory_monitor_type = None - if args.provider in ["cuda", "tensorrt"]: - memory_monitor_type = CudaMemoryMonitor - elif args.provider == "rocm": - memory_monitor_type = RocmMemoryMonitor + memory_monitor_type = "rocm" if args.provider == "rocm" else "cuda" start_memory = measure_gpu_memory(memory_monitor_type, None) print("GPU memory used before loading models:", start_memory) @@ -1541,18 +1408,18 @@ def main(): if "xl" in args.version: print("Testing Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.") result = run_ort_trt_xl( - args.work_dir, - args.version, - args.batch_size, - True, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, nvtx_profile=False, use_cuda_graph=args.enable_cuda_graph, ) @@ -1563,34 +1430,34 @@ def main(): ) ) result = run_ort_trt( - args.version, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, - args.enable_cuda_graph, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + enable_cuda_graph=args.enable_cuda_graph, ) else: print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.") result = run_ort_trt_static( - args.work_dir, - args.version, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, nvtx_profile=False, use_cuda_graph=args.enable_cuda_graph, ) @@ -1602,18 +1469,18 @@ def main(): ) ) result = export_and_run_ort( - args.version, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.enable_cuda_graph, + version=args.version, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + enable_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( @@ -1621,54 +1488,54 @@ def main(): ), "--pipeline should be specified for the directory of ONNX models" print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}") result = run_ort( - sd_model, - args.pipeline, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.tuning, + model_name=sd_model, + directory=args.pipeline, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + tuning=args.tuning, ) elif args.engine == "tensorrt" and "xl" in args.version: print("Testing Txt2ImgXLPipeline with static input shape. Backend is TensorRT.") result = run_tensorrt_static_xl( - args.work_dir, - args.version, - args.batch_size, - True, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, nvtx_profile=False, use_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "tensorrt": print("Testing Txt2ImgPipeline with static input shape. Backend is TensorRT.") result = run_tensorrt_static( - args.work_dir, - args.version, - sd_model, - args.batch_size, - True, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + work_dir=args.work_dir, + version=args.version, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, nvtx_profile=False, use_cuda_graph=args.enable_cuda_graph, ) @@ -1677,18 +1544,18 @@ def main(): f"Testing Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile={args.enable_torch_compile}, xformers={args.use_xformers}." ) result = run_torch( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.enable_torch_compile, - args.use_xformers, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + enable_torch_compile=args.enable_torch_compile, + use_xformers=args.use_xformers, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, ) print(result) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index b8e6b1b41931f..951cd66005f4c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -435,6 +435,9 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path) for j in range(len(graph.node[i].output)): if graph.node[i].output[j] == node_output_name: found = True + break + if found: + break if not found: raise RuntimeError("Failed to find hidden_states graph output in clip")