Skip to content

Commit

Permalink
Add CUDA EP in StableDiffusion demo (#17788)
Browse files Browse the repository at this point in the history
Add CUDA EP to the demo of stable diffusion.

### A100 Performance
Test | Engine Property | Batch Size | TRT Latency (ms) | ORT_TRT Latency
(ms) | ORT_CUDA Latency (ms) | TORCH Latency (ms)
-- | -- | -- | -- | -- | -- | --
SD 1.5, 50 steps, 512x512 | Static Input Shape | 1 | 861 | 851 | 861 |
N/A
SD 1.5, 50 steps, 512x512 | Dynamic Input Shape, Optimized for batch
size 1 and image size 512x512 | 1 | 974 | 1079 | 928 | 1222
SD 1.5, 50 steps, 768x768 | Dynamic Input Shape, Optimized for batch
size 1 and image size 512x512 | 1 | 2492 | OOM | 1901 | 1971
SD 1.5, 50 steps, 768x768 | Dynamic Input Shape, Optimized for batch
size 1 and image size 512x512 | 4 |9091 | OOM | 6785 | 6700

We can see that ORT_CUDA is the most robust one for handling dynamic
input shape. PyTorch could be a good choice if you run large batch size.

The above result is from one A100-SXM4-80GB GPU (in
Standard_ND96amsr_A100_v4 Azure VM) with 50 steps to generate 512x512 or
768x768 images using StableDiffusion 1.5. Onnxruntime-gpu is built from
source, and the following packages or libraries are used in this test:
* tensorrt==8.6.1.post1
* torch==2.2.0.dev20230920+cu121
* transformers==4.31.0
* diffusers==0.19.3
* onnx==1.14.1
* onnx-graphsurgeon==0.3.27
* polygraphy==0.47.1
* protobuf==3.20.2
* onnxruntime-gpu==1.17.0 (built from source of main branch)
* CUDA 12.2.2
* cuDNN 8.9.5.29
* python 3.10.13

For static input shape, the engine is built with static batch size and
static image shape, and cuda graph is enabled.

For dynamic input shape, the engine is built to support dynamic batch
size and dynamic image shape, and cuda graph is disabled. The TensorRT
engine is built for batch size 1~4, image size 256x256 ~ 1024x1024, and
the optimized image size is 512x512.

The script to test static and dynamic input shape are like the
following:
```
prompt="a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital paintining"
for e in TRT ORT_TRT ORT_CUDA
do
  python demo_txt2img.py --engine $e "$prompt"
  python demo_txt2img.py --engine $e --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape "$prompt"
  python demo_txt2img.py --engine $e --disable-cuda-graph --build-dynamic-batch --build-dynamic-shape --height 768 --width 768 "$prompt"
done
```

Performance of PyTorch is from commands like the following:
```
python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 1 --height 512 --width 512
python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 1 --height 768 --width 768
python benchmark.py -e torch -v 1.5 --enable_torch_compile -b 4 --height 768 --width 768
```
  • Loading branch information
tianleiwu authored Oct 5, 2023
1 parent db3901a commit d6dad96
Show file tree
Hide file tree
Showing 13 changed files with 341 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,38 +72,52 @@ cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion

Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`.

It is recommended to create a Conda environment with Python 3.10 for the following setup:
```
conda create -n py310 python=3.10
conda activate py310
```

### Setup Environment (CUDA)

It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with CUDA 11.8.
If you use CUDA 12.*, you will need build onnxruntime-gpu from source.
First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 8.5 or above, and [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine.

#### CUDA 11.8:

In the Conda environment, install PyTorch 2.1 or above, and other required packages like the following:
```
conda create -n py38 python=3.8
conda activate py38
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
pip install -r requirements-cuda.txt
pip install -r requirements-cuda11.txt
```
ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.8 and cuDNN 8.5 or above are recommended.

#### Install Nightly (Optional)
We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead.

Skip this step if you use onnxruntime-gpu package from official releases.
#### CUDA 12.*:
The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CUDA 12.*, you will need [build onnxruntime from source](https://onnxruntime.ai/docs/build/inferencing.html).

To try latest optimizations, you can install [ort-nightly-gpu](https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/ort-nightly-gpu/) package like the following:
```
git clone --recursive https://github.com/Microsoft/onnxruntime.git
cd onnxruntime
pip install -r requirements-dev.txt
```
Follow [example script for A100 in Ubuntu](https://github.com/microsoft/onnxruntime/blob/26a7b63716e3125bfe35fe3663ba10d2d7322628/build_release.sh)
or [example script for RTX 4090 in Windows](https://github.com/microsoft/onnxruntime/blob/8df5f4e0df1f3b9ceeb0f1f2561b09727ace9b37/build_trt.cmd) to build and install onnxruntime-gpu wheel.

Then install other python packages like the following:
```
pip uninstall onnxruntime-gpu
pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/
pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
pip install -r requirements-cuda12.txt
```
Finally, `pip install tensorrt` for Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead.

### Setup Environment (ROCm)

It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.8, 3.9 or 3.10.
It is recommended that the users run the model with ROCm 5.4 or newer and Python 3.10.
Note that Windows is not supported for ROCm at the moment.

```
conda create -n py38 python=3.8
conda activate py38
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-5.4/torch-1.12.1%2Brocm5.4-cp38-cp38-linux_x86_64.whl
pip install torch-1.12.1+rocm5.4-cp38-cp38-linux_x86_64.whl
pip install -r requirements-rocm.txt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatte
def parse_arguments(is_xl: bool, description: str):
parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter)

engines = ["ORT_TRT", "TRT"] if is_xl else ["ORT_CUDA", "ORT_TRT", "TRT"]

parser.add_argument(
"--engine",
type=str,
default="ORT_TRT",
choices=["ORT_TRT", "TRT"],
help="Backend engine. Default is OnnxRuntime CUDA execution provider.",
default=engines[0],
choices=engines,
help="Backend engine in {engines}. "
"ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT",
)

supported_versions = PipelineInfo.supported_versions(is_xl)
Expand Down Expand Up @@ -106,7 +109,7 @@ def parse_arguments(is_xl: bool, description: str):
parser.add_argument(
"--onnx-opset",
type=int,
default=17,
default=None,
choices=range(14, 18),
help="Select ONNX opset version to target for exported models.",
)
Expand Down Expand Up @@ -163,6 +166,16 @@ def parse_arguments(is_xl: bool, description: str):

args = parser.parse_args()

if (
args.engine in ["ORT_CUDA", "ORT_TRT"]
and (args.force_onnx_export or args.force_onnx_optimize)
and not args.force_engine_build
):
raise ValueError(
"For ORT_CUDA or ORT_TRT, --force_onnx_export and --force_onnx_optimize are not supported. "
"Please use --force_engine_build instead."
)

# Validate image dimensions
if args.height % 8 != 0 or args.width % 8 != 0:
raise ValueError(
Expand All @@ -173,6 +186,9 @@ def parse_arguments(is_xl: bool, description: str):
print("[I] CUDA Graph is disabled since dynamic input shape is configured.")
args.disable_cuda_graph = True

if args.onnx_opset is None:
args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17

print(args)

return args
Expand All @@ -197,7 +213,7 @@ def repeat_prompt(args):

def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size):
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
args.work_dir, pipeline_info, engine_type
work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type
)

# Initialize demo
Expand All @@ -214,7 +230,24 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si
engine_type=engine_type,
)

if engine_type == EngineType.ORT_TRT:
if engine_type == EngineType.ORT_CUDA:
# Build CUDA EP engines and load pytorch modules
pipeline.backend.build_engines(
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
onnx_opset=args.onnx_opset,
opt_image_height=args.height,
opt_image_width=args.height,
opt_batch_size=batch_size,
force_engine_rebuild=args.force_engine_build,
device_id=torch.cuda.current_device(),
disable_cuda_graph_models=[
"clip2", # TODO: Add ArgMax cuda kernel to enable cuda graph for clip2.
"unetxl",
],
)
elif engine_type == EngineType.ORT_TRT:
# Build TensorRT EP engines and load pytorch modules
pipeline.backend.build_engines(
engine_dir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,16 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch,
def get_shape_dict(self, batch_size, image_height, image_width):
return None

def fp32_input_output_names(self) -> List[str]:
"""For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model.
This is a list of input or output names that are kept as float32 during converting.
For the first version, we will use same data type as TensorRT.
"""
return []

def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True):
optimizer = self.get_ort_optimizer()
optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16)
optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=self.fp32_input_output_names())

def optimize_trt(self, input_onnx_path, optimized_onnx_path):
onnx_graph = onnx.load(input_onnx_path)
Expand Down Expand Up @@ -416,7 +423,7 @@ def get_sample_input(self, batch_size, image_height, image_width):
self.check_dims(batch_size, image_height, image_width)
return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),)

def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path):
def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, use_external_data_format=False):
graph: GraphProto = model.graph
hidden_layers = -1
for i in range(len(graph.node)):
Expand Down Expand Up @@ -457,7 +464,29 @@ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path)

onnx_model = OnnxModel(model)
onnx_model.add_node(cast_node)
onnx_model.save_model_to_file(optimized_onnx_path)
onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)

def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True):
optimizer = self.get_ort_optimizer()
if not self.output_hidden_state:
optimizer.optimize(
input_onnx_path, optimized_onnx_path, to_fp16, keep_io_types=[], keep_outputs=["text_embeddings"]
)
else:
with tempfile.TemporaryDirectory() as tmp_dir:
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to add hidden_states to graph output ...")
tmp_model_path = os.path.join(tmp_dir, "model.onnx")

model = onnx.load(input_onnx_path)
self.add_hidden_states_graph_output(model, tmp_model_path, use_external_data_format=True)
optimizer.optimize(
tmp_model_path,
optimized_onnx_path,
to_fp16,
keep_io_types=[],
keep_outputs=["text_embeddings", "hidden_states"],
)

def optimize_trt(self, input_onnx_path, optimized_onnx_path):
onnx_graph = onnx.load(input_onnx_path)
Expand Down Expand Up @@ -598,6 +627,9 @@ def get_sample_input(self, batch_size, image_height, image_width):
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
)

def fp32_input_output_names(self) -> List[str]:
return ["sample", "timestep"]


class UNetXL(BaseModel):
def __init__(
Expand Down Expand Up @@ -703,6 +735,9 @@ def get_sample_input(self, batch_size, image_height, image_width):
},
)

def fp32_input_output_names(self) -> List[str]:
return ["sample", "timestep"]


# VAE Decoder
class VAE(BaseModel):
Expand Down Expand Up @@ -773,6 +808,9 @@ def get_sample_input(self, batch_size, image_height, image_width):
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),)

def fp32_input_output_names(self) -> List[str]:
return ["latent", "images"]


def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"):
tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def get_cached_model_name(self, model_name):

def get_onnx_path(self, model_name, onnx_dir, opt=True):
engine_name = self.engine_type.name.lower()
onnx_model_dir = os.path.join(
onnx_dir, self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "")
)
directory_name = self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "")
onnx_model_dir = os.path.join(onnx_dir, directory_name)
os.makedirs(onnx_model_dir, exist_ok=True)
return os.path.join(onnx_model_dir, "model.onnx")

Expand Down
Loading

0 comments on commit d6dad96

Please sign in to comment.