Skip to content

Commit

Permalink
Use fp16 as the default vae dtype for the audio VAE.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 16, 2024
1 parent 8ddc151 commit 6425252
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
35 changes: 20 additions & 15 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def is_nvidia():
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False

VAE_DTYPE = torch.float32
VAE_DTYPES = [torch.float32]

try:
if is_nvidia():
Expand All @@ -176,25 +176,18 @@ def is_nvidia():
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPE = torch.bfloat16
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass

if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES

if args.cpu_vae:
VAE_DTYPE = torch.float32

if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
VAE_DTYPE = torch.bfloat16
elif args.fp32_vae:
VAE_DTYPE = torch.float32
VAE_DTYPES = [torch.float32]


if ENABLE_PYTORCH_ATTENTION:
Expand Down Expand Up @@ -258,7 +251,6 @@ def get_torch_device_name(device):
except:
logging.warning("Could not pick default device.")

logging.info("VAE dtype: {}".format(VAE_DTYPE))

current_loaded_models = []

Expand Down Expand Up @@ -619,9 +611,22 @@ def vae_offload_device():
else:
return torch.device("cpu")

def vae_dtype():
global VAE_DTYPE
return VAE_DTYPE
def vae_dtype(device=None, allowed_dtypes=[]):
global VAE_DTYPES
if args.fp16_vae:
return torch.float16
elif args.bf16_vae:
return torch.bfloat16
elif args.fp32_vae:
return torch.float32

for d in allowed_dtypes:
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
return d
if d in VAE_DTYPES:
return d

return VAE_DTYPES[0]

def get_autocast_device(dev):
if hasattr(dev, 'type'):
Expand Down
5 changes: 4 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]

if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
Expand Down Expand Up @@ -245,6 +246,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
self.downscale_ratio = 2048
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
Expand All @@ -265,12 +267,13 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
self.device = device
offload_device = model_management.vae_offload_device()
if dtype is None:
dtype = model_management.vae_dtype()
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()

self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))

def vae_encode_crop_pixels(self, pixels):
dims = pixels.shape[1:-1]
Expand Down

0 comments on commit 6425252

Please sign in to comment.