diff --git a/scripts/prepare_test_weights.py b/scripts/prepare_test_weights.py index 7837c10d7..19f942375 100644 --- a/scripts/prepare_test_weights.py +++ b/scripts/prepare_test_weights.py @@ -238,6 +238,11 @@ def download_loras(): "https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder ) + dest_folder = os.path.join(test_weights_dir, "loras", "sliders") + download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder) + download_file("https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder) + download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder) + def download_preprocessors(): dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings") diff --git a/src/refiners/fluxion/adapters/lora.py b/src/refiners/fluxion/adapters/lora.py index 6b8e99259..06ed305ca 100644 --- a/src/refiners/fluxion/adapters/lora.py +++ b/src/refiners/fluxion/adapters/lora.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Any from torch import Tensor, device as Device, dtype as DType from torch.nn import Parameter as TorchParameter @@ -7,18 +6,20 @@ import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import Adapter -from refiners.fluxion.layers.chain import Chain class Lora(fl.Chain, ABC): def __init__( self, + name: str, + /, rank: int = 16, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None, ) -> None: - self.rank = rank + self.name = name + self._rank = rank self._scale = scale super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale)) @@ -44,6 +45,10 @@ def up(self) -> fl.WeightedModule: assert isinstance(up_layer, fl.WeightedModule) return up_layer + @property + def rank(self) -> int: + return self._rank + @property def scale(self) -> float: return self._scale @@ -56,19 +61,21 @@ def scale(self, value: float) -> None: @classmethod def from_weights( cls, + name: str, + /, down: Tensor, up: Tensor, ) -> "Lora": match (up.ndim, down.ndim): case (2, 2): - return LinearLora.from_weights(up=up, down=down) + return LinearLora.from_weights(name, up=up, down=down) case (4, 4): - return Conv2dLora.from_weights(up=up, down=down) + return Conv2dLora.from_weights(name, up=up, down=down) case _: raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}") @classmethod - def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]: + def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]: """ Create a dictionary of LoRA layers from a state dict. @@ -80,13 +87,37 @@ def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]: list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2] ): key = ".".join(down_key.split(".")[:-2]) - loras[key] = cls.from_weights(down=down_tensor, up=up_tensor) + loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor) return loras @abstractmethod - def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any: + def is_compatible(self, layer: fl.WeightedModule, /) -> bool: ... + def auto_attach( + self, target: fl.Chain, exclude: list[str] | None = None + ) -> "tuple[LoraAdapter, fl.Chain | None] | None": + for layer, parent in target.walk(self.up.__class__): + if isinstance(parent, Lora): + continue + + if exclude is not None and any( + [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] + ): + continue + + if not self.is_compatible(layer): + continue + + if isinstance(parent, LoraAdapter): + if self.name in parent.names: + continue + else: + parent.add_lora(self) + return parent, None + + return LoraAdapter(layer, self), parent + def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: assert down_weight.shape == self.down.weight.shape assert up_weight.shape == self.up.weight.shape @@ -97,6 +128,8 @@ def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: class LinearLora(Lora): def __init__( self, + name: str, + /, in_features: int, out_features: int, rank: int = 16, @@ -107,35 +140,29 @@ def __init__( self.in_features = in_features self.out_features = out_features - super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) + super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype) @classmethod def from_weights( cls, + name: str, + /, down: Tensor, up: Tensor, ) -> "LinearLora": assert up.ndim == 2 and down.ndim == 2 assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" lora = cls( - in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype + name, + in_features=down.shape[1], + out_features=up.shape[0], + rank=down.shape[0], + device=up.device, + dtype=up.dtype, ) lora.load_weights(down_weight=down, up_weight=up) return lora - def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None": - for layer, parent in target.walk(fl.Linear): - if isinstance(parent, Lora) or isinstance(parent, LoraAdapter): - continue - - if exclude is not None and any( - [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] - ): - continue - - if layer.in_features == self.in_features and layer.out_features == self.out_features: - return LoraAdapter(target=layer, lora=self), parent - def lora_layers( self, device: Device | str | None = None, dtype: DType | None = None ) -> tuple[fl.Linear, fl.Linear]: @@ -156,10 +183,17 @@ def lora_layers( ), ) + def is_compatible(self, layer: fl.WeightedModule, /) -> bool: + if isinstance(layer, fl.Linear): + return layer.in_features == self.in_features and layer.out_features == self.out_features + return False + class Conv2dLora(Lora): def __init__( self, + name: str, + /, in_channels: int, out_channels: int, rank: int = 16, @@ -176,20 +210,24 @@ def __init__( self.stride = stride self.padding = padding - super().__init__(rank=rank, scale=scale, device=device, dtype=dtype) + super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype) @classmethod def from_weights( cls, + name: str, + /, down: Tensor, up: Tensor, ) -> "Conv2dLora": assert up.ndim == 4 and down.ndim == 4 assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}" down_kernel_size, up_kernel_size = down.shape[2], up.shape[2] + # padding is set so the spatial dimensions are preserved (assuming stride=1 and kernel_size either 1 or 3) down_padding = 1 if down_kernel_size == 3 else 0 up_padding = 1 if up_kernel_size == 3 else 0 lora = cls( + name, in_channels=down.shape[1], out_channels=up.shape[0], rank=down.shape[0], @@ -201,25 +239,6 @@ def from_weights( lora.load_weights(down_weight=down, up_weight=up) return lora - def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None": - for layer, parent in target.walk(fl.Conv2d): - if isinstance(parent, Lora) or isinstance(parent, LoraAdapter): - continue - - if exclude is not None and any( - [any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude] - ): - continue - - if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels: - if layer.stride != (self.stride[0], self.stride[0]): - self.down.stride = layer.stride - - return LoraAdapter( - target=layer, - lora=self, - ), parent - def lora_layers( self, device: Device | str | None = None, dtype: DType | None = None ) -> tuple[fl.Conv2d, fl.Conv2d]: @@ -246,20 +265,47 @@ def lora_layers( ), ) + def is_compatible(self, layer: fl.WeightedModule, /) -> bool: + if ( + isinstance(layer, fl.Conv2d) + and layer.in_channels == self.in_channels + and layer.out_channels == self.out_channels + ): + # stride cannot be inferred from the weights, so we assume it's the same as the layer + self.down.stride = layer.stride + + return True + return False + class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]): - def __init__(self, target: fl.WeightedModule, lora: Lora) -> None: + def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None: with self.setup_adapter(target): - super().__init__(target, lora) + super().__init__(target, *loras) @property - def lora(self) -> Lora: - return self.ensure_find(Lora) + def names(self) -> list[str]: + return [lora.name for lora in self.layers(Lora)] @property - def scale(self) -> float: - return self.lora.scale + def loras(self) -> dict[str, Lora]: + return {lora.name: lora for lora in self.layers(Lora)} - @scale.setter - def scale(self, value: float) -> None: - self.lora.scale = value + @property + def scales(self) -> dict[str, float]: + return {lora.name: lora.scale for lora in self.layers(Lora)} + + @scales.setter + def scale(self, values: dict[str, float]) -> None: + for name, value in values.items(): + self.loras[name].scale = value + + def add_lora(self, lora: Lora, /) -> None: + assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists" + self.append(lora) + + def remove_lora(self, name: str, /) -> Lora | None: + if name in self.names: + lora = self.loras[name] + self.remove(lora) + return lora diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index d25115d20..884874bf4 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -26,19 +26,20 @@ def clip_text_encoder(self) -> fl.Chain: assert isinstance(clip_text_encoder, fl.Chain) return clip_text_encoder - def load( + def add_loras( self, - tensors: dict[str, Tensor], + name: str, /, + tensors: dict[str, Tensor], scale: float = 1.0, ) -> None: """Load the LoRA weights from a dictionary of tensors. Expects the keys to be in the commonly found formats on CivitAI's hub. """ - assert len(self.lora_adapters) == 0, "Loras already loaded" + assert name not in self.names, f"LoRA {name} already exists" loras = Lora.from_dict( - {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()} + name, {key: value.to(device=self.target.device, dtype=self.target.dtype) for key, value in tensors.items()} ) loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)} @@ -46,16 +47,25 @@ def load( if not "unet" in loras and not "text" in loras: loras = {f"unet_{key}": loras[key] for key in loras.keys()} - self.load_unet(loras) - self.load_text_encoder(loras) + self.add_loras_to_unet(loras) + self.add_loras_to_text_encoder(loras) - self.scale = scale + self.set_scale(name, scale) + + def add_multiple_loras( + self, + /, + tensors: dict[str, dict[str, Tensor]], + scale: dict[str, float] | None = None, + ) -> None: + for name, lora_tensors in tensors.items(): + self.add_loras(name, tensors=lora_tensors, scale=scale[name] if scale else 1.0) - def load_text_encoder(self, loras: dict[str, Lora], /) -> None: + def add_loras_to_text_encoder(self, loras: dict[str, Lora], /) -> None: text_encoder_loras = {key: loras[key] for key in loras.keys() if "text" in key} SDLoraManager.auto_attach(text_encoder_loras, self.clip_text_encoder) - def load_unet(self, loras: dict[str, Lora], /) -> None: + def add_loras_to_unet(self, loras: dict[str, Lora], /) -> None: unet_loras = {key: loras[key] for key in loras.keys() if "unet" in key} exclude: list[str] = [] exclude = [ @@ -65,14 +75,43 @@ def load_unet(self, loras: dict[str, Lora], /) -> None: ] SDLoraManager.auto_attach(unet_loras, self.unet, exclude=exclude) - def unload(self) -> None: + def remove_loras(self, *names: str) -> None: + for lora_adapter in self.lora_adapters: + for name in names: + lora_adapter.remove_lora(name) + + if len(lora_adapter.loras) == 0: + lora_adapter.eject() + + def remove_all(self) -> None: for lora_adapter in self.lora_adapters: lora_adapter.eject() + def get_loras_by_name(self, name: str, /) -> list[Lora]: + return [lora for lora in self.loras if lora.name == name] + + def get_scale(self, name: str, /) -> float: + loras = self.get_loras_by_name(name) + assert all([lora.scale == loras[0].scale for lora in loras]), "lora scales are not all the same" + return loras[0].scale + + def set_scale(self, name: str, scale: float, /) -> None: + self.update_scales({name: scale}) + + def update_scales(self, scales: dict[str, float], /) -> None: + assert all([name in self.names for name in scales]), f"Scales keys must be a subset of {self.names}" + for name, scale in scales.items(): + for lora in self.get_loras_by_name(name): + lora.scale = scale + @property def loras(self) -> list[Lora]: return list(self.unet.layers(Lora)) + list(self.clip_text_encoder.layers(Lora)) + @property + def names(self) -> list[str]: + return list(set(lora.name for lora in self.loras)) + @property def lora_adapters(self) -> list[LoraAdapter]: return list(self.unet.layers(LoraAdapter)) + list(self.clip_text_encoder.layers(LoraAdapter)) @@ -87,15 +126,8 @@ def unet_exclusions(self) -> dict[str, str]: } @property - def scale(self) -> float: - assert len(self.loras) > 0, "No loras found" - assert all([lora.scale == self.loras[0].scale for lora in self.loras]) - return self.loras[0].scale - - @scale.setter - def scale(self, value: float) -> None: - for lora in self.loras: - lora.scale = value + def scales(self) -> dict[str, float]: + return {name: self.get_scale(name) for name in self.names} @staticmethod def pad(input: str, /, padding_length: int = 2) -> str: @@ -130,7 +162,9 @@ def auto_attach( for key, lora in loras.items(): if attach := lora.auto_attach(target, exclude=exclude): adapter, parent = attach - adapter.inject(parent) + # if parent is None, `adapter` is already attached and `lora` has been added to it + if parent is not None: + adapter.inject(parent) else: failed_loras[key] = lora diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py new file mode 100644 index 000000000..d298aa117 --- /dev/null +++ b/tests/adapters/test_lora.py @@ -0,0 +1,118 @@ +import pytest +import torch + +from refiners.fluxion import layers as fl +from refiners.fluxion.adapters.lora import Conv2dLora, LinearLora, Lora, LoraAdapter + + +@pytest.fixture +def lora() -> LinearLora: + return LinearLora("test", in_features=320, out_features=128, rank=16) + + +@pytest.fixture +def conv_lora() -> Lora: + return Conv2dLora("conv_test", in_channels=16, out_channels=8, kernel_size=(3, 1), rank=4) + + +def test_properties(lora: LinearLora, conv_lora: Lora) -> None: + assert lora.name == "test" + assert lora.rank == lora.down.out_features == lora.up.in_features == 16 + assert lora.scale == 1.0 + assert lora.in_features == lora.down.in_features == 320 + assert lora.out_features == lora.up.out_features == 128 + + assert conv_lora.name == "conv_test" + assert conv_lora.rank == conv_lora.down.out_channels == conv_lora.up.in_channels == 4 + assert conv_lora.scale == 1.0 + assert conv_lora.in_channels == conv_lora.down.in_channels == 16 + assert conv_lora.out_channels == conv_lora.up.out_channels == 8 + assert conv_lora.kernel_size == (conv_lora.down.kernel_size[0], conv_lora.up.kernel_size[0]) == (3, 1) + # padding is set so the spatial dimensions are preserved + assert conv_lora.padding == (conv_lora.down.padding[0], conv_lora.up.padding[0]) == (0, 1) + + +def test_scale_setter(lora: LinearLora) -> None: + lora.scale = 2.0 + assert lora.scale == 2.0 + assert lora.ensure_find(fl.Multiply).scale == 2.0 + + +def test_from_weights(lora: LinearLora, conv_lora: Conv2dLora) -> None: + new_lora = LinearLora.from_weights("test", down=lora.down.weight, up=lora.up.weight) + x = torch.randn(1, 320) + assert torch.allclose(lora(x), new_lora(x)) + + new_conv_lora = Conv2dLora.from_weights("conv_test", down=conv_lora.down.weight, up=conv_lora.up.weight) + x = torch.randn(1, 16, 64, 64) + assert torch.allclose(conv_lora(x), new_conv_lora(x)) + + +def test_from_dict() -> None: + state_dict = { + "down.weight": torch.randn(320, 128), + "up.weight": torch.randn(128, 320), + "this.is_not_used.alpha": torch.randn(1, 320), + "probably.a.conv.down.weight": torch.randn(4, 16, 3, 3), + "probably.a.conv.up.weight": torch.randn(8, 4, 1, 1), + } + loras = Lora.from_dict("test", state_dict=state_dict) + assert len(loras) == 2 + linear_lora, conv_lora = loras.values() + assert isinstance(linear_lora, LinearLora) + assert isinstance(conv_lora, Conv2dLora) + assert linear_lora.name == "test" + assert conv_lora.name == "test" + + +@pytest.fixture +def lora_adapter() -> LoraAdapter: + target = fl.Linear(320, 128) + lora1 = LinearLora("test1", in_features=320, out_features=128, rank=16, scale=2.0) + lora2 = LinearLora("test2", in_features=320, out_features=128, rank=16, scale=-1.0) + return LoraAdapter(target, lora1, lora2) + + +def test_names(lora_adapter: LoraAdapter) -> None: + assert set(lora_adapter.names) == {"test1", "test2"} + + +def test_loras(lora_adapter: LoraAdapter) -> None: + assert set(lora_adapter.loras.keys()) == {"test1", "test2"} + + +def test_scales(lora_adapter: LoraAdapter) -> None: + assert set(lora_adapter.scales.keys()) == {"test1", "test2"} + assert lora_adapter.scales["test1"] == 2.0 + assert lora_adapter.scales["test2"] == -1.0 + + +def test_scale_setter_lora_adapter(lora_adapter: LoraAdapter) -> None: + lora_adapter.scale = {"test1": 0.0, "test2": 3.0} + assert lora_adapter.scales == {"test1": 0.0, "test2": 3.0} + + +def test_add_lora(lora_adapter: LoraAdapter) -> None: + lora3 = LinearLora("test3", in_features=320, out_features=128, rank=16) + lora_adapter.add_lora(lora3) + assert "test3" in lora_adapter.names + + +def test_remove_lora(lora_adapter: LoraAdapter) -> None: + lora_adapter.remove_lora("test1") + assert "test1" not in lora_adapter.names + + +def test_add_existing_lora(lora_adapter: LoraAdapter) -> None: + lora3 = LinearLora("test1", in_features=320, out_features=128, rank=16) + with pytest.raises(AssertionError): + lora_adapter.add_lora(lora3) + + +def test_remove_nonexistent_lora(lora_adapter: LoraAdapter) -> None: + assert lora_adapter.remove_lora("test3") is None + + +def test_set_scale_for_nonexistent_lora(lora_adapter: LoraAdapter) -> None: + with pytest.raises(KeyError): + lora_adapter.scale = {"test3": 2.0} diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 0c544e0e3..d59c627b5 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -216,6 +216,26 @@ def lora_data_dpo(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, return expected_image, tensors +@pytest.fixture(scope="module") +def lora_sliders(test_weights_path: Path) -> tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]]: + weights_path = test_weights_path / "loras" / "sliders" + + if not weights_path.is_dir(): + warn(f"could not find weights at {weights_path}, skipping") + pytest.skip(allow_module_level=True) + + return { + "age": load_tensors(weights_path / "age.pt"), # type: ignore + "cartoon_style": load_tensors(weights_path / "cartoon_style.pt"), # type: ignore + "eyesize": load_tensors(weights_path / "eyesize.pt"), # type: ignore + }, { + "age": 0.3, + "cartoon_style": -0.2, + "dpo": 1.4, + "eyesize": -0.2, + } + + @pytest.fixture def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image: return Image.open(ref_path / "inpainting-scene.png").convert("RGB") @@ -266,6 +286,11 @@ def expected_freeu(ref_path: Path) -> Image.Image: return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB") +@pytest.fixture +def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image: + return Image.open(fp=ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB") + + @pytest.fixture def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.Image, Image.Image]: assets = Path(__file__).parent.parent.parent / "assets" @@ -1034,7 +1059,7 @@ def test_diffusion_lora( sd15.set_inference_steps(30) - SDLoraManager(sd15).load(lora_weights, scale=1) + SDLoraManager(sd15).add_loras("pokemon", lora_weights, scale=1) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) @@ -1067,7 +1092,7 @@ def test_diffusion_sdxl_lora( prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" - SDLoraManager(sdxl).load(lora_weights, scale=lora_scale) + SDLoraManager(sdxl).add_loras("dpo", lora_weights, scale=lora_scale) clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt @@ -1094,6 +1119,54 @@ def test_diffusion_sdxl_lora( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) +@no_grad() +def test_diffusion_sdxl_multiple_loras( + sdxl_ddim: StableDiffusion_XL, + lora_data_dpo: tuple[Image.Image, dict[str, torch.Tensor]], + lora_sliders: tuple[dict[str, dict[str, torch.Tensor]], dict[str, float]], + expected_sdxl_multi_loras: Image.Image, +) -> None: + sdxl = sdxl_ddim + expected_image = expected_sdxl_multi_loras + _, dpo = lora_data_dpo + loras, scales = lora_sliders + loras["dpo"] = dpo + + SDLoraManager(sdxl).add_multiple_loras(loras, scales) + + # parameters are the same as https://huggingface.co/radames/sdxl-DPO-LoRA + # except that we are using DDIM instead of sde-dpmsolver++ + n_steps = 40 + seed = 12341234123 + guidance_scale = 4 + prompt = "professional portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography" + negative_prompt = "3d render, cartoon, drawing, art, low light, blur, pixelated, low resolution, black and white" + + clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( + text=prompt, negative_text=negative_prompt + ) + + time_ids = sdxl.default_time_ids + sdxl.set_inference_steps(n_steps) + + manual_seed(seed=seed) + x = torch.randn(1, 4, 128, 128, device=sdxl.device, dtype=sdxl.dtype) + + for step in sdxl.steps: + x = sdxl( + x, + step=step, + clip_text_embedding=clip_text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + condition_scale=guidance_scale, + ) + + predicted_image = sdxl.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) + + @no_grad() def test_diffusion_refonly( sd15_ddim: StableDiffusion_1, diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index 0f572b935..fdd22b176 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -49,6 +49,7 @@ Special cases: - `expected_freeu.png` - `expected_dropy_slime_9752.png` - `expected_sdxl_dpo_lora.png` + - `expected_sdxl_multi_loras.png` ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_sdxl_multi_loras.png b/tests/e2e/test_diffusion_ref/expected_sdxl_multi_loras.png new file mode 100644 index 000000000..035c81cd2 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_sdxl_multi_loras.png differ diff --git a/tests/foundationals/latent_diffusion/test_lora_manager.py b/tests/foundationals/latent_diffusion/test_lora_manager.py new file mode 100644 index 000000000..fdce986e1 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_lora_manager.py @@ -0,0 +1,88 @@ +from pathlib import Path +from warnings import warn + +import pytest +import torch + +from refiners.fluxion.utils import load_tensors +from refiners.foundationals.latent_diffusion import StableDiffusion_1 +from refiners.foundationals.latent_diffusion.lora import Lora, SDLoraManager + + +@pytest.fixture +def manager() -> SDLoraManager: + target = StableDiffusion_1() + return SDLoraManager(target) + + +@pytest.fixture +def weights(test_weights_path: Path) -> dict[str, torch.Tensor]: + weights_path = test_weights_path / "loras" / "pokemon-lora" / "pytorch_lora_weights.bin" + + if not weights_path.is_file(): + warn(f"could not find weights at {weights_path}, skipping") + pytest.skip(allow_module_level=True) + + return load_tensors(weights_path) + + +def test_add_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + manager.add_loras("pokemon-lora", tensors=weights) + assert "pokemon-lora" in manager.names + + with pytest.raises(AssertionError) as exc: + manager.add_loras("pokemon-lora", tensors=weights) + assert "already exists" in str(exc.value) + + +def test_add_multiple_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) + assert "pokemon-lora" in manager.names + assert "pokemon-lora2" in manager.names + + +def test_remove_loras(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) + manager.remove_loras("pokemon-lora") + assert "pokemon-lora" not in manager.names + assert "pokemon-lora2" in manager.names + + manager.remove_loras("pokemon-lora2") + assert "pokemon-lora2" not in manager.names + assert len(manager.names) == 0 + + +def test_remove_all(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + manager.add_multiple_loras({"pokemon-lora": weights, "pokemon-lora2": weights}) + manager.remove_all() + assert len(manager.names) == 0 + + +def test_get_lora(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + manager.add_loras("pokemon-lora", tensors=weights) + assert all(isinstance(lora, Lora) for lora in manager.get_loras_by_name("pokemon-lora")) + + +def test_get_scale(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + manager.add_loras("pokemon-lora", tensors=weights, scale=0.4) + assert manager.get_scale("pokemon-lora") == 0.4 + + +def test_names(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + assert manager.names == [] + + manager.add_loras("pokemon-lora", tensors=weights) + assert manager.names == ["pokemon-lora"] + + manager.add_loras("pokemon-lora2", tensors=weights) + assert manager.names == ["pokemon-lora2", "pokemon-lora"] + + +def test_scales(manager: SDLoraManager, weights: dict[str, torch.Tensor]) -> None: + assert manager.scales == {} + + manager.add_loras("pokemon-lora", tensors=weights, scale=0.4) + assert manager.scales == {"pokemon-lora": 0.4} + + manager.add_loras("pokemon-lora2", tensors=weights, scale=0.5) + assert manager.scales == {"pokemon-lora": 0.4, "pokemon-lora2": 0.5}