diff --git a/OmniGen/model.py b/OmniGen/model.py index 5389931..161fbfb 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -1,5 +1,6 @@ # The code is revised from DiT import os +import gc import torch import torch.nn as nn import numpy as np @@ -10,6 +11,7 @@ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp from huggingface_hub import snapshot_download from safetensors.torch import load_file +from accelerate import init_empty_weights from OmniGen.transformer import Phi3Config, Phi3Transformer @@ -187,20 +189,37 @@ def __init__( self.llm.config.use_cache = False @classmethod - def from_pretrained(cls, model_name): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, low_cpu_mem_usage: bool = True,): if not os.path.exists(model_name): cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) - config = Phi3Config.from_pretrained(model_name) - model = cls(config) - if os.path.exists(os.path.join(model_name, 'model.safetensors')): + + model_path = os.path.join(model_name, 'model.safetensors') + if not os.path.exists(model_path): + model_path = os.path.join(model_name, 'model.pt') + ckpt = torch.load(model_path, map_location='cpu') + else: print("Loading safetensors") - ckpt = load_file(os.path.join(model_name, 'model.safetensors')) + ckpt = load_file(model_path, 'cpu') + + if low_cpu_mem_usage: + with init_empty_weights(): + config = Phi3Config.from_pretrained(model_name) + model = cls(config) + + model.load_state_dict(ckpt, assign=True) + model = model.to(dtype) else: - ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu') - model.load_state_dict(ckpt) + config = Phi3Config.from_pretrained(model_name) + model = cls(config) + model.load_state_dict(ckpt) + model = model.to(dtype) + + del ckpt + torch.cuda.empty_cache() + gc.collect() return model def initialize_weights(self): diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 572452b..bf776c9 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -41,6 +41,15 @@ ``` """ +def best_available_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") + device = torch.device("cpu") + return device class OmniGenPipeline: def __init__( @@ -55,14 +64,10 @@ def __init__( self.processor = processor self.device = device - if device is None: - if torch.cuda.is_available(): - self.device = torch.device("cuda") - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - else: - logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") - self.device = torch.device("cpu") + if self.device is None: + self.device = best_available_device() + elif isinstance(self.device, str): + self.device = torch.device(self.device) # self.model.to(torch.bfloat16) self.model.eval() @@ -71,7 +76,7 @@ def __init__( self.model_cpu_offload = False @classmethod - def from_pretrained(cls, model_name, vae_path: str=None): + def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True): if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"): # logger.info("Model not found, downloading...") print("Model not found, downloading...") @@ -79,20 +84,24 @@ def from_pretrained(cls, model_name, vae_path: str=None): model_name = snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']) - # logger.info(f"Downloaded model to {model_name}") - print(f"Downloaded model to {model_name}") - model = OmniGen.from_pretrained(model_name) + logger.info(f"Downloaded model to {model_name}") + + if device is None: + device = best_available_device() + + model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, low_cpu_mem_usage=low_cpu_mem_usage) processor = OmniGenProcessor.from_pretrained(model_name) - if os.path.exists(os.path.join(model_name, "vae")): - vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae")) - elif vae_path is not None: - vae = AutoencoderKL.from_pretrained(vae_path).to(device) - else: + if vae_path is None: + vae_path = os.path.join(model_name, "vae") + + if not os.path.exists(vae_path): logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") - vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) + vae_path = "stabilityai/sdxl-vae" + + vae = AutoencoderKL.from_pretrained(vae_path).to(device) - return cls(vae, model, processor) + return cls(vae, model, processor, device) def merge_lora(self, lora_path: str): model = PeftModel.from_pretrained(self.model, lora_path)