Skip to content

Commit

Permalink
Load Multiple LoRAs with SDLoraManager
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Jan 22, 2024
1 parent 40c33b9 commit 889ce08
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 74 deletions.
5 changes: 5 additions & 0 deletions scripts/prepare_test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def download_loras():
"https://huggingface.co/radames/sdxl-DPO-LoRA/resolve/main/pytorch_lora_weights.safetensors", dest_folder
)

dest_folder = os.path.join(test_weights_dir, "loras", "sliders")
download_file("https://sliders.baulab.info/weights/xl_sliders/age.pt", dest_folder)
download_file("https://sliders.baulab.info/weights/xl_sliders/cartoon_style.pt", dest_folder)
download_file("https://sliders.baulab.info/weights/xl_sliders/eyesize.pt", dest_folder)


def download_preprocessors():
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
Expand Down
150 changes: 98 additions & 52 deletions src/refiners/fluxion/adapters/lora.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
from abc import ABC, abstractmethod
from typing import Any

from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
from torch.nn.init import normal_, zeros_

import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.layers.chain import Chain


class Lora(fl.Chain, ABC):
def __init__(
self,
name: str,
/,
rank: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.rank = rank
self.name = name
self._rank = rank
self._scale = scale

super().__init__(*self.lora_layers(device=device, dtype=dtype), fl.Multiply(scale))
Expand All @@ -44,6 +45,10 @@ def up(self) -> fl.WeightedModule:
assert isinstance(up_layer, fl.WeightedModule)
return up_layer

@property
def rank(self) -> int:
return self._rank

@property
def scale(self) -> float:
return self._scale
Expand All @@ -56,19 +61,21 @@ def scale(self, value: float) -> None:
@classmethod
def from_weights(
cls,
name: str,
/,
down: Tensor,
up: Tensor,
) -> "Lora":
match (up.ndim, down.ndim):
case (2, 2):
return LinearLora.from_weights(up=up, down=down)
return LinearLora.from_weights(name, up=up, down=down)
case (4, 4):
return Conv2dLora.from_weights(up=up, down=down)
return Conv2dLora.from_weights(name, up=up, down=down)
case _:
raise ValueError(f"Unsupported weight shapes: up={up.shape}, down={down.shape}")

@classmethod
def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
def from_dict(cls, name: str, /, state_dict: dict[str, Tensor]) -> dict[str, "Lora"]:
"""
Create a dictionary of LoRA layers from a state dict.
Expand All @@ -80,13 +87,37 @@ def from_dict(cls, state_dict: dict[str, Tensor], /) -> dict[str, "Lora"]:
list(state_dict.keys())[::2], list(state_dict.values())[::2], list(state_dict.values())[1::2]
):
key = ".".join(down_key.split(".")[:-2])
loras[key] = cls.from_weights(down=down_tensor, up=up_tensor)
loras[key] = cls.from_weights(name, down=down_tensor, up=up_tensor)
return loras

@abstractmethod
def auto_attach(self, target: fl.Chain, exclude: list[str] | None = None) -> Any:
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
...

def auto_attach(
self, target: fl.Chain, exclude: list[str] | None = None
) -> "tuple[LoraAdapter, fl.Chain | None] | None":
for layer, parent in target.walk(self.up.__class__):
if isinstance(parent, Lora):
continue

if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue

if not self.is_compatible(layer):
continue

if isinstance(parent, LoraAdapter):
if self.name in parent.names:
continue
else:
parent.add_lora(self)
return parent, None

return LoraAdapter(layer, self), parent

def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
assert down_weight.shape == self.down.weight.shape
assert up_weight.shape == self.up.weight.shape
Expand All @@ -97,6 +128,8 @@ def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
class LinearLora(Lora):
def __init__(
self,
name: str,
/,
in_features: int,
out_features: int,
rank: int = 16,
Expand All @@ -107,35 +140,29 @@ def __init__(
self.in_features = in_features
self.out_features = out_features

super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)

@classmethod
def from_weights(
cls,
name: str,
/,
down: Tensor,
up: Tensor,
) -> "LinearLora":
assert up.ndim == 2 and down.ndim == 2
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
lora = cls(
in_features=down.shape[1], out_features=up.shape[0], rank=down.shape[0], device=up.device, dtype=up.dtype
name,
in_features=down.shape[1],
out_features=up.shape[0],
rank=down.shape[0],
device=up.device,
dtype=up.dtype,
)
lora.load_weights(down_weight=down, up_weight=up)
return lora

def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Linear):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue

if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue

if layer.in_features == self.in_features and layer.out_features == self.out_features:
return LoraAdapter(target=layer, lora=self), parent

def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Linear, fl.Linear]:
Expand All @@ -156,10 +183,17 @@ def lora_layers(
),
)

def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
if isinstance(layer, fl.Linear):
return layer.in_features == self.in_features and layer.out_features == self.out_features
return False


class Conv2dLora(Lora):
def __init__(
self,
name: str,
/,
in_channels: int,
out_channels: int,
rank: int = 16,
Expand All @@ -176,20 +210,24 @@ def __init__(
self.stride = stride
self.padding = padding

super().__init__(rank=rank, scale=scale, device=device, dtype=dtype)
super().__init__(name, rank=rank, scale=scale, device=device, dtype=dtype)

@classmethod
def from_weights(
cls,
name: str,
/,
down: Tensor,
up: Tensor,
) -> "Conv2dLora":
assert up.ndim == 4 and down.ndim == 4
assert down.shape[0] == up.shape[1], f"Rank mismatch: down rank={down.shape[0]} and up rank={up.shape[1]}"
down_kernel_size, up_kernel_size = down.shape[2], up.shape[2]
# padding is set so the spatial dimensions are preserved (assuming stride=1 and kernel_size either 1 or 3)
down_padding = 1 if down_kernel_size == 3 else 0
up_padding = 1 if up_kernel_size == 3 else 0
lora = cls(
name,
in_channels=down.shape[1],
out_channels=up.shape[0],
rank=down.shape[0],
Expand All @@ -201,25 +239,6 @@ def from_weights(
lora.load_weights(down_weight=down, up_weight=up)
return lora

def auto_attach(self, target: Chain, exclude: list[str] | None = None) -> "tuple[LoraAdapter, fl.Chain] | None":
for layer, parent in target.walk(fl.Conv2d):
if isinstance(parent, Lora) or isinstance(parent, LoraAdapter):
continue

if exclude is not None and any(
[any([p.__class__.__name__ == e for p in parent.get_parents() + [parent]]) for e in exclude]
):
continue

if layer.in_channels == self.in_channels and layer.out_channels == self.out_channels:
if layer.stride != (self.stride[0], self.stride[0]):
self.down.stride = layer.stride

return LoraAdapter(
target=layer,
lora=self,
), parent

def lora_layers(
self, device: Device | str | None = None, dtype: DType | None = None
) -> tuple[fl.Conv2d, fl.Conv2d]:
Expand All @@ -246,20 +265,47 @@ def lora_layers(
),
)

def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
if (
isinstance(layer, fl.Conv2d)
and layer.in_channels == self.in_channels
and layer.out_channels == self.out_channels
):
# stride cannot be inferred from the weights, so we assume it's the same as the layer
self.down.stride = layer.stride

return True
return False


class LoraAdapter(fl.Sum, Adapter[fl.WeightedModule]):
def __init__(self, target: fl.WeightedModule, lora: Lora) -> None:
def __init__(self, target: fl.WeightedModule, /, *loras: Lora) -> None:
with self.setup_adapter(target):
super().__init__(target, lora)
super().__init__(target, *loras)

@property
def lora(self) -> Lora:
return self.ensure_find(Lora)
def names(self) -> list[str]:
return [lora.name for lora in self.layers(Lora)]

@property
def scale(self) -> float:
return self.lora.scale
def loras(self) -> dict[str, Lora]:
return {lora.name: lora for lora in self.layers(Lora)}

@scale.setter
def scale(self, value: float) -> None:
self.lora.scale = value
@property
def scales(self) -> dict[str, float]:
return {lora.name: lora.scale for lora in self.layers(Lora)}

@scales.setter
def scale(self, values: dict[str, float]) -> None:
for name, value in values.items():
self.loras[name].scale = value

def add_lora(self, lora: Lora, /) -> None:
assert lora.name not in self.names, f"LoRA layer with name {lora.name} already exists"
self.append(lora)

def remove_lora(self, name: str, /) -> Lora | None:
if name in self.names:
lora = self.loras[name]
self.remove(lora)
return lora
Loading

0 comments on commit 889ce08

Please sign in to comment.