From 387f48c003661cd18978bb5cbc859dd5dd4f30a3 Mon Sep 17 00:00:00 2001 From: Rypo Date: Mon, 25 Nov 2024 19:39:21 -0600 Subject: [PATCH 1/2] feat: fast model loading with accelerate Prevents slow CPU initialization of model weights on load by using accelerate `init_empty_weights`. Completely compatible with from_pretrained since weights will always be overwritten by state_dict fixes #72 --- OmniGen/model.py | 33 ++++++++++++++++++++++++++------- OmniGen/pipeline.py | 44 +++++++++++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index 8999a8e..7504e54 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -1,5 +1,6 @@ # The code is revised from DiT import os +import gc import torch import torch.nn as nn import numpy as np @@ -10,6 +11,7 @@ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp from huggingface_hub import snapshot_download from safetensors.torch import load_file +from accelerate import init_empty_weights from OmniGen.transformer import Phi3Config, Phi3Transformer @@ -187,20 +189,37 @@ def __init__( self.llm.config.use_cache = False @classmethod - def from_pretrained(cls, model_name): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, device: str|torch.device='cuda', low_cpu_mem_usage: bool = True,): if not os.path.exists(model_name): cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) - config = Phi3Config.from_pretrained(model_name) - model = cls(config) - if os.path.exists(os.path.join(model_name, 'model.safetensors')): + + model_path = os.path.join(model_name, 'model.safetensors') + if not os.path.exists(model_path): + model_path = os.path.join(model_name, 'model.pt') + ckpt = torch.load(model_path, map_location='cpu') + else: print("Loading safetensors") - ckpt = load_file(os.path.join(model_name, 'model.safetensors')) + ckpt = load_file(model_path, 'cpu') + + if low_cpu_mem_usage: + with init_empty_weights(): + config = Phi3Config.from_pretrained(model_name) + model = cls(config) + + model.load_state_dict(ckpt, assign=True) + model = model.to(device, dtype) else: - ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu') - model.load_state_dict(ckpt) + config = Phi3Config.from_pretrained(model_name) + model = cls(config) + model.load_state_dict(ckpt) + model = model.to(device, dtype) + + del ckpt + torch.cuda.empty_cache() + gc.collect() return model def initialize_weights(self): diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 09b0731..9f34086 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -41,6 +41,15 @@ ``` """ +def best_available_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") + device = torch.device("cpu") + return device class OmniGenPipeline: def __init__( @@ -55,14 +64,10 @@ def __init__( self.processor = processor self.device = device - if device is None: - if torch.cuda.is_available(): - self.device = torch.device("cuda") - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - else: - logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") - self.device = torch.device("cpu") + if self.device is None: + self.device = best_available_device() + elif isinstance(self.device, str): + self.device = torch.device(self.device) # self.model.to(torch.bfloat16) self.model.eval() @@ -71,7 +76,7 @@ def __init__( self.model_cpu_offload = False @classmethod - def from_pretrained(cls, model_name, vae_path: str=None): + def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True): if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"): logger.info("Model not found, downloading...") cache_folder = os.getenv('HF_HUB_CACHE') @@ -79,18 +84,23 @@ def from_pretrained(cls, model_name, vae_path: str=None): cache_dir=cache_folder, ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']) logger.info(f"Downloaded model to {model_name}") - model = OmniGen.from_pretrained(model_name) + + if device is None: + device = best_available_device() + + model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, device=device, low_cpu_mem_usage=low_cpu_mem_usage) processor = OmniGenProcessor.from_pretrained(model_name) - if os.path.exists(os.path.join(model_name, "vae")): - vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae")) - elif vae_path is not None: - vae = AutoencoderKL.from_pretrained(vae_path).to(device) - else: + if vae_path is None: + vae_path = os.path.join(model_name, "vae") + + if not os.path.exists(vae_path): logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") - vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) + vae_path = "stabilityai/sdxl-vae" + + vae = AutoencoderKL.from_pretrained(vae_path).to(device) - return cls(vae, model, processor) + return cls(vae, model, processor, device) def merge_lora(self, lora_path: str): model = PeftModel.from_pretrained(self.model, lora_path) From 0287b507b2c092cd280c52d2c0b9f4fe764fdd20 Mon Sep 17 00:00:00 2001 From: Rypo Date: Tue, 26 Nov 2024 13:31:37 -0600 Subject: [PATCH 2/2] fix: avoid moving model to device prematurely --- OmniGen/model.py | 6 +++--- OmniGen/pipeline.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index 7504e54..b45f2e1 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -189,7 +189,7 @@ def __init__( self.llm.config.use_cache = False @classmethod - def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, device: str|torch.device='cuda', low_cpu_mem_usage: bool = True,): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, low_cpu_mem_usage: bool = True,): if not os.path.exists(model_name): cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, @@ -210,12 +210,12 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch model = cls(config) model.load_state_dict(ckpt, assign=True) - model = model.to(device, dtype) + model = model.to(dtype) else: config = Phi3Config.from_pretrained(model_name) model = cls(config) model.load_state_dict(ckpt) - model = model.to(device, dtype) + model = model.to(dtype) del ckpt torch.cuda.empty_cache() diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 9f34086..c2325e4 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -88,7 +88,7 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me if device is None: device = best_available_device() - model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, device=device, low_cpu_mem_usage=low_cpu_mem_usage) + model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, low_cpu_mem_usage=low_cpu_mem_usage) processor = OmniGenProcessor.from_pretrained(model_name) if vae_path is None: