From 1de69fe4d56cfb0c1dbf5a14944c60079ba09d23 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Aug 2024 15:29:36 -0400 Subject: [PATCH] Fix some issues with inference slowing down. --- comfy/model_management.py | 39 ++++++++++++++++++++++++--------------- comfy/model_patcher.py | 2 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5da213f2c6d..a0105131d96 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -296,7 +296,7 @@ def model_offloaded_memory(self): def model_memory_required(self, device): if device == self.model.current_loaded_device(): - return 0 + return self.model_offloaded_memory() else: return self.model_memory() @@ -308,15 +308,21 @@ def model_load(self, lowvram_model_memory=0, force_patch_weights=False): load_weights = not self.weights_loaded - try: - if lowvram_model_memory > 0 and load_weights: - self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) - else: - self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) - except Exception as e: - self.model.unpatch_model(self.model.offload_device) - self.model_unload() - raise e + if self.model.loaded_size() > 0: + use_more_vram = lowvram_model_memory + if use_more_vram == 0: + use_more_vram = 1e32 + self.model_use_more_vram(use_more_vram) + else: + try: + if lowvram_model_memory > 0 and load_weights: + self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) + else: + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) + except Exception as e: + self.model.unpatch_model(self.model.offload_device) + self.model_unload() + raise e if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) @@ -484,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu total_memory_required = {} for loaded_model in models_to_load: - if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - for device in total_memory_required: - if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_already_loaded: + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for loaded_model in models_to_load: weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded if weights_unloaded is not None: loaded_model.weights_loaded = not weights_unloaded + for device in total_memory_required: + if device != torch.device("cpu"): + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6c67193ebcf..ae3d2051454 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -102,7 +102,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.size = size self.model = model if not hasattr(self.model, 'device'): - logging.info("Model doesn't have a device attribute.") + logging.debug("Model doesn't have a device attribute.") self.model.device = offload_device elif self.model.device is None: self.model.device = offload_device