Skip to content

Commit

Permalink
add missing file
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Dec 19, 2023
1 parent f347e53 commit be42152
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1336,18 +1336,18 @@ def main():
os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1"

result = run_optimum_ort(
sd_model,
args.pipeline,
provider,
args.batch_size,
not args.enable_safety_checker,
args.height,
args.width,
args.steps,
args.num_prompts,
args.batch_count,
start_memory,
memory_monitor_type,
model_name=sd_model,
directory=args.pipeline,
provider=provider,
batch_size=args.batch_size,
disable_safety_checker=not args.enable_safety_checker,
height=args.height,
width=args.width,
steps=args.steps,
num_prompts=args.num_prompts,
batch_count=args.batch_count,
start_memory=start_memory,
memory_monitor_type=memory_monitor_type,
)
elif args.engine == "onnxruntime":
assert args.pipeline and os.path.isdir(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType

logger = logging.getLogger(__name__)


class TorchEngineBuilder(EngineBuilder):
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.TORCH,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)

self.compile_config = {
"clip": {"mode": "reduce-overhead", "dynamic": False},
"clip2": {"mode": "reduce-overhead", "dynamic": False},
"unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
"unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
}

self.compile_config["vae"] = {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False}

def build_engines(
self,
framework_model_dir: str,
):
import torch

self.torch_device = torch.device("cuda", torch.cuda.current_device())
self.load_models(framework_model_dir)

pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None

built_engines = {}
for model_name, model_obj in self.models.items():
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
model = model.to(device=self.torch_device, dtype=torch.float32)
else:
model = model.to(device=self.torch_device, dtype=torch.float16)

if model_name in self.compile_config:
compile_config = self.compile_config[model_name]
if model_name in ["unet", "unetxl"]:
model.to(memory_format=torch.channels_last)
engine = torch.compile(model, **compile_config)
built_engines[model_name] = engine
else:
built_engines[model_name] = model

self.engines = built_engines

def run_engine(self, model_name, feed_dict):
if model_name in ["unet", "unetxl"]:
if "controlnet_images" in feed_dict:
return {"latent": self.engines[model_name](**feed_dict)}

if model_name == "unetxl":
added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
return {
"latent": self.engines[model_name](
feed_dict["sample"],
feed_dict["timestep"],
feed_dict["encoder_hidden_states"],
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
}

return {
"latent": self.engines[model_name](
feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
)[0]
}

if model_name in ["vae_encoder"]:
return {"latent": self.engines[model_name](feed_dict["images"])}

raise RuntimeError(f"Shall not reach here: {model_name}")

0 comments on commit be42152

Please sign in to comment.