diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9250865a3ad1..b82dfa38c172 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -57,7 +57,7 @@ AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, ) -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -70,6 +70,39 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 97b60c8f527d..dd7b29ca8842 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -50,7 +50,7 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -63,6 +63,39 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 78b443d149e8..b7309196dec8 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -40,8 +40,7 @@ import diffusers from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel -from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -54,6 +53,39 @@ logger = get_logger(__name__, log_level="INFO") +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): img_str = "" for i, image in enumerate(images): @@ -458,25 +490,43 @@ def main(): # => 32 layers # Set correct lora layers - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=args.rank, + unet_lora_parameters = [] + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank + ) + ) + + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=args.rank, + ) ) - unet.set_attn_processor(lora_attn_procs) + # Accumulate the LoRA params to optimize. + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -491,8 +541,6 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - lora_layers = AttnProcsLayers(unet.attn_processors) - # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -517,7 +565,7 @@ def main(): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - lora_layers.parameters(), + unet_lora_parameters, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -644,8 +692,8 @@ def collate_fn(examples): ) # Prepare everything with our `accelerator`. - lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - lora_layers, optimizer, train_dataloader, lr_scheduler + unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet_lora_parameters, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -777,7 +825,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers.parameters() + params_to_clip = unet_lora_parameters accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index bff928541f57..96bfe9e16783 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -50,7 +50,7 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, is_wandb_available @@ -63,6 +63,39 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 684736856029..45c8c97c76eb 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder): deprecate( "text_encoder_load_state_dict in `models`", "0.27.0", - "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) state_dict = {} @@ -34,7 +34,7 @@ def text_encoder_attn_modules(text_encoder): deprecate( "text_encoder_attn_modules in `models`", "0.27.0", - "`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", + "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.", ) from transformers import CLIPTextModel, CLIPTextModelWithProjection @@ -67,7 +67,6 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from ..models.lora import text_encoder_lora_state_dict from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index ab5d0ffd0157..06eb3af05ee2 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -47,9 +47,10 @@ if is_transformers_available(): - from transformers import PreTrainedModel + from transformers import CLIPTextModel, CLIPTextModelWithProjection - from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules + # To be deprecated soon + from ..models.lora import PatchedLoraProjection if is_accelerate_available(): from accelerate import init_empty_weights @@ -66,6 +67,34 @@ LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." +def text_encoder_attn_modules(text_encoder): + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + else: + raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") + + return attn_modules + + +def text_encoder_mlp_modules(text_encoder): + mlp_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + mlp_mod = layer.mlp + name = f"text_model.encoder.layers.{i}.mlp" + mlp_modules.append((name, mlp_mod)) + else: + raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") + + return mlp_modules + + class LoraLoaderMixin: r""" Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`]. @@ -1415,7 +1444,7 @@ def process_weights(adapter_names, weights): ) set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) - def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 """ Disable the text encoder's LoRA layers. @@ -1445,7 +1474,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel" raise ValueError("Text Encoder not found.") set_adapter_layers(text_encoder, enabled=False) - def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): + def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 """ Enables the text encoder's LoRA layers. diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 9edec19a3a34..daac8f902cd6 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -12,6 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# IMPORTANT: # +################################################################### +# ----------------------------------------------------------------# +# This file is deprecated and will be removed soon # +# (as soon as PEFT will become a required dependency for LoRA) # +# ----------------------------------------------------------------# +################################################################### + from typing import Optional, Tuple, Union import torch @@ -57,25 +66,6 @@ def text_encoder_mlp_modules(text_encoder): return mlp_modules -def text_encoder_lora_state_dict(text_encoder): - state_dict = {} - - for name, module in text_encoder_attn_modules(text_encoder): - for k, v in module.q_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v - - for k, v in module.k_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v - - for k, v in module.v_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - - for k, v in module.out_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v - - return state_dict - - def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection):