Skip to content

Commit

Permalink
fix refiner
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Dec 20, 2023
1 parent d0c0478 commit fe7b237
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_grap
)

engine_type = EngineType.ORT_CUDA if engine == "ort_cuda" else EngineType.ORT_TRT
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type
)

Expand All @@ -157,7 +157,6 @@ def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_grap
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
force_engine_rebuild=False,
device_id=torch.cuda.current_device(),
)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call Error

Call to
method OrtTensorrtEngineBuilder.build_engines
with too few arguments; should be no fewer than 6.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def repeat_prompt(args):

def initialize_pipeline(
version="xl-turbo",
is_refiner:bool=False,
is_inpaint:bool=False,
engine_type=EngineType.ORT_CUDA,
work_dir: str = ".",
engine_dir=None,
Expand All @@ -406,6 +408,8 @@ def initialize_pipeline(
):
pipeline_info = PipelineInfo(
version,
is_refiner=is_refiner,
is_inpaint=is_inpaint,
use_vae=use_vae,
min_image_size=min_image_size,
max_image_size=max_image_size,
Expand Down Expand Up @@ -529,6 +533,8 @@ def load_pipelines(args, batch_size=None):

params = {
"version": args.version,
"is_refiner": False,
"is_inpaint": False,
"engine_type": engine_type,
"work_dir": args.work_dir,
"engine_dir": args.engine_dir,
Expand Down Expand Up @@ -561,6 +567,7 @@ def load_pipelines(args, batch_size=None):
refiner = None
if "xl" in args.version and args.enable_refiner:
params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner.
params["is_refiner"] = True
params["scheduler"] = args.refiner_scheduler
params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0
params["lcm"] = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import torch
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from onnx import TensorProto
from ort_utils import CudaSession, OnnxModel
from packaging import version

import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import CudaSession
from onnxruntime.transformers.onnx_model import OnnxModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -200,7 +200,7 @@ def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str):
)

if model_name == "clip2":
model.change_graph_input_type(model.find_graph_input("input_ids"), TensorProto.INT32)
model.change_graph_input_type(model.find_graph_input("input_ids"), onnx.TensorProto.INT32)

model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2"))
elif model_name in ["unet", "unetxl"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from ort_utils import CudaSession
from packaging import version

import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import CudaSession

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
use_cuda_graph=use_cuda_graph,
)

self.compile_config = {}
if use_cuda_graph:
self.compile_config = {
"clip": {"mode": "reduce-overhead", "dynamic": False},
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,20 @@ def tokenize(prompt, output_hidden_states):
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

if pooled_outputs:
pooled_output = text_embeddings

if output_hidden_states:
if do_classifier_free_guidance:
text_embeddings = torch.cat([uncond_hidden_states, hidden_states])
else:
text_embeddings = hidden_states
if output_hidden_states:
hidden_states = torch.cat([uncond_hidden_states, hidden_states])

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'uncond_hidden_states' may be used before it is initialized.

self.stop_profile("clip")

if pooled_outputs:
return text_embeddings.to(dtype=dtype), pooled_output.to(dtype=dtype)
# For text encoder in sdxl base
return hidden_states.to(dtype=dtype), text_embeddings.to(dtype=dtype)

if output_hidden_states:
# For text encoder 2 in sdxl base or refiner
return hidden_states.to(dtype=dtype)

# For text encoder in sd 1.5
return text_embeddings.to(dtype=dtype)

def denoise_latent(
Expand Down

0 comments on commit fe7b237

Please sign in to comment.