Skip to content

Commit

Permalink
refactor pipe weightsharing for quantized models
Browse files Browse the repository at this point in the history
  • Loading branch information
noskill committed Nov 29, 2024
1 parent a56ae2e commit c91dc39
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 161 deletions.
141 changes: 57 additions & 84 deletions multigen/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Type, List
from typing import Type, List, Union, Optional, Any
from dataclasses import dataclass
import random
import copy as cp
from contextlib import nullcontext
Expand All @@ -10,44 +11,32 @@
import diffusers

from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
else:
init_empty_weights = nullcontext

from .util import get_model_size, awailable_ram, quantize, weightshare_copy


logger = logging.getLogger(__file__)


def weightshare_copy(pipe):
@dataclass(frozen=True)
class ModelDescriptor:
"""
Create a new pipe object then assign weights using load_state_dict from passed 'pipe'
Descriptor class for model identification that includes quantization information
"""
copy = pipe.__class__(**pipe.components)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
for key, component in copy.components.items():
if getattr(copy, key) is None:
continue
if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'):
setattr(copy, key, cp.deepcopy(getattr(copy, key)))
continue
cls = getattr(copy, key).__class__
if hasattr(cls, 'from_config'):
setattr(copy, key, cls.from_config(getattr(copy, key).config))
else:
setattr(copy, key, cls(getattr(copy, key).config))
# assign=True is needed since our copy is on "meta" device, i.g. weights are empty
for key, component in copy.components.items():
if key == 'tokenizer' or key == 'tokenizer_2':
continue
obj = getattr(copy, key)
if hasattr(obj, 'load_state_dict'):
obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True)
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy
model_id: str
quantize_dtype: Optional[Any] = None

def __hash__(self):
return hash((self.model_id, str(self.quantize_dtype)))

def __eq__(self, other):
if isinstance(other, str):
return self.model_id == other

if not isinstance(other, ModelDescriptor):
return False
return (self.model_id == other.model_id and
self.quantize_dtype == other.quantize_dtype)


class Loader:
Expand All @@ -56,9 +45,8 @@ class for loading diffusion pipelines from files.
"""
def __init__(self):
self._lock = threading.RLock()
self._cpu_pipes = dict()
# idx -> list of (model_id, pipe) pairs
self._gpu_pipes = dict()
self._cpu_pipes = dict() # ModelDescriptor -> pipe
self._gpu_pipes = dict() # gpu idx -> list of (ModelDescriptor, pipe) pairs

def get_gpu(self, model_id) -> List[int]:
"""
Expand All @@ -73,24 +61,29 @@ def get_gpu(self, model_id) -> List[int]:
return result

def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bfloat16,
device=None, offload_device=None, **additional_args):
device=None, offload_device=None, quantize_dtype=None, **additional_args):
with self._lock:
logger.debug(f'looking for pipeline {cls} from {path} on {device}')
result = None
descriptor = ModelDescriptor(path, quantize_dtype)
found_quantized = False
if device is None:
device = torch.device('cpu', 0)
if device.type == 'cuda':
idx = device.index
gpu_pipes = self._gpu_pipes.get(idx, [])
for (key, value) in gpu_pipes:
if key == path:
if key == descriptor:
logger.debug(f'found pipe in gpu cache {key}')
result = self.from_pipe(cls, value, additional_args)
logger.debug(f'created pipe from gpu cache {key} on {device}')
return result
for (key, pipe) in self._cpu_pipes.items():
if key == path:
if key == descriptor:
found_quantized = True
logger.debug(f'found pipe in cpu cache {key} {pipe.device}')
if device.type == 'cuda':
pipe = cp.deepcopy(pipe)
result = self.from_pipe(cls, pipe, additional_args)
break
if result is None:
Expand All @@ -106,16 +99,25 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bf
logger.debug("prepare pipe before returning from loader")
logger.debug(f"{path} on {result.device} {result.dtype}")

# Add quantization if specified
if (not found_quantized) and quantize_dtype is not None:
logger.debug(f'Quantizing pipeline to {quantize_dtype}')
quantize(result, dtype=quantize_dtype)

if result.device != device:
result = result.to(dtype=torch_dtype, device=device)
if result.dtype != torch_dtype:
result = result.to(dtype=torch_dtype)

self.cache_pipeline(result, path)
logger.debug(f'result device before weightshare_copy {result.device}')
result = weightshare_copy(result)
logger.debug(f'result device after weightshare_copy {result.device}')
assert result.device.type == device.type
if device.type == 'cuda':
assert result.device.index == device.index
logger.debug(f'returning {type(result)} from {path} on {result.device}')
logger.debug(f'returning {type(result)} from {path} \
on {result.device} scheduler {id(result.scheduler)}')
return result

def from_pipe(self, cls, pipe, additional_args):
Expand All @@ -131,86 +133,57 @@ def from_pipe(self, cls, pipe, additional_args):
components.pop('controlnet')
return cls(**components, **additional_args)

def cache_pipeline(self, pipe: DiffusionPipeline, model_id):
def cache_pipeline(self, pipe: DiffusionPipeline, descriptor: ModelDescriptor):
logger.debug(f'caching pipeline {descriptor} {pipe}')
with self._lock:
device = pipe.device
if model_id not in self._cpu_pipes:
if descriptor not in self._cpu_pipes:
# deepcopy is needed since Module.to is an inplace operation
size = get_model_size(pipe)
ram = awailable_ram()
logger.debug(f'{model_id} has size {size}, ram {ram}')
logger.debug(f'{descriptor} has size {size}, ram {ram}')
if ram < size * 2.5 and self._cpu_pipes:
key_to_delete = random.choice(list(self._cpu_pipes.keys()))
self._cpu_pipes.pop(key_to_delete)
item = pipe
if pipe.device.type == 'cuda':
item = cp.deepcopy(pipe).to('cpu')
self._cpu_pipes[model_id] = item
logger.debug(f'storing {model_id} on cpu')
device = pipe.device
item = cp.deepcopy(pipe.to('cpu'))
pipe.to(device)
self._cpu_pipes[descriptor] = item
logger.debug(f'storing {descriptor} on cpu')
assert pipe.device == device
if pipe.device.type == 'cuda':
self._store_gpu_pipe(pipe, model_id)
logger.debug(f'storing {model_id} on {pipe.device}')
self._store_gpu_pipe(pipe, descriptor)
logger.debug(f'storing {descriptor} on {pipe.device}')

def clear_cache(self, device):
logger.debug(f'clear_cache pipelines from {device}')
with self._lock:
if device.type == 'cuda':
self._gpu_pipes[device.index] = []

def _store_gpu_pipe(self, pipe, model_id):
def _store_gpu_pipe(self, pipe, descriptor: ModelDescriptor):
idx = pipe.device.index
assert idx is not None
# for now just clear all other pipelines
self._gpu_pipes[idx] = [(model_id, pipe)]
self._gpu_pipes[idx] = [(descriptor, pipe)]

def remove_pipeline(self, model_id):
self._cpu_pipes.pop(model_id)

def get_pipeline(self, model_id, device=None):
def get_pipeline(self, descriptor: Union[ModelDescriptor, str], device=None):
with self._lock:
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', 0)
if device.type == 'cuda':
idx = device.index
gpu_pipes = self._gpu_pipes.get(idx, ())
for (key, value) in gpu_pipes:
if key == model_id:
if key == descriptor:
return value
for (key, pipe) in self._cpu_pipes.items():
if key == model_id:
if key == descriptor:
return pipe

return None


def count_params(model):
total_size = sum(param.numel() for param in model.parameters())
mul = 2
if model.dtype in (torch.float16, torch.bfloat16):
mul = 2
elif model.dtype == torch.float32:
mul = 4
return total_size * mul


def get_size(obj):
return sys.getsizeof(obj)


def get_model_size(pipeline):
total_size = 0
for name, component in pipeline.components.items():
if isinstance(component, torch.nn.Module):
total_size += count_params(component)
elif hasattr(component, 'tokenizer'):
total_size += count_params(component.tokenizer)
else:
total_size += get_size(component)
return total_size / (1024 * 1024)


def awailable_ram():
mem = psutil.virtual_memory()
available_ram = mem.available
return available_ram / (1024 * 1024)
94 changes: 94 additions & 0 deletions multigen/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
import sys
import torch
import psutil
from PIL import Image
import copy as cp
import optimum.quanto
from optimum.quanto import freeze, qfloat8, quantize as _quantize
from diffusers.utils import is_accelerate_available

import logging


if is_accelerate_available():
from accelerate import init_empty_weights
else:
init_empty_weights = nullcontext

def create_exif_metadata(im: Image, custom_metadata):
exif = im.getexif()
Expand Down Expand Up @@ -47,3 +61,83 @@ def pad_image_to_multiple_of_8(image: Image) -> Image:

return padded_image


def count_params(model):
total_size = sum(param.numel() for param in model.parameters())
mul = 2
if model.dtype in (torch.float16, torch.bfloat16):
mul = 2
elif model.dtype == torch.float32:
mul = 4
return total_size * mul


def get_size(obj):
return sys.getsizeof(obj)


def get_model_size(pipeline):
total_size = 0
for name, component in pipeline.components.items():
if isinstance(component, torch.nn.Module):
total_size += count_params(component)
elif hasattr(component, 'tokenizer'):
total_size += count_params(component.tokenizer)
else:
total_size += get_size(component)
return total_size / (1024 * 1024)


def awailable_ram():
mem = psutil.virtual_memory()
available_ram = mem.available
return available_ram / (1024 * 1024)


def quantize(pipe, dtype=qfloat8):
components = ['unet', 'transformer', 'text_encoder', 'text_encoder_2', 'vae']

for component in components:
if hasattr(pipe, component):
component_obj = getattr(pipe, component)
_quantize(component_obj, weights=dtype)
freeze(component_obj)
# Add attributes to indicate quantization
component_obj._is_quantized = True
component_obj._quantization_dtype = dtype


def weightshare_copy(pipe):
"""
Create a new pipe object then assign weights using load_state_dict from passed 'pipe'
"""
copy = pipe.__class__(**pipe.components)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
for key, component in copy.components.items():
if getattr(copy, key) is None:
continue
if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'):
setattr(copy, key, cp.deepcopy(getattr(copy, key)))
continue
cls = getattr(copy, key).__class__
if hasattr(cls, 'from_config'):
setattr(copy, key, cls.from_config(getattr(copy, key).config))
else:
setattr(copy, key, cls(getattr(copy, key).config))

pipe_component = getattr(pipe, key)
if getattr(pipe_component, '_is_quantized', False):
# Quantize the component in the copy using the same dtype
component_obj = getattr(copy, key)
_quantize(component_obj, weights=pipe_component._quantization_dtype)
# assign=True is needed since our copy is on "meta" device, i.g. weights are empty
for key, component in copy.components.items():
if key == 'tokenizer' or key == 'tokenizer_2':
continue
obj = getattr(copy, key)
if hasattr(obj, 'load_state_dict'):
obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True)
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy
Loading

0 comments on commit c91dc39

Please sign in to comment.