Skip to content

Commit

Permalink
[Lora] Seperate logic (huggingface#5809)
Browse files Browse the repository at this point in the history
* [Lora] Seperate logic

* [Lora] Seperate logic

* [Lora] Seperate logic

* add comments to explain the code better

* add comments to explain the code better
  • Loading branch information
patrickvonplaten authored Nov 21, 2023
1 parent ba352ae commit 13d73d9
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 54 deletions.
35 changes: 34 additions & 1 deletion examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
Expand All @@ -70,6 +70,39 @@
logger = get_logger(__name__)


# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection

attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))

return attn_modules

for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

return state_dict


def save_model_card(
repo_id: str,
images=None,
Expand Down
35 changes: 34 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
Expand All @@ -63,6 +63,39 @@
logger = get_logger(__name__)


# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection

attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))

return attn_modules

for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

return state_dict


def save_model_card(
repo_id: str,
images=None,
Expand Down
98 changes: 73 additions & 25 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
Expand All @@ -54,6 +53,39 @@
logger = get_logger(__name__, log_level="INFO")


# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection

attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))

return attn_modules

for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

return state_dict


def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
Expand Down Expand Up @@ -458,25 +490,43 @@ def main():
# => 32 layers

# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)

# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)

attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)

unet.set_attn_processor(lora_attn_procs)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
Expand All @@ -491,8 +541,6 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

lora_layers = AttnProcsLayers(unet.attn_processors)

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
Expand All @@ -517,7 +565,7 @@ def main():
optimizer_cls = torch.optim.AdamW

optimizer = optimizer_cls(
lora_layers.parameters(),
unet_lora_parameters,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
Expand Down Expand Up @@ -644,8 +692,8 @@ def collate_fn(examples):
)

# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
Expand Down Expand Up @@ -777,7 +825,7 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
params_to_clip = unet_lora_parameters
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
Expand Down
35 changes: 34 additions & 1 deletion examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
Expand All @@ -63,6 +63,39 @@
logger = get_logger(__name__)


# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection

attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))

return attn_modules

for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

return state_dict


def save_model_card(
repo_id: str,
images=None,
Expand Down
5 changes: 2 additions & 3 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder):
deprecate(
"text_encoder_load_state_dict in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
)
state_dict = {}

Expand All @@ -34,7 +34,7 @@ def text_encoder_attn_modules(text_encoder):
deprecate(
"text_encoder_attn_modules in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
)
from transformers import CLIPTextModel, CLIPTextModelWithProjection

Expand Down Expand Up @@ -67,7 +67,6 @@ def text_encoder_attn_modules(text_encoder):

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ..models.lora import text_encoder_lora_state_dict
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
Expand Down
37 changes: 33 additions & 4 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@


if is_transformers_available():
from transformers import PreTrainedModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection

from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
# To be deprecated soon
from ..models.lora import PatchedLoraProjection

if is_accelerate_available():
from accelerate import init_empty_weights
Expand All @@ -66,6 +67,34 @@
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."


def text_encoder_attn_modules(text_encoder):
attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
else:
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")

return attn_modules


def text_encoder_mlp_modules(text_encoder):
mlp_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp
name = f"text_model.encoder.layers.{i}.mlp"
mlp_modules.append((name, mlp_mod))
else:
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")

return mlp_modules


class LoraLoaderMixin:
r"""
Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`].
Expand Down Expand Up @@ -1415,7 +1444,7 @@ def process_weights(adapter_names, weights):
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)

def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
"""
Disable the text encoder's LoRA layers.
Expand Down Expand Up @@ -1445,7 +1474,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)

def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
"""
Enables the text encoder's LoRA layers.
Expand Down
Loading

0 comments on commit 13d73d9

Please sign in to comment.