Skip to content

Commit

Permalink
StableDiffusion XL with TensorRT EP (#17748)
Browse files Browse the repository at this point in the history
Accelerate StableDiffusion XL with TensorRT EP. It is modified from
TensorRT demo diffusion, and we updated the design to make the pipeline
works with different backend engines.

The following result is from A100 80GB with 30 steps of Base, or 30
steps Base & 30 Steps Refiner to generate 1024x1024 images. The engine
is built with static input shape, and cuda graph is enabled.

  | Batch Size | TRT Latency (ms) | ORT_TRT Latency (ms) | Diff
-- | -- | -- | -- | --
Base | 1 | 2714 | 2679 | -1.3%
Base & Refiner | 1 | 3593 | 3530 | -1.8%

The test environment: onnxruntime-gpu is built from source, and the following packages or
libraries are used in this test:
* tensorrt==8.6.1.post1
* torch==2.2.0.dev20230920+cu121
* transformers==4.31.0
* diffusers==0.19.3
* onnx==1.14.1
* onnx-graphsurgeon==0.3.27
* polygraphy==0.47.1
* protobuf==3.20.2
* onnxruntime-gpu==1.17.0 (built from source of main branch)
* CUDA 12.2.2
* cuDNN 8.9.5.29
* python 3.10.13
  • Loading branch information
tianleiwu authored Oct 4, 2023
1 parent 8e6019a commit a05580e
Show file tree
Hide file tree
Showing 26 changed files with 5,090 additions and 1,292 deletions.
20 changes: 16 additions & 4 deletions onnxruntime/python/tools/transformers/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def measure_gpu_usage(self):
while True:
for i in range(device_count):
max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
time.sleep(0.005) # 2ms
time.sleep(0.005) # 5ms
if not self.keep_measuring:
break
return [
Expand All @@ -555,7 +555,7 @@ def measure_gpu_usage(self):
]


def measure_memory(is_gpu, func, monitor_type="cuda"):
def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
memory_monitor_type = None
if monitor_type == "rocm":
memory_monitor_type = RocmMemoryMonitor
Expand All @@ -565,10 +565,16 @@ def measure_memory(is_gpu, func, monitor_type="cuda"):
monitor = memory_monitor_type(False)

if is_gpu:
memory_before_test = monitor.measure_gpu_usage()
if start_memory is not None:
memory_before_test = start_memory
else:
memory_before_test = monitor.measure_gpu_usage()
if memory_before_test is None:
return None

if func is None:
return memory_before_test

with ThreadPoolExecutor() as executor:
monitor = memory_monitor_type()
mem_thread = executor.submit(monitor.measure_gpu_usage)
Expand All @@ -595,7 +601,13 @@ def measure_memory(is_gpu, func, monitor_type="cuda"):
return None

# CPU memory
memory_before_test = monitor.measure_cpu_usage()
if start_memory is not None:
memory_before_test = start_memory
else:
memory_before_test = monitor.measure_cpu_usage()

if func is None:
return memory_before_test

with ThreadPoolExecutor() as executor:
monitor = MemoryMonitor()
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/io_binding_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]):
if name in self.input_names:
if self.enable_cuda_graph:
assert self.input_tensors[name].nelement() == tensor.nelement()
assert self.input_tensors[name].dtype == tensor.dtype
assert tensor.device.type == "cuda"
# Please install cuda-python package with a version corresponding to CUDA in your machine.
from cuda import cudart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,16 @@ Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, p

### Setup Environment (CUDA)

It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with [CUDA 11.7](https://developer.nvidia.com/cuda-11-7-0-download-archive) or 11.8.
It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with CUDA 11.8.
If you use CUDA 12.*, you will need build onnxruntime-gpu from source.
```
conda create -n py38 python=3.8
conda activate py38
pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
pip install -r requirements-cuda.txt
```

ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.7 and cuDNN 8.5 are used in our tests.
ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.8 and cuDNN 8.5 or above are recommended.

#### Install Nightly (Optional)

Expand Down Expand Up @@ -233,18 +234,21 @@ Sometime, it complains ptxas not found when there are multiple CUDA versions ins
Note that torch.compile is not supported in Windows: we encountered error `Windows not yet supported for torch.compile`. So it is excluded from RTX 3060 results of Windows.


### Run Benchmark with TensorRT and TensorRT execution provider
### Run Benchmark with TensorRT or TensorRT execution provider

For TensorRT installation, follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html.

```
pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
pip install -r requirements-tensorrt.txt
export CUDA_MODULE_LOADING=LAZY
python benchmark.py -e tensorrt -b 1 -v 1.5
python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5
python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 --enable_cuda_graph
python benchmark.py -e tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph
python benchmark.py -e onnxruntime -r tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph
```

### Example Benchmark output
Expand Down
Loading

0 comments on commit a05580e

Please sign in to comment.