diff --git a/multigen/loader.py b/multigen/loader.py index cbe9ed8..8ced8b0 100644 --- a/multigen/loader.py +++ b/multigen/loader.py @@ -1,4 +1,5 @@ -from typing import Type, List +from typing import Type, List, Union, Optional, Any +from dataclasses import dataclass import random import copy as cp from contextlib import nullcontext @@ -10,44 +11,32 @@ import diffusers from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline -from diffusers.utils import is_accelerate_available -if is_accelerate_available(): - from accelerate import init_empty_weights -else: - init_empty_weights = nullcontext + +from .util import get_model_size, awailable_ram, quantize, weightshare_copy logger = logging.getLogger(__file__) -def weightshare_copy(pipe): +@dataclass(frozen=True) +class ModelDescriptor: """ - Create a new pipe object then assign weights using load_state_dict from passed 'pipe' + Descriptor class for model identification that includes quantization information """ - copy = pipe.__class__(**pipe.components) - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - for key, component in copy.components.items(): - if getattr(copy, key) is None: - continue - if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'): - setattr(copy, key, cp.deepcopy(getattr(copy, key))) - continue - cls = getattr(copy, key).__class__ - if hasattr(cls, 'from_config'): - setattr(copy, key, cls.from_config(getattr(copy, key).config)) - else: - setattr(copy, key, cls(getattr(copy, key).config)) - # assign=True is needed since our copy is on "meta" device, i.g. weights are empty - for key, component in copy.components.items(): - if key == 'tokenizer' or key == 'tokenizer_2': - continue - obj = getattr(copy, key) - if hasattr(obj, 'load_state_dict'): - obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True) - # some buffers might not be transfered from pipe to copy - copy.to(pipe.device) - return copy + model_id: str + quantize_dtype: Optional[Any] = None + + def __hash__(self): + return hash((self.model_id, str(self.quantize_dtype))) + + def __eq__(self, other): + if isinstance(other, str): + return self.model_id == other + + if not isinstance(other, ModelDescriptor): + return False + return (self.model_id == other.model_id and + self.quantize_dtype == other.quantize_dtype) class Loader: @@ -56,9 +45,8 @@ class for loading diffusion pipelines from files. """ def __init__(self): self._lock = threading.RLock() - self._cpu_pipes = dict() - # idx -> list of (model_id, pipe) pairs - self._gpu_pipes = dict() + self._cpu_pipes = dict() # ModelDescriptor -> pipe + self._gpu_pipes = dict() # gpu idx -> list of (ModelDescriptor, pipe) pairs def get_gpu(self, model_id) -> List[int]: """ @@ -73,24 +61,29 @@ def get_gpu(self, model_id) -> List[int]: return result def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bfloat16, - device=None, offload_device=None, **additional_args): + device=None, offload_device=None, quantize_dtype=None, **additional_args): with self._lock: logger.debug(f'looking for pipeline {cls} from {path} on {device}') result = None + descriptor = ModelDescriptor(path, quantize_dtype) + found_quantized = False if device is None: device = torch.device('cpu', 0) if device.type == 'cuda': idx = device.index gpu_pipes = self._gpu_pipes.get(idx, []) for (key, value) in gpu_pipes: - if key == path: + if key == descriptor: logger.debug(f'found pipe in gpu cache {key}') result = self.from_pipe(cls, value, additional_args) logger.debug(f'created pipe from gpu cache {key} on {device}') return result for (key, pipe) in self._cpu_pipes.items(): - if key == path: + if key == descriptor: + found_quantized = True logger.debug(f'found pipe in cpu cache {key} {pipe.device}') + if device.type == 'cuda': + pipe = cp.deepcopy(pipe) result = self.from_pipe(cls, pipe, additional_args) break if result is None: @@ -106,16 +99,25 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bf logger.debug("prepare pipe before returning from loader") logger.debug(f"{path} on {result.device} {result.dtype}") + # Add quantization if specified + if (not found_quantized) and quantize_dtype is not None: + logger.debug(f'Quantizing pipeline to {quantize_dtype}') + quantize(result, dtype=quantize_dtype) + if result.device != device: result = result.to(dtype=torch_dtype, device=device) if result.dtype != torch_dtype: result = result.to(dtype=torch_dtype) + self.cache_pipeline(result, path) + logger.debug(f'result device before weightshare_copy {result.device}') result = weightshare_copy(result) + logger.debug(f'result device after weightshare_copy {result.device}') assert result.device.type == device.type if device.type == 'cuda': assert result.device.index == device.index - logger.debug(f'returning {type(result)} from {path} on {result.device}') + logger.debug(f'returning {type(result)} from {path} \ + on {result.device} scheduler {id(result.scheduler)}') return result def from_pipe(self, cls, pipe, additional_args): @@ -131,26 +133,29 @@ def from_pipe(self, cls, pipe, additional_args): components.pop('controlnet') return cls(**components, **additional_args) - def cache_pipeline(self, pipe: DiffusionPipeline, model_id): + def cache_pipeline(self, pipe: DiffusionPipeline, descriptor: ModelDescriptor): + logger.debug(f'caching pipeline {descriptor} {pipe}') with self._lock: device = pipe.device - if model_id not in self._cpu_pipes: + if descriptor not in self._cpu_pipes: # deepcopy is needed since Module.to is an inplace operation size = get_model_size(pipe) ram = awailable_ram() - logger.debug(f'{model_id} has size {size}, ram {ram}') + logger.debug(f'{descriptor} has size {size}, ram {ram}') if ram < size * 2.5 and self._cpu_pipes: key_to_delete = random.choice(list(self._cpu_pipes.keys())) self._cpu_pipes.pop(key_to_delete) item = pipe if pipe.device.type == 'cuda': - item = cp.deepcopy(pipe).to('cpu') - self._cpu_pipes[model_id] = item - logger.debug(f'storing {model_id} on cpu') + device = pipe.device + item = cp.deepcopy(pipe.to('cpu')) + pipe.to(device) + self._cpu_pipes[descriptor] = item + logger.debug(f'storing {descriptor} on cpu') assert pipe.device == device if pipe.device.type == 'cuda': - self._store_gpu_pipe(pipe, model_id) - logger.debug(f'storing {model_id} on {pipe.device}') + self._store_gpu_pipe(pipe, descriptor) + logger.debug(f'storing {descriptor} on {pipe.device}') def clear_cache(self, device): logger.debug(f'clear_cache pipelines from {device}') @@ -158,16 +163,16 @@ def clear_cache(self, device): if device.type == 'cuda': self._gpu_pipes[device.index] = [] - def _store_gpu_pipe(self, pipe, model_id): + def _store_gpu_pipe(self, pipe, descriptor: ModelDescriptor): idx = pipe.device.index assert idx is not None # for now just clear all other pipelines - self._gpu_pipes[idx] = [(model_id, pipe)] + self._gpu_pipes[idx] = [(descriptor, pipe)] def remove_pipeline(self, model_id): self._cpu_pipes.pop(model_id) - def get_pipeline(self, model_id, device=None): + def get_pipeline(self, descriptor: Union[ModelDescriptor, str], device=None): with self._lock: if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', 0) @@ -175,42 +180,10 @@ def get_pipeline(self, model_id, device=None): idx = device.index gpu_pipes = self._gpu_pipes.get(idx, ()) for (key, value) in gpu_pipes: - if key == model_id: + if key == descriptor: return value for (key, pipe) in self._cpu_pipes.items(): - if key == model_id: + if key == descriptor: return pipe return None - - -def count_params(model): - total_size = sum(param.numel() for param in model.parameters()) - mul = 2 - if model.dtype in (torch.float16, torch.bfloat16): - mul = 2 - elif model.dtype == torch.float32: - mul = 4 - return total_size * mul - - -def get_size(obj): - return sys.getsizeof(obj) - - -def get_model_size(pipeline): - total_size = 0 - for name, component in pipeline.components.items(): - if isinstance(component, torch.nn.Module): - total_size += count_params(component) - elif hasattr(component, 'tokenizer'): - total_size += count_params(component.tokenizer) - else: - total_size += get_size(component) - return total_size / (1024 * 1024) - - -def awailable_ram(): - mem = psutil.virtual_memory() - available_ram = mem.available - return available_ram / (1024 * 1024) diff --git a/multigen/util.py b/multigen/util.py index d03f611..2d3de57 100644 --- a/multigen/util.py +++ b/multigen/util.py @@ -1,5 +1,19 @@ +import sys +import torch +import psutil from PIL import Image +import copy as cp +import optimum.quanto +from optimum.quanto import freeze, qfloat8, quantize as _quantize +from diffusers.utils import is_accelerate_available +import logging + + +if is_accelerate_available(): + from accelerate import init_empty_weights +else: + init_empty_weights = nullcontext def create_exif_metadata(im: Image, custom_metadata): exif = im.getexif() @@ -47,3 +61,83 @@ def pad_image_to_multiple_of_8(image: Image) -> Image: return padded_image + +def count_params(model): + total_size = sum(param.numel() for param in model.parameters()) + mul = 2 + if model.dtype in (torch.float16, torch.bfloat16): + mul = 2 + elif model.dtype == torch.float32: + mul = 4 + return total_size * mul + + +def get_size(obj): + return sys.getsizeof(obj) + + +def get_model_size(pipeline): + total_size = 0 + for name, component in pipeline.components.items(): + if isinstance(component, torch.nn.Module): + total_size += count_params(component) + elif hasattr(component, 'tokenizer'): + total_size += count_params(component.tokenizer) + else: + total_size += get_size(component) + return total_size / (1024 * 1024) + + +def awailable_ram(): + mem = psutil.virtual_memory() + available_ram = mem.available + return available_ram / (1024 * 1024) + + +def quantize(pipe, dtype=qfloat8): + components = ['unet', 'transformer', 'text_encoder', 'text_encoder_2', 'vae'] + + for component in components: + if hasattr(pipe, component): + component_obj = getattr(pipe, component) + _quantize(component_obj, weights=dtype) + freeze(component_obj) + # Add attributes to indicate quantization + component_obj._is_quantized = True + component_obj._quantization_dtype = dtype + + +def weightshare_copy(pipe): + """ + Create a new pipe object then assign weights using load_state_dict from passed 'pipe' + """ + copy = pipe.__class__(**pipe.components) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + for key, component in copy.components.items(): + if getattr(copy, key) is None: + continue + if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'): + setattr(copy, key, cp.deepcopy(getattr(copy, key))) + continue + cls = getattr(copy, key).__class__ + if hasattr(cls, 'from_config'): + setattr(copy, key, cls.from_config(getattr(copy, key).config)) + else: + setattr(copy, key, cls(getattr(copy, key).config)) + + pipe_component = getattr(pipe, key) + if getattr(pipe_component, '_is_quantized', False): + # Quantize the component in the copy using the same dtype + component_obj = getattr(copy, key) + _quantize(component_obj, weights=pipe_component._quantization_dtype) + # assign=True is needed since our copy is on "meta" device, i.g. weights are empty + for key, component in copy.components.items(): + if key == 'tokenizer' or key == 'tokenizer_2': + continue + obj = getattr(copy, key) + if hasattr(obj, 'load_state_dict'): + obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True) + # some buffers might not be transfered from pipe to copy + copy.to(pipe.device) + return copy diff --git a/multigen/worker.py b/multigen/worker.py index 867e7ab..413fa62 100755 --- a/multigen/worker.py +++ b/multigen/worker.py @@ -4,6 +4,7 @@ import concurrent from queue import Empty import PIL +from optimum.quanto import qfloat8 from .worker_base import ServiceThreadBase from .prompting import Cfgen @@ -29,19 +30,20 @@ def _get_device(self, model_id): # choose random from resting gpus # if there is no resting gpus choose # one with our model_id otherwise choose random - devices = list(range(torch.cuda.device_count())) - self.logger.debug('awailable devices %s', devices) - if not devices: - self.logger.debug('returning cpu device') - return torch.device('cpu') with self._lock: - self.logger.debug('locked gpu %s', self._gpu_jobs) + devices = list(range(torch.cuda.device_count())) + self.logger.debug('cuda devices %s', devices) + if not devices: + self.logger.debug('returning cpu device') + return torch.device('cpu') + self.logger.debug('gpu jobs %s', self._gpu_jobs) free_gpus = [x for x in devices if x not in self._gpu_jobs] self.logger.debug('free gpus %s', free_gpus) if free_gpus: idx = random.choice(free_gpus) else: self.logger.debug('no free gpus') + raise RuntimeError("no free gpus") gpus_with_model = self._loader.get_gpu(model_id) if gpus_with_model: idx = random.choice(gpus_with_model) @@ -51,7 +53,7 @@ def _get_device(self, model_id): self.logger.debug(f'locked device cuda:{idx} for {model_id}') return torch.device('cuda', idx) - def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None): + def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None, quantize_dtype=None): device = self._get_device(model_id) offload_device = None if cnet is None: @@ -59,15 +61,16 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None): if model_type == ModelType.SDXL: cls = pipe_class._classxl elif model_type == ModelType.FLUX: - # use offload by default for now + # use quantisation by default for now cls = pipe_class._classflux if device.type == 'cuda': - offload_device = device.index - device = torch.device('cpu') + quantize_dtype = qfloat8 + # offload_device = device.index + # device = torch.device('cpu') else: cls = pipe_class._class pipeline = self._loader.load_pipeline(cls, model_id, torch_dtype=torch.bfloat16, - device=device) + device=device, quantize_dtype=quantize_dtype) self.logger.debug(f'requested {cls} {model_id} on device {device}, got {pipeline.device}') pipe = pipe_class(model_id, pipe=pipeline, device=device, offload_device=offload_device) if offload_device is None: @@ -77,9 +80,6 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None): pipeline = self._loader.get_pipeline(model_id, device=device, model_type=model_type) if model_type == FLUX: cnet_type = ModelType.FLUX - if device.type == 'cuda': - offload_device = device.index - device = torch.device('cpu') if pipeline is None or 'controlnet' not in pipeline.components: # reload pipe = pipe_class(model_id, ctypes=[cnet], model_type=model_type, device=device, offload_device=offload_device) @@ -89,7 +89,6 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None): return pipe def run(self): - self.logger.debug('running thread') num_of_workers = torch.cuda.device_count() if num_of_workers == 0: num_of_workers = 1 @@ -103,6 +102,7 @@ def run(self): time.sleep(0.2) def worker(self, data): + self.logger.debug('running worker') def _update(sess, job, gs): sess["images"].append(gs.last_img_name) if 'image_callback' in data: @@ -111,69 +111,71 @@ def _update(sess, job, gs): device = None # keep the job in the queue until complete try: - sess = data.get('session', None) - session_id = data["session_id"] - if sess is None: - sess = self.sessions[session_id] - sess['status'] ='running' - self.logger.info("GENERATING: " + str(data)) - if 'start_callback' in data: - data['start_callback']() + with torch.no_grad(): + sess = data.get('session', None) + session_id = data["session_id"] + if sess is None: + sess = self.sessions[session_id] + sess['status'] ='running' + self.logger.info("GENERATING: " + str(data)) + if 'start_callback' in data: + data['start_callback']() - pipe_name = sess.get('pipe', 'Prompt2ImPipe') - model_id = str(self.cwd/self.config["model_dir"]/self.models['base'][sess["model"]]['id']) - loras = [str(self.cwd/self.config["model_dir"]/'Lora'/self.models['lora'][x]['id']) for x in sess.get("loras", [])] - data['loras'] = loras - mt = self.models['base'][sess["model"]]['type'] - if mt == SDXL: - model_type = ModelType.SDXL - elif mt == SD: - model_type = ModelType.SD - elif mt == FLUX: - model_type = ModelType.FLUX - else: - raise RuntimeError(f"unexpected model type {mt}") - pipe = self.get_pipeline(pipe_name, model_id, model_type, cnet=data.get('cnet', None)) - device = pipe.pipe.device - offload_device = None - if hasattr(pipe, 'offload_gpu_id'): - offload_device = pipe.offload_gpu_id - self.logger.debug(f'running job on {device} offload {offload_device}') - if device.type in ['cuda', 'meta']: - with self._lock: - if device.type == 'meta': - self._gpu_jobs[offload_device] = model_id - else: - self._gpu_jobs[device.index] = model_id - class_name = str(pipe.__class__) - self.logger.debug(f'got pipeline {class_name}') + pipe_name = sess.get('pipe', 'Prompt2ImPipe') + model_id = str(self.cwd/self.config["model_dir"]/self.models['base'][sess["model"]]['id']) + loras = [str(self.cwd/self.config["model_dir"]/'Lora'/self.models['lora'][x]['id']) for x in sess.get("loras", [])] + data['loras'] = loras + mt = self.models['base'][sess["model"]]['type'] + if mt == SDXL: + model_type = ModelType.SDXL + elif mt == SD: + model_type = ModelType.SD + elif mt == FLUX: + model_type = ModelType.FLUX + else: + raise RuntimeError(f"unexpected model type {mt}") + pipe = self.get_pipeline(pipe_name, model_id, model_type, cnet=data.get('cnet', None)) + device = pipe.pipe.device + offload_device = None + if hasattr(pipe, 'offload_gpu_id'): + offload_device = pipe.offload_gpu_id + self.logger.debug(f'running job on {device} offload {offload_device}') + if device.type in ['cuda', 'meta']: + with self._lock: + if device.type == 'meta': + self._gpu_jobs[offload_device] = model_id + else: + self._gpu_jobs[device.index] = model_id + class_name = str(pipe.__class__) + self.logger.debug(f'got pipeline {class_name}') - images = data.get('images', None) - if images and 'MaskedIm2ImPipe' in class_name: - pipe.setup(**data, original_image=str(images[0]), - image_painted=str(images[1])) - elif images and any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]): - if isinstance(images[0], PIL.Image.Image): - pipe.setup(**data, fimage=None, image=images[0]) + images = data.get('images', None) + if images and 'MaskedIm2ImPipe' in class_name: + pipe.setup(**data, original_image=str(images[0]), + image_painted=str(images[1])) + elif images and any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]): + if isinstance(images[0], PIL.Image.Image): + pipe.setup(**data, fimage=None, image=images[0]) + else: + pipe.setup(**data, fimage=str(images[0])) else: - pipe.setup(**data, fimage=str(images[0])) - else: - pipe.setup(**data) - # TODO: add negative prompt to parameters - nprompt_default = "jpeg artifacts, blur, distortion, watermark, signature, extra fingers, fewer fingers, lowres, nude, bad hands, duplicate heads, bad anatomy, bad crop" - nprompt = data.get('nprompt', nprompt_default) - seeds = data.get('seeds', None) - self.logger.debug(f"offload_device {pipe.offload_gpu_id}") - directory = data.get('gen_dir', None) - if directory is None: - directory = self.get_image_pathname(data["session_id"], None) - gs = GenSession(directory, - pipe, Cfgen(data["prompt"], nprompt, seeds=seeds)) - gs.gen_sess(add_count = data["count"], - callback = lambda: _update(sess, data, gs)) - if 'finish_callback' in data: - data['finish_callback']() - except (RuntimeError, TypeError, NotImplementedError, OSError) as e: + pipe.setup(**data) + # TODO: add negative prompt to parameters + nprompt_default = "jpeg artifacts, blur, distortion, watermark, signature, extra fingers, fewer fingers, lowres, nude, bad hands, duplicate heads, bad anatomy, bad crop" + nprompt = data.get('nprompt', nprompt_default) + seeds = data.get('seeds', None) + self.logger.debug(f"device {device} offload_device {pipe.offload_gpu_id}") + directory = data.get('gen_dir', None) + if directory is None: + directory = self.get_image_pathname(data["session_id"], None) + gs = GenSession(directory, + pipe, Cfgen(data["prompt"], nprompt, seeds=seeds)) + gs.gen_sess(add_count = data["count"], + callback = lambda: _update(sess, data, gs)) + self.logger.info(f"running finish callback") + if 'finish_callback' in data: + data['finish_callback']() + except (RuntimeError, TypeError, NotImplementedError, OSError, IndexError) as e: self.logger.error("error in generation", exc_info=e) if hasattr(pipe.pipe, '_offload_gpu_id'): self.logger.error(f"offload_device {pipe.pipe._offload_gpu_id}") @@ -185,10 +187,12 @@ def _update(sess, job, gs): finally: with self._lock: index = None + self.logger.info(f"finished job, unlocking device {device}") if device is not None and device.type == 'cuda': index = device.index - if pipe.pipe._offload_gpu_id is not None: + self.logger.debug(f'device index {index}') + if hasattr(pipe.pipe, '_offload_gpu_id') and pipe.pipe._offload_gpu_id is not None: index = pipe.pipe._offload_gpu_id if index is not None: - self.logger.debug('unlock device %s', index) + self.logger.debug(f'unlock device {index}') del self._gpu_jobs[index]