Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Dec 15, 2024
1 parent 699a64c commit c1d0160
Showing 1 changed file with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,17 @@ def has_external_data(onnx_model_path):
return False


def is_sd_3(source_dir: Path):
return (source_dir / "text_encoder_3").exists()


def is_sdxl(source_dir: Path):
return (source_dir / "text_encoder_2").exists() and not (source_dir / "text_encoder_3").exists()


def _get_model_list(source_dir: Path):
is_xl = (source_dir / "text_encoder_2").exists()
is_sd3 = (source_dir / "text_encoder_3").exists()
is_xl = is_sdxl(source_dir)
is_sd3 = is_sd_3(source_dir)
model_list_sd3 = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"]
model_list_sdxl = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"]
model_list_sd = ["text_encoder", "unet", "vae_encoder", "vae_decoder"]
Expand Down Expand Up @@ -163,8 +171,7 @@ def _optimize_sd_pipeline(

if float16:
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
is_xl = (source_dir / "text_encoder_2").exists()
if is_xl and name == "vae_decoder":
if is_sdxl(source_dir) and name == "vae_decoder":
logger.info("Skip converting %s to float16 to avoid NaN", name)
else:
logger.info("Convert %s to float16 ...", name)
Expand Down

0 comments on commit c1d0160

Please sign in to comment.