From 9a92b8177cb3f8bf4b095fff55da3b45a3607960 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 30 Oct 2024 18:04:15 +0530 Subject: [PATCH 01/20] Allegro VAE fix (#9811) fix --- .../models/autoencoders/autoencoder_kl_allegro.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 4836de7e16ab..922fd15c08fb 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -1091,8 +1091,6 @@ def forward( sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, - encoder_local_batch_size: int = 2, - decoder_local_batch_size: int = 2, ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: @@ -1103,18 +1101,14 @@ def forward( Whether or not to return a [`DecoderOutput`] instead of a plain tuple. generator (`torch.Generator`, *optional*): PyTorch random number generator. - encoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the encoder's batch inference. - decoder_local_batch_size (`int`, *optional*, defaults to 2): - Local batch size for the decoder's batch inference. """ x = sample - posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist + posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample + dec = self.decode(z).sample if not return_dict: return (dec,) From c1d4a0dded4d5b5f434051435c3cb091ffb9cabd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Oct 2024 14:58:05 +0530 Subject: [PATCH 02/20] [CI] add new runner for testing (#9699) new runner. --- .github/workflows/ssh-runner.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ssh-runner.yml b/.github/workflows/ssh-runner.yml index 0d4fe1578ba6..fd65598a53a7 100644 --- a/.github/workflows/ssh-runner.yml +++ b/.github/workflows/ssh-runner.yml @@ -4,12 +4,13 @@ on: workflow_dispatch: inputs: runner_type: - description: 'Type of runner to test (aws-g6-4xlarge-plus: a10 or aws-g4dn-2xlarge: t4)' + description: 'Type of runner to test (aws-g6-4xlarge-plus: a10, aws-g4dn-2xlarge: t4, aws-g6e-xlarge-plus: L40)' type: choice required: true options: - aws-g6-4xlarge-plus - aws-g4dn-2xlarge + - aws-g6e-xlarge-plus docker_image: description: 'Name of the Docker image' required: true From 09b8aebd67018d4fb8a559fc8a5ad4e74e956d9d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Oct 2024 15:46:00 +0530 Subject: [PATCH 03/20] [training] fixes to the quantization training script and add AdEMAMix optimizer as an option (#9806) * fixes * more fixes. --- .../train_dreambooth_lora_flux_miniature.py | 45 +++++++++++++------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py index fd2b5568d6d8..f3b4602c7fcf 100644 --- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py @@ -349,7 +349,7 @@ def parse_args(input_args=None): "--optimizer", type=str, default="AdamW", - help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + choices=["AdamW", "Prodigy", "AdEMAMix"], ) parser.add_argument( @@ -357,6 +357,11 @@ def parse_args(input_args=None): action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", ) + parser.add_argument( + "--use_8bit_ademamix", + action="store_true", + help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.", + ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." @@ -820,16 +825,15 @@ def load_model_hook(models, input_dir): params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation - if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warning( - f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." - "Defaulting to adamW" + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" ) - args.optimizer = "adamw" - if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix": logger.warning( - f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was " f"set to {args.optimizer.lower()}" ) @@ -853,6 +857,20 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + elif args.optimizer.lower() == "ademamix": + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`." + ) + if args.use_8bit_ademamix: + optimizer_class = bnb.optim.AdEMAMix8bit + else: + optimizer_class = bnb.optim.AdEMAMix + + optimizer = optimizer_class(params_to_optimize) + if args.optimizer.lower() == "prodigy": try: import prodigyopt @@ -868,7 +886,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, @@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if transformer.config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, ) From 8ce37ab055372dedf4e9621ed63374a019d93f5d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Oct 2024 15:51:42 +0530 Subject: [PATCH 04/20] [training] use the lr when using 8bit adam. (#9796) * use the lr when using 8bit adam. * remove lr as we pack it in params_to_optimize. --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- .../train_dreambooth_lora_flux_advanced.py | 14 +++----------- .../train_dreambooth_lora_sd15_advanced.py | 6 +----- .../train_dreambooth_lora_sdxl_advanced.py | 1 - .../train_cogvideox_image_to_video_lora.py | 1 - examples/cogvideo/train_cogvideox_lora.py | 1 - examples/dreambooth/train_dreambooth_flux.py | 6 +----- examples/dreambooth/train_dreambooth_lora_flux.py | 6 +----- examples/dreambooth/train_dreambooth_lora_sd3.py | 1 - examples/dreambooth/train_dreambooth_lora_sdxl.py | 1 - examples/dreambooth/train_dreambooth_sd3.py | 1 - .../dreambooth/train_dreambooth_lora_sdxl.py | 1 - 11 files changed, 6 insertions(+), 33 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 92d296c0f1e8..bf726e65c94b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1778,15 +1778,10 @@ def load_model_hook(models, input_dir): if not args.enable_t5_ti: # pure textual inversion - only clip if pure_textual_inversion: - params_to_optimize = [ - text_parameters_one_with_lr, - ] + params_to_optimize = [text_parameters_one_with_lr] te_idx = 0 else: # regular te training or regular pivotal for clip - params_to_optimize = [ - transformer_parameters_with_lr, - text_parameters_one_with_lr, - ] + params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr] te_idx = 1 elif args.enable_t5_ti: # pivotal tuning of clip & t5 @@ -1809,9 +1804,7 @@ def load_model_hook(models, input_dir): ] te_idx = 1 else: - params_to_optimize = [ - transformer_parameters_with_lr, - ] + params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): @@ -1871,7 +1864,6 @@ def load_model_hook(models, input_dir): params_to_optimize[-1]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 024722536d88..7fdea56dc5cb 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -1358,10 +1358,7 @@ def load_model_hook(models, input_dir): else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } - params_to_optimize = [ - unet_lora_parameters_with_lr, - text_lora_parameters_one_with_lr, - ] + params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr] else: params_to_optimize = [unet_lora_parameters_with_lr] @@ -1423,7 +1420,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index bc06cc9213dc..74d52186dd81 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1794,7 +1794,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index 4ef392baa2b5..1f055bcecbed 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -947,7 +947,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 011466bc7d58..e591e0ee5900 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -969,7 +969,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index f720afef6542..d23d05f7e38b 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1226,10 +1226,7 @@ def load_model_hook(models, input_dir): "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } - params_to_optimize = [ - transformer_parameters_with_lr, - text_parameters_one_with_lr, - ] + params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr] else: params_to_optimize = [transformer_parameters_with_lr] @@ -1291,7 +1288,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b6e657234850..a0a197b1b2ee 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1335,10 +1335,7 @@ def load_model_hook(models, input_dir): "weight_decay": args.adam_weight_decay_text_encoder, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } - params_to_optimize = [ - transformer_parameters_with_lr, - text_parameters_one_with_lr, - ] + params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr] else: params_to_optimize = [transformer_parameters_with_lr] @@ -1400,7 +1397,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index fc3c69b8901f..dcf093a94c5a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1468,7 +1468,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index bf8c8f7d0578..6e621b3caee3 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1402,7 +1402,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 5d10345304ab..525a4cc906e9 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1328,7 +1328,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py index d16780131139..2a9801038999 100644 --- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py @@ -1475,7 +1475,6 @@ def load_model_hook(models, input_dir): optimizer = optimizer_class( params_to_optimize, - lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, From 4adf6affbb5800ba7ff3c9d87ccc427300dd1ba1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Oct 2024 18:24:19 +0530 Subject: [PATCH 05/20] [Tests] clean up and refactor gradient checkpointing tests (#9494) * check. * fixes * fixes * updates * fixes * fixes --- tests/models/autoencoders/test_models_vae.py | 109 ++++-------------- tests/models/test_modeling_common.py | 97 ++++++++++++++++ .../test_models_dit_transformer2d.py | 7 ++ .../test_models_pixart_transformer2d.py | 4 + .../test_models_transformer_allegro.py | 4 + .../test_models_transformer_aura_flow.py | 4 + .../test_models_transformer_cogvideox.py | 4 + .../test_models_transformer_cogview3plus.py | 4 + .../test_models_transformer_flux.py | 4 + .../test_models_transformer_latte.py | 4 + .../test_models_transformer_sd3.py | 8 ++ .../unets/test_models_unet_2d_condition.py | 76 +----------- .../unets/test_models_unet_controlnetxs.py | 28 +---- tests/models/unets/test_models_unet_motion.py | 26 +---- .../unets/test_models_unet_spatiotemporal.py | 74 +----------- 15 files changed, 180 insertions(+), 273 deletions(-) diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 0188f9121ae0..d29defbf6085 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -39,7 +39,6 @@ load_hf_numpy, require_torch_accelerator, require_torch_accelerator_with_fp16, - require_torch_accelerator_with_training, require_torch_gpu, skip_mps, slow, @@ -170,52 +169,17 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_training(self): pass - @require_torch_accelerator_with_training - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Decoder", "Encoder"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_from_pretrained_hub(self): model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) @@ -329,9 +293,11 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_forward_with_norm_groups(self): pass @@ -364,9 +330,20 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_outputs_equivalence(self): pass + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DecoderTiny", "EncoderTiny"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip( + "Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest." + ) + def test_effective_gradient_checkpointing(self): + pass + class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): model_class = ConsistencyDecoderVAE @@ -443,55 +420,17 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_training(self): pass - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - if "post_quant_conv" in name: - continue - - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Encoder", "TemporalDecoder"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): @@ -522,9 +461,11 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Not tested.") def test_forward_signature(self): pass + @unittest.skip("Not tested.") def test_forward_with_norm_groups(self): pass diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5548fdd0723d..7f8dc63e00ac 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import json import os @@ -57,6 +58,7 @@ require_torch_gpu, require_torch_multi_gpu, run_test_in_subprocess, + torch_all_close, torch_device, ) @@ -785,6 +787,101 @@ def test_enable_disable_gradient_checkpointing(self): model.disable_gradient_checkpointing() self.assertFalse(model.is_gradient_checkpointing) + @require_torch_accelerator_with_training + def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5): + if not self.model_class._supports_gradient_checkpointing: + return # Skip test if model does not support gradient checkpointing + + # enable deterministic behavior for gradient checkpointing + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + inputs_dict_copy = copy.deepcopy(inputs_dict) + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + torch.manual_seed(0) + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict_copy).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + self.assertTrue((loss - loss_2).abs() < loss_tolerance) + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + + for name, param in named_params.items(): + if "post_quant_conv" in name: + continue + self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol)) + + @unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.") + def test_gradient_checkpointing_is_applied( + self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None + ): + if not self.model_class._supports_gradient_checkpointing: + return # Skip test if model does not support gradient checkpointing + if self.model_class.__name__ in [ + "UNetSpatioTemporalConditionModel", + "AutoencoderKLTemporalDecoder", + ]: + return + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + if attention_head_dim is not None: + init_dict["attention_head_dim"] = attention_head_dim + if num_attention_heads is not None: + init_dict["num_attention_heads"] = num_attention_heads + if block_out_channels is not None: + init_dict["block_out_channels"] = block_out_channels + + model_class_copy = copy.copy(self.model_class) + + modules_with_gc_enabled = {} + + # now monkey patch the following function: + # def _set_gradient_checkpointing(self, module, value=False): + # if hasattr(module, "gradient_checkpointing"): + # module.gradient_checkpointing = value + + def _set_gradient_checkpointing_new(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + modules_with_gc_enabled[module.__class__.__name__] = True + + model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new + + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() + + print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}") + + assert set(modules_with_gc_enabled.keys()) == expected_set + assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + def test_deprecated_kwargs(self): has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py index b12cae1a8879..5f4a2f587e92 100644 --- a/tests/models/transformers/test_models_dit_transformer2d.py +++ b/tests/models/transformers/test_models_dit_transformer2d.py @@ -84,6 +84,13 @@ def test_correct_class_remapping_from_dict_config(self): model = Transformer2DModel.from_config(init_dict) assert isinstance(model, DiTTransformer2DModel) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"DiTTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + def test_effective_gradient_checkpointing(self): + super().test_effective_gradient_checkpointing(loss_tolerance=1e-4) + def test_correct_class_remapping_from_pretrained_config(self): config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer") model = Transformer2DModel.from_config(config) diff --git a/tests/models/transformers/test_models_pixart_transformer2d.py b/tests/models/transformers/test_models_pixart_transformer2d.py index 30293f5d35cb..a544a3fc4607 100644 --- a/tests/models/transformers/test_models_pixart_transformer2d.py +++ b/tests/models/transformers/test_models_pixart_transformer2d.py @@ -92,6 +92,10 @@ def test_output(self): expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape ) + def test_gradient_checkpointing_is_applied(self): + expected_set = {"PixArtTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_correct_class_remapping_from_dict_config(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = Transformer2DModel.from_config(init_dict) diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py index ad8b7a3824ba..3479803da61d 100644 --- a/tests/models/transformers/test_models_transformer_allegro.py +++ b/tests/models/transformers/test_models_transformer_allegro.py @@ -77,3 +77,7 @@ def prepare_init_args_and_inputs_for_common(self): } inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"AllegroTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py index 376d8b57da4d..d1ff7d2c96d3 100644 --- a/tests/models/transformers/test_models_transformer_aura_flow.py +++ b/tests/models/transformers/test_models_transformer_aura_flow.py @@ -74,6 +74,10 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + def test_gradient_checkpointing_is_applied(self): + expected_set = {"AuraFlowTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + @unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply") def test_set_attn_processor_for_determinism(self): pass diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 6db4113cbd1b..1342577f0114 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -81,3 +81,7 @@ def prepare_init_args_and_inputs_for_common(self): } inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CogVideoXTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py index 46612dbd9190..eda9813808e9 100644 --- a/tests/models/transformers/test_models_transformer_cogview3plus.py +++ b/tests/models/transformers/test_models_transformer_cogview3plus.py @@ -83,3 +83,7 @@ def prepare_init_args_and_inputs_for_common(self): } inputs_dict = self.dummy_input return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CogView3PlusTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 6cf7a4f75707..4a784eee4732 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -111,3 +111,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self): torch.allclose(output_1, output_2, atol=1e-5), msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"FluxTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_latte.py b/tests/models/transformers/test_models_transformer_latte.py index 3fe0a6098045..0cb9094f5165 100644 --- a/tests/models/transformers/test_models_transformer_latte.py +++ b/tests/models/transformers/test_models_transformer_latte.py @@ -86,3 +86,7 @@ def test_output(self): super().test_output( expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LatteTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index 2be4744c5ac4..af86fa9c3bc1 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -84,6 +84,10 @@ def prepare_init_args_and_inputs_for_common(self): def test_set_attn_processor_for_determinism(self): pass + def test_gradient_checkpointing_is_applied(self): + expected_set = {"SD3Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): model_class = SD3Transformer2DModel @@ -139,3 +143,7 @@ def prepare_init_args_and_inputs_for_common(self): @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") def test_set_attn_processor_for_determinism(self): pass + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"SD3Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 37d55cedeb28..fec34822904c 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -43,7 +43,6 @@ require_peft_backend, require_torch_accelerator, require_torch_accelerator_with_fp16, - require_torch_accelerator_with_training, require_torch_gpu, skip_mps, slow, @@ -406,47 +405,6 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - @require_torch_accelerator_with_training - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) - def test_model_with_attention_head_dim_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -599,31 +557,7 @@ def check_sliceable_dim_attr(module: torch.nn.Module): check_sliceable_dim_attr(module) def test_gradient_checkpointing_is_applied(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = (8, 16) - - model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - model = model_class_copy(**init_dict) - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "CrossAttnUpBlock2D", "CrossAttnDownBlock2D", "UNetMidBlock2DCrossAttn", @@ -631,9 +565,11 @@ def _set_gradient_checkpointing_new(self, module, value=False): "Transformer2DModel", "DownBlock2D", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + attention_head_dim = (8, 16) + block_out_channels = (16, 32) + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels + ) def test_special_attn_proc(self): class AttnEasyProc(torch.nn.Module): diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 6f3662e01750..3025d7117f35 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import unittest import numpy as np @@ -269,37 +268,14 @@ def assert_unfrozen(module): assert_unfrozen(u.ctrl_to_base) def test_gradient_checkpointing_is_applied(self): - model_class_copy = copy.copy(UNetControlNetXSModel) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = model_class_copy(**init_dict) - - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "Transformer2DModel", "UNetMidBlock2DCrossAttn", "ControlNetXSCrossAttnDownBlock2D", "ControlNetXSCrossAttnMidBlock2D", "ControlNetXSCrossAttnUpBlock2D", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) @is_flaky def test_forward_no_control(self): diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index ee05f0d93824..209806a5fe26 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -161,27 +161,7 @@ def test_xformers_enable_works(self): ), "xformers is not enabled" def test_gradient_checkpointing_is_applied(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - model = model_class_copy(**init_dict) - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "CrossAttnUpBlockMotion", "CrossAttnDownBlockMotion", "UNetMidBlockCrossAttnMotion", @@ -189,9 +169,7 @@ def _set_gradient_checkpointing_new(self, module, value=False): "Transformer2DModel", "DownBlockMotion", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) def test_feed_forward_chunking(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py index afdd3d127702..0d7dc823b026 100644 --- a/tests/models/unets/test_models_unet_spatiotemporal.py +++ b/tests/models/unets/test_models_unet_spatiotemporal.py @@ -25,7 +25,6 @@ enable_full_determinism, floats_tensor, skip_mps, - torch_all_close, torch_device, ) @@ -160,47 +159,6 @@ def test_xformers_enable_works(self): == "XFormersAttnProcessor" ), "xformers is not enabled" - @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") - def test_gradient_checkpointing(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - assert not model.is_gradient_checkpointing and model.training - - out = model(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model.zero_grad() - - labels = torch.randn_like(out) - loss = (out - labels).mean() - loss.backward() - - # re-instantiate the model now enabling gradient checkpointing - model_2 = self.model_class(**init_dict) - # clone model - model_2.load_state_dict(model.state_dict()) - model_2.to(torch_device) - model_2.enable_gradient_checkpointing() - - assert model_2.is_gradient_checkpointing and model_2.training - - out_2 = model_2(**inputs_dict).sample - # run the backwards pass on the model. For backwards pass, for simplicity purpose, - # we won't calculate the loss and rather backprop on out.sum() - model_2.zero_grad() - loss_2 = (out_2 - labels).mean() - loss_2.backward() - - # compare the output and parameters gradients - self.assertTrue((loss - loss_2).abs() < 1e-5) - named_params = dict(model.named_parameters()) - named_params_2 = dict(model_2.named_parameters()) - for name, param in named_params.items(): - self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) - def test_model_with_num_attention_heads_tuple(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -239,30 +197,7 @@ def test_model_with_cross_attention_dim_tuple(self): self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_gradient_checkpointing_is_applied(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["num_attention_heads"] = (8, 16) - - model_class_copy = copy.copy(self.model_class) - - modules_with_gc_enabled = {} - - # now monkey patch the following function: - # def _set_gradient_checkpointing(self, module, value=False): - # if hasattr(module, "gradient_checkpointing"): - # module.gradient_checkpointing = value - - def _set_gradient_checkpointing_new(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - modules_with_gc_enabled[module.__class__.__name__] = True - - model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new - - model = model_class_copy(**init_dict) - model.enable_gradient_checkpointing() - - EXPECTED_SET = { + expected_set = { "TransformerSpatioTemporalModel", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal", @@ -270,9 +205,10 @@ def _set_gradient_checkpointing_new(self, module, value=False): "CrossAttnUpBlockSpatioTemporal", "UNetMidBlockSpatioTemporal", } - - assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET - assert all(modules_with_gc_enabled.values()), "All modules should be enabled" + num_attention_heads = (8, 16) + super().test_gradient_checkpointing_is_applied( + expected_set=expected_set, num_attention_heads=num_attention_heads + ) def test_pickle(self): # enable deterministic behavior for gradient checkpointing From ff182ad6694ada3c01b3514eeae03392b2761b92 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Oct 2024 18:44:34 +0530 Subject: [PATCH 06/20] [CI] add a big GPU marker to run memory-intensive tests separately on CI (#9691) * add a marker for big gpu tests * update * trigger on PRs temporarily. * onnx * fix * total memory * fixes * reduce memory threshold. * bigger gpu * empty * g6e * Apply suggestions from code review * address comments. * fix * fix * fix * fix * fix * okay * further reduce. * updates * remove * updates * updates * updates * updates * fixes * fixes * updates. * fix * workflow fixes. --------- Co-authored-by: Aryan --- .github/workflows/nightly_tests.yml | 56 +++++++++++++++ src/diffusers/utils/testing_utils.py | 21 ++++++ .../controlnet_flux/test_controlnet_flux.py | 38 +++++++--- .../test_controlnet_flux_img2img.py | 71 ------------------- .../controlnet_sd3/test_controlnet_sd3.py | 35 ++++----- tests/pipelines/flux/test_pipeline_flux.py | 67 ++++++++++++----- .../test_pipeline_stable_diffusion_3.py | 6 +- ...est_pipeline_stable_diffusion_3_img2img.py | 6 +- utils/print_env.py | 4 ++ 9 files changed, 181 insertions(+), 123 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 142dbb0f1e8f..b8e9860aec63 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -180,6 +180,62 @@ jobs: pip install slack_sdk tabulate python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + run_big_gpu_torch_tests: + name: Torch tests on big GPU + strategy: + fail-fast: false + max-parallel: 2 + runs-on: + group: aws-g6e-xlarge-plus + container: + image: diffusers/diffusers-pytorch-cuda + options: --shm-size "16gb" --ipc host --gpus 0 + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: NVIDIA-SMI + run: nvidia-smi + - name: Install dependencies + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install -e [quality,test] + python -m uv pip install peft@git+https://github.com/huggingface/peft.git + pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + python -m uv pip install pytest-reportlog + - name: Environment + run: | + python utils/print_env.py + - name: Selected Torch CUDA Test on big GPU + env: + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} + # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms + CUBLAS_WORKSPACE_CONFIG: :16:8 + BIG_GPU_MEMORY: 40 + run: | + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ + -m "big_gpu_with_torch_cuda" \ + --make-reports=tests_big_gpu_torch_cuda \ + --report-log=tests_big_gpu_torch_cuda.log \ + tests/ + - name: Failure short reports + if: ${{ failure() }} + run: | + cat reports/tests_big_gpu_torch_cuda_stats.txt + cat reports/tests_big_gpu_torch_cuda_failures_short.txt + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: torch_cuda_big_gpu_test_reports + path: reports + - name: Generate Report and Notify Channel + if: always() + run: | + pip install slack_sdk tabulate + python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + run_flax_tpu_tests: name: Nightly Flax TPU Tests runs-on: docker-tpu diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 6361cca663b9..03b9c3752922 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -57,6 +57,7 @@ ) > version.parse("4.33") USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version +BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40)) if is_torch_available(): import torch @@ -310,6 +311,26 @@ def require_torch_accelerator_with_fp64(test_case): ) +def require_big_gpu_with_torch_cuda(test_case): + """ + Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog, + etc. + """ + if not is_torch_available(): + return unittest.skip("test requires PyTorch")(test_case) + + import torch + + if not torch.cuda.is_available(): + return unittest.skip("test requires PyTorch CUDA")(test_case) + + device_properties = torch.cuda.get_device_properties(0) + total_memory = device_properties.total_memory / (1024**3) + return unittest.skipUnless( + total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory" + )(test_case) + + def require_torch_accelerator_with_training(test_case): """Decorator marking a test that requires an accelerator with support for training.""" return unittest.skipUnless( diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index d2db28bdda35..89540232f9cf 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -17,7 +17,9 @@ import unittest import numpy as np +import pytest import torch +from huggingface_hub import hf_hub_download from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from diffusers import ( @@ -30,7 +32,8 @@ from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, - require_torch_gpu, + numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, slow, torch_device, ) @@ -180,7 +183,8 @@ def test_xformers_attention_forwardGenerator_pass(self): @slow -@require_torch_gpu +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda class FluxControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = FluxControlNetPipeline @@ -199,35 +203,49 @@ def test_canny(self): "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 ) pipe = FluxControlNetPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16 + "black-forest-labs/FLUX.1-dev", + text_encoder=None, + text_encoder_2=None, + controlnet=controlnet, + torch_dtype=torch.bfloat16, ) pipe.enable_model_cpu_offload() pipe.set_progress_bar_config(disable=None) generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "A girl in city, 25 years old, cool, futuristic" control_image = load_image( "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ).resize((512, 512)) + + prompt_embeds = torch.load( + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") + ) + pooled_prompt_embeds = torch.load( + hf_hub_download( + repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" + ) ) output = pipe( - prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, control_image=control_image, controlnet_conditioning_scale=0.6, num_inference_steps=2, guidance_scale=3.5, + max_sequence_length=256, output_type="np", + height=512, + width=512, generator=generator, ) image = output.images[0] - assert image.shape == (1024, 1024, 3) + assert image.shape == (512, 512, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array( - [0.33007812, 0.33984375, 0.33984375, 0.328125, 0.34179688, 0.33984375, 0.30859375, 0.3203125, 0.3203125] - ) + expected_image = np.array([0.2734, 0.2852, 0.2852, 0.2734, 0.2754, 0.2891, 0.2617, 0.2637, 0.2773]) - assert np.abs(original_image.flatten() - expected_image).max() < 1e-2 + assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2 diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index 9c0e948861f7..9b33d4b46d04 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -1,4 +1,3 @@ -import gc import unittest import numpy as np @@ -13,9 +12,6 @@ FluxTransformer2DModel, ) from diffusers.utils.testing_utils import ( - numpy_cosine_similarity_distance, - require_torch_gpu, - slow, torch_device, ) @@ -222,70 +218,3 @@ def test_fused_qkv_projections(self): assert np.allclose( original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." - - -@slow -@require_torch_gpu -class FluxControlNetImg2ImgPipelineSlowTests(unittest.TestCase): - pipeline_class = FluxControlNetImg2ImgPipeline - repo_id = "black-forest-labs/FLUX.1-schnell" - - def setUp(self): - super().setUp() - gc.collect() - torch.cuda.empty_cache() - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def get_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - image = torch.randn(1, 3, 64, 64).to(device) - control_image = torch.randn(1, 3, 64, 64).to(device) - - return { - "prompt": "A photo of a cat", - "image": image, - "control_image": control_image, - "num_inference_steps": 2, - "guidance_scale": 5.0, - "controlnet_conditioning_scale": 1.0, - "strength": 0.8, - "output_type": "np", - "generator": generator, - } - - @unittest.skip("We cannot run inference on this model with the current CI hardware") - def test_flux_controlnet_img2img_inference(self): - pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) - pipe.enable_model_cpu_offload() - - inputs = self.get_inputs(torch_device) - - image = pipe(**inputs).images[0] - image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - [0.36132812, 0.30004883, 0.25830078], - [0.36669922, 0.31103516, 0.23754883], - [0.34814453, 0.29248047, 0.23583984], - [0.35791016, 0.30981445, 0.23999023], - [0.36328125, 0.31274414, 0.2607422], - [0.37304688, 0.32177734, 0.26171875], - [0.3671875, 0.31933594, 0.25756836], - [0.36035156, 0.31103516, 0.2578125], - [0.3857422, 0.33789062, 0.27563477], - [0.3701172, 0.31982422, 0.265625], - ], - dtype=np.float32, - ) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) - - assert max_diff < 1e-4 diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 74cb56e0337a..aae1dc0ebcb0 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -30,7 +31,8 @@ from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, - require_torch_gpu, + numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, slow, torch_device, ) @@ -195,7 +197,8 @@ def test_xformers_attention_forwardGenerator_pass(self): @slow -@require_torch_gpu +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3ControlNetPipeline @@ -238,11 +241,9 @@ def test_canny(self): original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array( - [0.20947266, 0.1574707, 0.19897461, 0.15063477, 0.1418457, 0.17285156, 0.14160156, 0.13989258, 0.30810547] - ) + expected_image = np.array([0.7314, 0.7075, 0.6611, 0.7539, 0.7563, 0.6650, 0.6123, 0.7275, 0.7222]) - assert np.abs(original_image.flatten() - expected_image).max() < 1e-2 + assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2 def test_pose(self): controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Pose", torch_dtype=torch.float16) @@ -272,15 +273,12 @@ def test_pose(self): assert image.shape == (1024, 1024, 3) original_image = image[-3:, -3:, -1].flatten() + expected_image = np.array([0.9048, 0.8740, 0.8936, 0.8516, 0.8799, 0.9360, 0.8379, 0.8408, 0.8652]) - expected_image = np.array( - [0.8671875, 0.86621094, 0.91015625, 0.8491211, 0.87890625, 0.9140625, 0.8300781, 0.8334961, 0.8623047] - ) - - assert np.abs(original_image.flatten() - expected_image).max() < 1e-2 + assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2 def test_tile(self): - controlnet = SD3ControlNetModel.from_pretrained("InstantX//SD3-Controlnet-Tile", torch_dtype=torch.float16) + controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Tile", torch_dtype=torch.float16) pipe = StableDiffusion3ControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16 ) @@ -307,12 +305,9 @@ def test_tile(self): assert image.shape == (1024, 1024, 3) original_image = image[-3:, -3:, -1].flatten() + expected_image = np.array([0.6699, 0.6836, 0.6226, 0.6572, 0.7310, 0.6646, 0.6650, 0.6694, 0.6011]) - expected_image = np.array( - [0.6982422, 0.7011719, 0.65771484, 0.6904297, 0.7416992, 0.6904297, 0.6977539, 0.7080078, 0.6386719] - ) - - assert np.abs(original_image.flatten() - expected_image).max() < 1e-2 + assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2 def test_multi_controlnet(self): controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16) @@ -344,8 +339,6 @@ def test_multi_controlnet(self): assert image.shape == (1024, 1024, 3) original_image = image[-3:, -3:, -1].flatten() - expected_image = np.array( - [0.7451172, 0.7416992, 0.7158203, 0.7792969, 0.7607422, 0.7089844, 0.6855469, 0.71777344, 0.7314453] - ) + expected_image = np.array([0.7207, 0.7041, 0.6543, 0.7500, 0.7490, 0.6592, 0.6001, 0.7168, 0.7231]) - assert np.abs(original_image.flatten() - expected_image).max() < 1e-2 + assert numpy_cosine_similarity_distance(original_image.flatten(), expected_image) < 1e-2 diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 4caff4030261..3ccf3f80ba3c 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -2,13 +2,15 @@ import unittest import numpy as np +import pytest import torch +from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, - require_torch_gpu, + require_big_gpu_with_torch_cuda, slow, torch_device, ) @@ -191,7 +193,8 @@ def test_fused_qkv_projections(self): @slow -@require_torch_gpu +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda class FluxPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-schnell" @@ -212,18 +215,28 @@ def get_inputs(self, device, seed=0): else: generator = torch.Generator(device="cpu").manual_seed(seed) + prompt_embeds = torch.load( + hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") + ) + pooled_prompt_embeds = torch.load( + hf_hub_download( + repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" + ) + ) return { - "prompt": "A photo of a cat", + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, "num_inference_steps": 2, - "guidance_scale": 5.0, + "guidance_scale": 0.0, + "max_sequence_length": 256, "output_type": "np", "generator": generator, } - # TODO: Dhruv. Move large model tests to a dedicated runner) - @unittest.skip("We cannot run inference on this model with the current CI hardware") def test_flux_inference(self): - pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) + pipe = self.pipeline_class.from_pretrained( + self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None + ) pipe.enable_model_cpu_offload() inputs = self.get_inputs(torch_device) @@ -232,16 +245,36 @@ def test_flux_inference(self): image_slice = image[0, :10, :10] expected_slice = np.array( [ - [0.36132812, 0.30004883, 0.25830078], - [0.36669922, 0.31103516, 0.23754883], - [0.34814453, 0.29248047, 0.23583984], - [0.35791016, 0.30981445, 0.23999023], - [0.36328125, 0.31274414, 0.2607422], - [0.37304688, 0.32177734, 0.26171875], - [0.3671875, 0.31933594, 0.25756836], - [0.36035156, 0.31103516, 0.2578125], - [0.3857422, 0.33789062, 0.27563477], - [0.3701172, 0.31982422, 0.265625], + 0.3242, + 0.3203, + 0.3164, + 0.3164, + 0.3125, + 0.3125, + 0.3281, + 0.3242, + 0.3203, + 0.3301, + 0.3262, + 0.3242, + 0.3281, + 0.3242, + 0.3203, + 0.3262, + 0.3262, + 0.3164, + 0.3262, + 0.3281, + 0.3184, + 0.3281, + 0.3281, + 0.3203, + 0.3281, + 0.3281, + 0.3164, + 0.3320, + 0.3320, + 0.3203, ], dtype=np.float32, ) diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 94a85a56f510..7767c94c4879 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -2,13 +2,14 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, - require_torch_gpu, + require_big_gpu_with_torch_cuda, slow, torch_device, ) @@ -226,7 +227,8 @@ def test_fused_qkv_projections(self): @slow -@require_torch_gpu +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda class StableDiffusion3PipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Pipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index 9d131b28c308..695954163c8f 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -3,6 +3,7 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -16,7 +17,7 @@ from diffusers.utils.testing_utils import ( floats_tensor, numpy_cosine_similarity_distance, - require_torch_gpu, + require_big_gpu_with_torch_cuda, slow, torch_device, ) @@ -194,7 +195,8 @@ def test_multi_vae(self): @slow -@require_torch_gpu +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/utils/print_env.py b/utils/print_env.py index 3e4495c98094..9f88d940fe7d 100644 --- a/utils/print_env.py +++ b/utils/print_env.py @@ -37,6 +37,10 @@ print("Cuda version:", torch.version.cuda) print("CuDNN version:", torch.backends.cudnn.version()) print("Number of GPUs available:", torch.cuda.device_count()) + if torch.cuda.is_available(): + device_properties = torch.cuda.get_device_properties(0) + total_memory = device_properties.total_memory / (1024**3) + print(f"CUDA memory: {total_memory} GB") except ImportError: print("Torch version:", None) From 41e4779d988ead99e7acd78dc8e752de88777d0f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Oct 2024 21:17:41 +0530 Subject: [PATCH 07/20] [LoRA] fix: lora loading when using with a device_mapped model. (#9449) * fix: lora loading when using with a device_mapped model. * better attibutung * empty Co-authored-by: Benjamin Bossan * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * minors * better error messages. * fix-copies * add: tests, docs. * add hardware note. * quality * Update docs/source/en/training/distributed_inference.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fixes * skip properly. * fixes --------- Co-authored-by: Benjamin Bossan Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../en/training/distributed_inference.md | 2 + src/diffusers/loaders/lora_base.py | 12 +- src/diffusers/loaders/unet.py | 12 +- .../pipelines/pipeline_loading_utils.py | 7 + src/diffusers/pipelines/pipeline_utils.py | 31 ++++ tests/pipelines/audioldm2/test_audioldm2.py | 5 + tests/pipelines/controlnet/test_controlnet.py | 24 +++ .../controlnet/test_controlnet_img2img.py | 12 ++ .../controlnet/test_controlnet_inpaint.py | 12 ++ .../controlnet/test_controlnet_sdxl.py | 24 +++ tests/pipelines/flux/test_pipeline_flux.py | 171 ++++++++++++++++++ .../kandinsky/test_kandinsky_combined.py | 36 ++++ .../kandinsky2_2/test_kandinsky_combined.py | 36 ++++ tests/pipelines/musicldm/test_musicldm.py | 4 + .../test_stable_cascade_combined.py | 12 ++ .../test_stable_diffusion_adapter.py | 12 ++ .../test_stable_diffusion_xl_adapter.py | 18 +- .../stable_unclip/test_stable_unclip.py | 12 ++ .../test_stable_unclip_img2img.py | 12 ++ tests/pipelines/test_pipelines_common.py | 79 ++++++++ .../pipelines/unidiffuser/test_unidiffuser.py | 9 + .../wuerstchen/test_wuerstchen_combined.py | 12 ++ 22 files changed, 546 insertions(+), 8 deletions(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 0e1eb7962bf7..8e68b1bed382 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -237,3 +237,5 @@ with torch.no_grad(): ``` By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs. + +This workflow is also compatible with LoRAs via [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. However, only LoRAs without text encoder components are currently supported in this workflow. diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index e124b6eeacf3..a13f8c20112a 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -31,6 +31,7 @@ delete_adapter_layers, deprecate, is_accelerate_available, + is_accelerate_version, is_peft_available, is_transformers_available, logging, @@ -214,9 +215,18 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False + def model_has_device_map(model): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + return getattr(model, "hf_device_map", None) is not None + if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if ( + isinstance(component, nn.Module) + and hasattr(component, "_hf_hook") + and not model_has_device_map(component) + ): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 2fa7732a6a3b..55b1a24e60db 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -39,6 +39,7 @@ get_adapter_name, get_peft_kwargs, is_accelerate_available, + is_accelerate_version, is_peft_version, is_torch_version, logging, @@ -398,9 +399,18 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False + def model_has_device_map(model): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + return getattr(model, "hf_device_map", None) is not None + if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if ( + isinstance(component, nn.Module) + and hasattr(component, "_hf_hook") + and not model_has_device_map(component) + ): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 5eba1952e608..7d42ed5bcba8 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -36,6 +36,7 @@ deprecate, get_class_from_dynamic_module, is_accelerate_available, + is_accelerate_version, is_peft_available, is_transformers_available, logging, @@ -947,3 +948,9 @@ def _get_ignore_patterns( ) return ignore_patterns + + +def model_has_device_map(model): + if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): + return False + return getattr(model, "hf_device_map", None) is not None diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2e1858b16148..791b3e5e9394 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -85,6 +85,7 @@ _update_init_kwargs_with_connected_pipeline, load_sub_model, maybe_raise_or_warn, + model_has_device_map, variant_compatible_siblings, warn_deprecated_model_variant, ) @@ -406,6 +407,16 @@ def module_is_offloaded(module): return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) + # device-mapped modules should not go through any device placements. + device_mapped_components = [ + key for key, component in self.components.items() if model_has_device_map(component) + ] + if device_mapped_components: + raise ValueError( + "The following pipeline components have been found to use a device map: " + f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`." + ) + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() @@ -1002,6 +1013,16 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ + # device-mapped modules should not go through any device placements. + device_mapped_components = [ + key for key, component in self.components.items() if model_has_device_map(component) + ] + if device_mapped_components: + raise ValueError( + "The following pipeline components have been found to use a device map: " + f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`." + ) + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( @@ -1104,6 +1125,16 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ + # device-mapped modules should not go through any device placements. + device_mapped_components = [ + key for key, component in self.components.items() if model_has_device_map(component) + ] + if device_mapped_components: + raise ValueError( + "The following pipeline components have been found to use a device map: " + f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`." + ) + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index fb550dd3219d..9af49697f913 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -506,9 +506,14 @@ def test_to_dtype(self): model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) + @unittest.skip("Test currently not supported.") def test_sequential_cpu_offload_forward_pass(self): pass + @unittest.skip("Test currently not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + @nightly class AudioLDM2PipelineSlowTests(unittest.TestCase): diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index b12655d989d4..1cb6569716a8 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -514,6 +514,18 @@ def test_inference_multiple_prompt_input(self): assert image.shape == (4, 64, 64, 3) + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + class StableDiffusionMultiControlNetOneModelPipelineFastTests( IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase @@ -697,6 +709,18 @@ def test_save_pretrained_raise_not_implemented_exception(self): except NotImplementedError: pass + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 7c4ae716b37d..45bc70c809f2 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -389,6 +389,18 @@ def test_save_pretrained_raise_not_implemented_exception(self): except NotImplementedError: pass + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index e49106334c2e..af8ddb7e6b28 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -441,6 +441,18 @@ def test_save_pretrained_raise_not_implemented_exception(self): except NotImplementedError: pass + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index c931391ac4d5..a8fa23678fc7 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -683,6 +683,18 @@ def test_inference_batch_single_identical(self): def test_save_load_optional_components(self): return self._test_save_load_optional_components() + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase @@ -887,6 +899,18 @@ def test_negative_conditions(self): self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2) + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + @slow @require_torch_gpu diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 3ccf3f80ba3c..e864ff85daa4 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -8,9 +8,11 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers.image_processor import VaeImageProcessor from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, require_big_gpu_with_torch_cuda, + require_torch_multi_gpu, slow, torch_device, ) @@ -282,3 +284,172 @@ def test_flux_inference(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4 + + @require_torch_multi_gpu + @torch.no_grad() + def test_flux_component_sharding(self): + """ + internal note: test was run on `audace`. + """ + + ckpt_id = "black-forest-labs/FLUX.1-dev" + dtype = torch.bfloat16 + prompt = "a photo of a cat with tiger-like look" + + pipeline = FluxPipeline.from_pretrained( + ckpt_id, + transformer=None, + vae=None, + device_map="balanced", + max_memory={0: "16GB", 1: "16GB"}, + torch_dtype=dtype, + ) + prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) + + del pipeline.text_encoder + del pipeline.text_encoder_2 + del pipeline.tokenizer + del pipeline.tokenizer_2 + del pipeline + + gc.collect() + torch.cuda.empty_cache() + + transformer = FluxTransformer2DModel.from_pretrained( + ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype + ) + pipeline = FluxPipeline.from_pretrained( + ckpt_id, + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + torch_dtype=dtype, + ) + + height, width = 768, 1360 + # No need to wrap it up under `torch.no_grad()` as pipeline call method + # is already wrapped under that. + latents = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=10, + guidance_scale=3.5, + height=height, + width=width, + output_type="latent", + generator=torch.manual_seed(0), + ).images + latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy() + expected_slice = np.array([-0.377, -0.3008, -0.5117, -0.252, 0.0615, -0.3477, -0.1309, -0.1914, 0.1533]) + + assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4 + + del pipeline.transformer + del pipeline + + gc.collect() + torch.cuda.empty_cache() + + vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + + image = vae.decode(latents, return_dict=False)[0] + image = image_processor.postprocess(image, output_type="np") + image_slice = image[0, :3, :3, -1].flatten() + expected_slice = np.array([0.127, 0.1113, 0.1055, 0.1172, 0.1172, 0.1074, 0.1191, 0.1191, 0.1152]) + + assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4 + + @require_torch_multi_gpu + @torch.no_grad() + def test_flux_component_sharding_with_lora(self): + """ + internal note: test was run on `audace`. + """ + + ckpt_id = "black-forest-labs/FLUX.1-dev" + dtype = torch.bfloat16 + prompt = "jon snow eating pizza." + + pipeline = FluxPipeline.from_pretrained( + ckpt_id, + transformer=None, + vae=None, + device_map="balanced", + max_memory={0: "16GB", 1: "16GB"}, + torch_dtype=dtype, + ) + prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt( + prompt=prompt, prompt_2=None, max_sequence_length=512 + ) + + del pipeline.text_encoder + del pipeline.text_encoder_2 + del pipeline.tokenizer + del pipeline.tokenizer_2 + del pipeline + + gc.collect() + torch.cuda.empty_cache() + + transformer = FluxTransformer2DModel.from_pretrained( + ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype + ) + pipeline = FluxPipeline.from_pretrained( + ckpt_id, + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + torch_dtype=dtype, + ) + pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") + + height, width = 768, 1360 + # No need to wrap it up under `torch.no_grad()` as pipeline call method + # is already wrapped under that. + latents = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + num_inference_steps=10, + guidance_scale=3.5, + height=height, + width=width, + output_type="latent", + generator=torch.manual_seed(0), + ).images + latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy() + expected_slice = np.array([-0.6523, -0.4961, -0.9141, -0.5, -0.2129, -0.6914, -0.375, -0.5664, -0.1699]) + + assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4 + + del pipeline.transformer + del pipeline + + gc.collect() + torch.cuda.empty_cache() + + vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + + image = vae.decode(latents, return_dict=False)[0] + image = image_processor.postprocess(image, output_type="np") + image_slice = image[0, :3, :3, -1].flatten() + expected_slice = np.array([0.1211, 0.1094, 0.1035, 0.1094, 0.1113, 0.1074, 0.1133, 0.1133, 0.1094]) + + assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4 diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py index 607a47e08e58..739f8676cbd3 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py @@ -139,6 +139,18 @@ def test_float16_inference(self): def test_dict_tuple_outputs_equivalent(self): super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4) + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyImg2ImgCombinedPipeline @@ -248,6 +260,18 @@ def test_dict_tuple_outputs_equivalent(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=5e-4) + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyInpaintCombinedPipeline @@ -363,3 +387,15 @@ def test_save_load_optional_components(self): def test_save_load_local(self): super().test_save_load_local(expected_max_difference=5e-3) + + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index dbba0831397b..cf2b70f4c990 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -159,6 +159,18 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyV22Img2ImgCombinedPipeline @@ -281,6 +293,18 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyV22InpaintCombinedPipeline @@ -404,3 +428,15 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass + + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index e51f5103933a..70765d981bbc 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -404,6 +404,10 @@ def test_to_dtype(self): model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) + @unittest.skip("Test currently not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py index d256deed376c..d799ae6e623a 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py @@ -279,3 +279,15 @@ def test_stable_cascade_combined_prompt_embeds(self): ) assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5 + + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py index 2a1e691e9e8f..996afbb9d323 100644 --- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py @@ -593,6 +593,18 @@ def test_inference_batch_single_identical( if test_mean_pixel_difference: assert_mean_pixel_difference(output_batch[0][0], output[0][0]) + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 2091af9c0383..61b5b754c44c 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -642,9 +642,6 @@ def test_adapter_sdxl_lcm(self): assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) - debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] - print(",".join(debug)) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_adapter_sdxl_lcm_custom_timesteps(self): @@ -667,7 +664,16 @@ def test_adapter_sdxl_lcm_custom_timesteps(self): assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) - debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] - print(",".join(debug)) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index bb54d212a786..be5e3783ff5c 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -184,6 +184,18 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index a5cbf7761501..1a662819b00f 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -205,6 +205,18 @@ def test_inference_batch_single_identical(self): def test_xformers_attention_forwardGenerator_pass(self): self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False) + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass + @nightly @require_torch_gpu diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 295a94c1d2e4..f5ceda8f2703 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -41,8 +41,11 @@ from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import ( CaptureLogger, + nightly, require_torch, + require_torch_multi_gpu, skip_mps, + slow, torch_device, ) @@ -59,6 +62,10 @@ from ..others.test_utils import TOKEN, USER, is_staging_test +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + + def to_np(tensor): if isinstance(tensor, torch.Tensor): tensor = tensor.detach().cpu().numpy() @@ -1908,6 +1915,78 @@ def test_StableDiffusionMixin_component(self): ) ) + @require_torch_multi_gpu + @slow + @nightly + def test_calling_to_raises_error_device_mapped_components(self, safe_serialization=True): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + max_model_size = max( + compute_module_sizes(module)[""] + for _, module in pipe.components.items() + if isinstance(module, torch.nn.Module) + ) + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization) + max_memory = {0: max_model_size, 1: max_model_size} + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory) + + with self.assertRaises(ValueError) as err_context: + loaded_pipe.to(torch_device) + + self.assertTrue( + "The following pipeline components have been found" in str(err_context.exception) + and "This is incompatible with explicitly setting the device using `to()`" in str(err_context.exception) + ) + + @require_torch_multi_gpu + @slow + @nightly + def test_calling_mco_raises_error_device_mapped_components(self, safe_serialization=True): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + max_model_size = max( + compute_module_sizes(module)[""] + for _, module in pipe.components.items() + if isinstance(module, torch.nn.Module) + ) + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization) + max_memory = {0: max_model_size, 1: max_model_size} + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory) + + with self.assertRaises(ValueError) as err_context: + loaded_pipe.enable_model_cpu_offload() + + self.assertTrue( + "The following pipeline components have been found" in str(err_context.exception) + and "This is incompatible with `enable_model_cpu_offload()`" in str(err_context.exception) + ) + + @require_torch_multi_gpu + @slow + @nightly + def test_calling_sco_raises_error_device_mapped_components(self, safe_serialization=True): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + max_model_size = max( + compute_module_sizes(module)[""] + for _, module in pipe.components.items() + if isinstance(module, torch.nn.Module) + ) + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization) + max_memory = {0: max_model_size, 1: max_model_size} + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory) + + with self.assertRaises(ValueError) as err_context: + loaded_pipe.enable_sequential_cpu_offload() + + self.assertTrue( + "The following pipeline components have been found" in str(err_context.exception) + and "This is incompatible with `enable_sequential_cpu_offload()`" in str(err_context.exception) + ) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 2e0ba1cfb8eb..5cf017029fdf 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -576,6 +576,15 @@ def test_unidiffuser_default_img2text_v1_cuda_fp16(self): expected_text_prefix = '" This This' assert text[0][: len(expected_text_prefix)] == expected_text_prefix + def test_calling_mco_raises_error_device_mapped_components(self): + super().test_calling_mco_raises_error_device_mapped_components(safe_serialization=False) + + def test_calling_to_raises_error_device_mapped_components(self): + super().test_calling_to_raises_error_device_mapped_components(safe_serialization=False) + + def test_calling_sco_raises_error_device_mapped_components(self): + super().test_calling_sco_raises_error_device_mapped_components(safe_serialization=False) + @nightly @require_torch_gpu diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 0caed159100a..cd7891767f65 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -237,3 +237,15 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass + + @unittest.skip("Test not supported.") + def test_calling_mco_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_to_raises_error_device_mapped_components(self): + pass + + @unittest.skip("Test not supported.") + def test_calling_sco_raises_error_device_mapped_components(self): + pass From d2e5cb3c1072ad324d1c9c4bf19be98bc4280282 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 31 Oct 2024 08:19:32 -1000 Subject: [PATCH 08/20] =?UTF-8?q?Revert=20"[LoRA]=20fix:=20lora=20loading?= =?UTF-8?q?=20when=20using=20with=20a=20device=5Fmapped=20mode=E2=80=A6=20?= =?UTF-8?q?(#9823)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Revert "[LoRA] fix: lora loading when using with a device_mapped model. (#9449)" This reverts commit 41e4779d988ead99e7acd78dc8e752de88777d0f. --- .../en/training/distributed_inference.md | 2 - src/diffusers/loaders/lora_base.py | 12 +- src/diffusers/loaders/unet.py | 12 +- .../pipelines/pipeline_loading_utils.py | 7 - src/diffusers/pipelines/pipeline_utils.py | 31 ---- tests/pipelines/audioldm2/test_audioldm2.py | 5 - tests/pipelines/controlnet/test_controlnet.py | 24 --- .../controlnet/test_controlnet_img2img.py | 12 -- .../controlnet/test_controlnet_inpaint.py | 12 -- .../controlnet/test_controlnet_sdxl.py | 24 --- tests/pipelines/flux/test_pipeline_flux.py | 171 ------------------ .../kandinsky/test_kandinsky_combined.py | 36 ---- .../kandinsky2_2/test_kandinsky_combined.py | 36 ---- tests/pipelines/musicldm/test_musicldm.py | 4 - .../test_stable_cascade_combined.py | 12 -- .../test_stable_diffusion_adapter.py | 12 -- .../test_stable_diffusion_xl_adapter.py | 18 +- .../stable_unclip/test_stable_unclip.py | 12 -- .../test_stable_unclip_img2img.py | 12 -- tests/pipelines/test_pipelines_common.py | 79 -------- .../pipelines/unidiffuser/test_unidiffuser.py | 9 - .../wuerstchen/test_wuerstchen_combined.py | 12 -- 22 files changed, 8 insertions(+), 546 deletions(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 8e68b1bed382..0e1eb7962bf7 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -237,5 +237,3 @@ with torch.no_grad(): ``` By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs. - -This workflow is also compatible with LoRAs via [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. However, only LoRAs without text encoder components are currently supported in this workflow. diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index a13f8c20112a..e124b6eeacf3 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -31,7 +31,6 @@ delete_adapter_layers, deprecate, is_accelerate_available, - is_accelerate_version, is_peft_available, is_transformers_available, logging, @@ -215,18 +214,9 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - def model_has_device_map(model): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - return getattr(model, "hf_device_map", None) is not None - if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if ( - isinstance(component, nn.Module) - and hasattr(component, "_hf_hook") - and not model_has_device_map(component) - ): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 55b1a24e60db..2fa7732a6a3b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -39,7 +39,6 @@ get_adapter_name, get_peft_kwargs, is_accelerate_available, - is_accelerate_version, is_peft_version, is_torch_version, logging, @@ -399,18 +398,9 @@ def _optionally_disable_offloading(cls, _pipeline): is_model_cpu_offload = False is_sequential_cpu_offload = False - def model_has_device_map(model): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - return getattr(model, "hf_device_map", None) is not None - if _pipeline is not None and _pipeline.hf_device_map is None: for _, component in _pipeline.components.items(): - if ( - isinstance(component, nn.Module) - and hasattr(component, "_hf_hook") - and not model_has_device_map(component) - ): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 7d42ed5bcba8..5eba1952e608 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -36,7 +36,6 @@ deprecate, get_class_from_dynamic_module, is_accelerate_available, - is_accelerate_version, is_peft_available, is_transformers_available, logging, @@ -948,9 +947,3 @@ def _get_ignore_patterns( ) return ignore_patterns - - -def model_has_device_map(model): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - return getattr(model, "hf_device_map", None) is not None diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 791b3e5e9394..2e1858b16148 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -85,7 +85,6 @@ _update_init_kwargs_with_connected_pipeline, load_sub_model, maybe_raise_or_warn, - model_has_device_map, variant_compatible_siblings, warn_deprecated_model_variant, ) @@ -407,16 +406,6 @@ def module_is_offloaded(module): return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - # device-mapped modules should not go through any device placements. - device_mapped_components = [ - key for key, component in self.components.items() if model_has_device_map(component) - ] - if device_mapped_components: - raise ValueError( - "The following pipeline components have been found to use a device map: " - f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`." - ) - # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() @@ -1013,16 +1002,6 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ - # device-mapped modules should not go through any device placements. - device_mapped_components = [ - key for key, component in self.components.items() if model_has_device_map(component) - ] - if device_mapped_components: - raise ValueError( - "The following pipeline components have been found to use a device map: " - f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`." - ) - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( @@ -1125,16 +1104,6 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ - # device-mapped modules should not go through any device placements. - device_mapped_components = [ - key for key, component in self.components.items() if model_has_device_map(component) - ] - if device_mapped_components: - raise ValueError( - "The following pipeline components have been found to use a device map: " - f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`." - ) - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index 9af49697f913..fb550dd3219d 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -506,14 +506,9 @@ def test_to_dtype(self): model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) - @unittest.skip("Test currently not supported.") def test_sequential_cpu_offload_forward_pass(self): pass - @unittest.skip("Test currently not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - @nightly class AudioLDM2PipelineSlowTests(unittest.TestCase): diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 1cb6569716a8..b12655d989d4 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -514,18 +514,6 @@ def test_inference_multiple_prompt_input(self): assert image.shape == (4, 64, 64, 3) - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - class StableDiffusionMultiControlNetOneModelPipelineFastTests( IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase @@ -709,18 +697,6 @@ def test_save_pretrained_raise_not_implemented_exception(self): except NotImplementedError: pass - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 45bc70c809f2..7c4ae716b37d 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -389,18 +389,6 @@ def test_save_pretrained_raise_not_implemented_exception(self): except NotImplementedError: pass - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index af8ddb7e6b28..e49106334c2e 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -441,18 +441,6 @@ def test_save_pretrained_raise_not_implemented_exception(self): except NotImplementedError: pass - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index a8fa23678fc7..c931391ac4d5 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -683,18 +683,6 @@ def test_inference_batch_single_identical(self): def test_save_load_optional_components(self): return self._test_save_load_optional_components() - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase @@ -899,18 +887,6 @@ def test_negative_conditions(self): self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2) - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index e864ff85daa4..3ccf3f80ba3c 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -8,11 +8,9 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel -from diffusers.image_processor import VaeImageProcessor from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, require_big_gpu_with_torch_cuda, - require_torch_multi_gpu, slow, torch_device, ) @@ -284,172 +282,3 @@ def test_flux_inference(self): max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4 - - @require_torch_multi_gpu - @torch.no_grad() - def test_flux_component_sharding(self): - """ - internal note: test was run on `audace`. - """ - - ckpt_id = "black-forest-labs/FLUX.1-dev" - dtype = torch.bfloat16 - prompt = "a photo of a cat with tiger-like look" - - pipeline = FluxPipeline.from_pretrained( - ckpt_id, - transformer=None, - vae=None, - device_map="balanced", - max_memory={0: "16GB", 1: "16GB"}, - torch_dtype=dtype, - ) - prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt( - prompt=prompt, prompt_2=None, max_sequence_length=512 - ) - - del pipeline.text_encoder - del pipeline.text_encoder_2 - del pipeline.tokenizer - del pipeline.tokenizer_2 - del pipeline - - gc.collect() - torch.cuda.empty_cache() - - transformer = FluxTransformer2DModel.from_pretrained( - ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype - ) - pipeline = FluxPipeline.from_pretrained( - ckpt_id, - text_encoder=None, - text_encoder_2=None, - tokenizer=None, - tokenizer_2=None, - vae=None, - transformer=transformer, - torch_dtype=dtype, - ) - - height, width = 768, 1360 - # No need to wrap it up under `torch.no_grad()` as pipeline call method - # is already wrapped under that. - latents = pipeline( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - num_inference_steps=10, - guidance_scale=3.5, - height=height, - width=width, - output_type="latent", - generator=torch.manual_seed(0), - ).images - latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy() - expected_slice = np.array([-0.377, -0.3008, -0.5117, -0.252, 0.0615, -0.3477, -0.1309, -0.1914, 0.1533]) - - assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4 - - del pipeline.transformer - del pipeline - - gc.collect() - torch.cuda.empty_cache() - - vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - - latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) - latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor - - image = vae.decode(latents, return_dict=False)[0] - image = image_processor.postprocess(image, output_type="np") - image_slice = image[0, :3, :3, -1].flatten() - expected_slice = np.array([0.127, 0.1113, 0.1055, 0.1172, 0.1172, 0.1074, 0.1191, 0.1191, 0.1152]) - - assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4 - - @require_torch_multi_gpu - @torch.no_grad() - def test_flux_component_sharding_with_lora(self): - """ - internal note: test was run on `audace`. - """ - - ckpt_id = "black-forest-labs/FLUX.1-dev" - dtype = torch.bfloat16 - prompt = "jon snow eating pizza." - - pipeline = FluxPipeline.from_pretrained( - ckpt_id, - transformer=None, - vae=None, - device_map="balanced", - max_memory={0: "16GB", 1: "16GB"}, - torch_dtype=dtype, - ) - prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt( - prompt=prompt, prompt_2=None, max_sequence_length=512 - ) - - del pipeline.text_encoder - del pipeline.text_encoder_2 - del pipeline.tokenizer - del pipeline.tokenizer_2 - del pipeline - - gc.collect() - torch.cuda.empty_cache() - - transformer = FluxTransformer2DModel.from_pretrained( - ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype - ) - pipeline = FluxPipeline.from_pretrained( - ckpt_id, - text_encoder=None, - text_encoder_2=None, - tokenizer=None, - tokenizer_2=None, - vae=None, - transformer=transformer, - torch_dtype=dtype, - ) - pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") - - height, width = 768, 1360 - # No need to wrap it up under `torch.no_grad()` as pipeline call method - # is already wrapped under that. - latents = pipeline( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - num_inference_steps=10, - guidance_scale=3.5, - height=height, - width=width, - output_type="latent", - generator=torch.manual_seed(0), - ).images - latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy() - expected_slice = np.array([-0.6523, -0.4961, -0.9141, -0.5, -0.2129, -0.6914, -0.375, -0.5664, -0.1699]) - - assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4 - - del pipeline.transformer - del pipeline - - gc.collect() - torch.cuda.empty_cache() - - vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) - - latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) - latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor - - image = vae.decode(latents, return_dict=False)[0] - image = image_processor.postprocess(image, output_type="np") - image_slice = image[0, :3, :3, -1].flatten() - expected_slice = np.array([0.1211, 0.1094, 0.1035, 0.1094, 0.1113, 0.1074, 0.1133, 0.1133, 0.1094]) - - assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4 diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py index 739f8676cbd3..607a47e08e58 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py @@ -139,18 +139,6 @@ def test_float16_inference(self): def test_dict_tuple_outputs_equivalent(self): super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4) - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyImg2ImgCombinedPipeline @@ -260,18 +248,6 @@ def test_dict_tuple_outputs_equivalent(self): def test_save_load_optional_components(self): super().test_save_load_optional_components(expected_max_difference=5e-4) - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyInpaintCombinedPipeline @@ -387,15 +363,3 @@ def test_save_load_optional_components(self): def test_save_load_local(self): super().test_save_load_local(expected_max_difference=5e-3) - - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index cf2b70f4c990..dbba0831397b 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -159,18 +159,6 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyV22Img2ImgCombinedPipeline @@ -293,18 +281,6 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = KandinskyV22InpaintCombinedPipeline @@ -428,15 +404,3 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass - - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index 70765d981bbc..e51f5103933a 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -404,10 +404,6 @@ def test_to_dtype(self): model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) - @unittest.skip("Test currently not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py index d799ae6e623a..d256deed376c 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py @@ -279,15 +279,3 @@ def test_stable_cascade_combined_prompt_embeds(self): ) assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5 - - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py index 996afbb9d323..2a1e691e9e8f 100644 --- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py @@ -593,18 +593,6 @@ def test_inference_batch_single_identical( if test_mean_pixel_difference: assert_mean_pixel_difference(output_batch[0][0], output[0][0]) - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 61b5b754c44c..2091af9c0383 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -642,6 +642,9 @@ def test_adapter_sdxl_lcm(self): assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) + debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] + print(",".join(debug)) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_adapter_sdxl_lcm_custom_timesteps(self): @@ -664,16 +667,7 @@ def test_adapter_sdxl_lcm_custom_timesteps(self): assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass + debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] + print(",".join(debug)) - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index be5e3783ff5c..bb54d212a786 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -184,18 +184,6 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=1e-3) - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 1a662819b00f..a5cbf7761501 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -205,18 +205,6 @@ def test_inference_batch_single_identical(self): def test_xformers_attention_forwardGenerator_pass(self): self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False) - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass - @nightly @require_torch_gpu diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f5ceda8f2703..295a94c1d2e4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -41,11 +41,8 @@ from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import ( CaptureLogger, - nightly, require_torch, - require_torch_multi_gpu, skip_mps, - slow, torch_device, ) @@ -62,10 +59,6 @@ from ..others.test_utils import TOKEN, USER, is_staging_test -if is_accelerate_available(): - from accelerate.utils import compute_module_sizes - - def to_np(tensor): if isinstance(tensor, torch.Tensor): tensor = tensor.detach().cpu().numpy() @@ -1915,78 +1908,6 @@ def test_StableDiffusionMixin_component(self): ) ) - @require_torch_multi_gpu - @slow - @nightly - def test_calling_to_raises_error_device_mapped_components(self, safe_serialization=True): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - max_model_size = max( - compute_module_sizes(module)[""] - for _, module in pipe.components.items() - if isinstance(module, torch.nn.Module) - ) - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization) - max_memory = {0: max_model_size, 1: max_model_size} - loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory) - - with self.assertRaises(ValueError) as err_context: - loaded_pipe.to(torch_device) - - self.assertTrue( - "The following pipeline components have been found" in str(err_context.exception) - and "This is incompatible with explicitly setting the device using `to()`" in str(err_context.exception) - ) - - @require_torch_multi_gpu - @slow - @nightly - def test_calling_mco_raises_error_device_mapped_components(self, safe_serialization=True): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - max_model_size = max( - compute_module_sizes(module)[""] - for _, module in pipe.components.items() - if isinstance(module, torch.nn.Module) - ) - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization) - max_memory = {0: max_model_size, 1: max_model_size} - loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory) - - with self.assertRaises(ValueError) as err_context: - loaded_pipe.enable_model_cpu_offload() - - self.assertTrue( - "The following pipeline components have been found" in str(err_context.exception) - and "This is incompatible with `enable_model_cpu_offload()`" in str(err_context.exception) - ) - - @require_torch_multi_gpu - @slow - @nightly - def test_calling_sco_raises_error_device_mapped_components(self, safe_serialization=True): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - max_model_size = max( - compute_module_sizes(module)[""] - for _, module in pipe.components.items() - if isinstance(module, torch.nn.Module) - ) - with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization) - max_memory = {0: max_model_size, 1: max_model_size} - loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory) - - with self.assertRaises(ValueError) as err_context: - loaded_pipe.enable_sequential_cpu_offload() - - self.assertTrue( - "The following pipeline components have been found" in str(err_context.exception) - and "This is incompatible with `enable_sequential_cpu_offload()`" in str(err_context.exception) - ) - @is_staging_test class PipelinePushToHubTester(unittest.TestCase): diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 5cf017029fdf..2e0ba1cfb8eb 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -576,15 +576,6 @@ def test_unidiffuser_default_img2text_v1_cuda_fp16(self): expected_text_prefix = '" This This' assert text[0][: len(expected_text_prefix)] == expected_text_prefix - def test_calling_mco_raises_error_device_mapped_components(self): - super().test_calling_mco_raises_error_device_mapped_components(safe_serialization=False) - - def test_calling_to_raises_error_device_mapped_components(self): - super().test_calling_to_raises_error_device_mapped_components(safe_serialization=False) - - def test_calling_sco_raises_error_device_mapped_components(self): - super().test_calling_sco_raises_error_device_mapped_components(safe_serialization=False) - @nightly @require_torch_gpu diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index cd7891767f65..0caed159100a 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -237,15 +237,3 @@ def test_callback_inputs(self): def test_callback_cfg(self): pass - - @unittest.skip("Test not supported.") - def test_calling_mco_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_to_raises_error_device_mapped_components(self): - pass - - @unittest.skip("Test not supported.") - def test_calling_sco_raises_error_device_mapped_components(self): - pass From c75431843f3b5b4915a57fe68a3e5420dc46a280 Mon Sep 17 00:00:00 2001 From: Abhipsha Das Date: Thu, 31 Oct 2024 17:53:00 -0400 Subject: [PATCH 09/20] [Model Card] standardize advanced diffusion training sd15 lora (#7613) * modelcard generation edit * add missed tag * fix param name * fix var * change str to dict * add use_dora check * use correct tags for lora * make style && make quality --------- Co-authored-by: Aryan --- .../train_dreambooth_lora_sd15_advanced.py | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 7fdea56dc5cb..afe30680567d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -67,6 +67,7 @@ convert_state_dict_to_kohya, is_wandb_available, ) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available @@ -79,30 +80,27 @@ def save_model_card( repo_id: str, use_dora: bool, - images=None, - base_model=str, + images: list = None, + base_model: str = None, train_text_encoder=False, train_text_encoder_ti=False, token_abstraction_dict=None, - instance_prompt=str, - validation_prompt=str, + instance_prompt=None, + validation_prompt=None, repo_folder=None, vae_path=None, ): - img_str = "widget:\n" lora = "lora" if not use_dora else "dora" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f""" - - text: '{validation_prompt if validation_prompt else ' ' }' - output: - url: - "image_{i}.png" - """ - if not images: - img_str += f""" - - text: '{instance_prompt}' - """ + + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + else: + widget_dict.append({"text": instance_prompt}) embeddings_filename = f"{repo_folder}_emb" instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1)) ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) @@ -137,24 +135,7 @@ def save_model_card( trigger_str += f""" to trigger concept `{key}` → use `{tokens}` in your prompt \n """ - - yaml = f"""--- -tags: -- stable-diffusion -- stable-diffusion-diffusers -- diffusers-training -- text-to-image -- diffusers -- {lora} -- template:sd-lora -{img_str} -base_model: {base_model} -instance_prompt: {instance_prompt} -license: openrail++ ---- -""" - - model_card = f""" + model_description = f""" # SD1.5 LoRA DreamBooth - {repo_id} @@ -202,8 +183,28 @@ def save_model_card( Special VAE used for training: {vae_path}. """ - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + inference=True, + widget=widget_dict, + ) + + tags = [ + "text-to-image", + "diffusers", + "diffusers-training", + lora, + "template:sd-lora" "stable-diffusion", + "stable-diffusion-diffusers", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) def import_model_class_from_model_name_or_path( From 9dcac8305749de1eea84dcc53367bfac9b2bc35b Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Thu, 31 Oct 2024 21:33:15 -0600 Subject: [PATCH 10/20] NPU Adaption for FLUX (#9751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX --------- Co-authored-by: 蒋硕 --- examples/dreambooth/train_dreambooth_flux.py | 22 +- src/diffusers/models/attention_processor.py | 217 ++++++++++++++++++ .../models/transformers/transformer_flux.py | 7 +- 3 files changed, 243 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index d23d05f7e38b..bd1c29009976 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -57,6 +57,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -68,6 +69,12 @@ logger = get_logger(__name__) +if is_torch_npu_available(): + import torch_npu + + torch.npu.config.allow_internal_format = False + torch.npu.set_compile_mode(jit_compile=False) + def save_model_card( repo_id: str, @@ -189,6 +196,8 @@ def log_validation( del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() return images @@ -1035,7 +1044,9 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + has_supported_fp16_accelerator = ( + torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available() + ) torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 @@ -1073,6 +1084,8 @@ def main(args): del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() # Handle the repository creation if accelerator.is_main_process: @@ -1354,6 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1719,7 +1734,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() gc.collect() # Save the lora layers diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index db88ecbbb9d3..20c5cf3d925e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1893,6 +1893,112 @@ def __call__( return hidden_states +class FluxAttnProcessor2_0_NPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class FusedFluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -1987,6 +2093,117 @@ def __call__( return hidden_states +class FusedFluxAttnProcessor2_0_NPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 5d39a1bb5391..f078cace0f3e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -27,11 +27,13 @@ Attention, AttentionProcessor, FluxAttnProcessor2_0, + FluxAttnProcessor2_0_NPU, FusedFluxAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -64,7 +66,10 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - processor = FluxAttnProcessor2_0() + if is_torch_npu_available(): + processor = FluxAttnProcessor2_0_NPU() + else: + processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, From f55f1f7ee50283c4eb239b12e5c88738886c8b21 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Fri, 1 Nov 2024 09:20:19 +0530 Subject: [PATCH 11/20] Fixes EMAModel "from_pretrained" method (#9779) * fix from_pretrained and added test * make style --------- Co-authored-by: Sayak Paul --- src/diffusers/training_utils.py | 2 +- tests/others/test_ema.py | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 0e0d0ce5b568..d2bf3fe07185 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -379,7 +379,7 @@ def __init__( @classmethod def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": - _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 5bed42b8488f..3443e6366f01 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -59,6 +59,25 @@ def simulate_backprop(self, unet): unet.load_state_dict(updated_state_dict) return unet + def test_from_pretrained(self): + # Save the model parameters to a temporary directory + unet, ema_unet = self.get_models() + with tempfile.TemporaryDirectory() as tmpdir: + ema_unet.save_pretrained(tmpdir) + + # Load the EMA model from the saved directory + loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) + + # Check that the shadow parameters of the loaded model match the original EMA model + for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): + assert torch.allclose(original_param, loaded_param, atol=1e-4) + + # Verify that the optimization step is also preserved + assert loaded_ema_unet.optimization_step == ema_unet.optimization_step + + # Check the decay value + assert loaded_ema_unet.decay == ema_unet.decay + def test_optimization_steps_updated(self): unet, ema_unet = self.get_models() # Take the first (hypothetical) EMA step. @@ -194,6 +213,25 @@ def simulate_backprop(self, unet): unet.load_state_dict(updated_state_dict) return unet + def test_from_pretrained(self): + # Save the model parameters to a temporary directory + unet, ema_unet = self.get_models() + with tempfile.TemporaryDirectory() as tmpdir: + ema_unet.save_pretrained(tmpdir) + + # Load the EMA model from the saved directory + loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) + + # Check that the shadow parameters of the loaded model match the original EMA model + for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): + assert torch.allclose(original_param, loaded_param, atol=1e-4) + + # Verify that the optimization step is also preserved + assert loaded_ema_unet.optimization_step == ema_unet.optimization_step + + # Check the decay value + assert loaded_ema_unet.decay == ema_unet.decay + def test_optimization_steps_updated(self): unet, ema_unet = self.get_models() # Take the first (hypothetical) EMA step. From 7ffbc2525fdf58f9f7aea8b2d5c05c1da63dffa3 Mon Sep 17 00:00:00 2001 From: ScilenceForest <45549187+ScilenceForest@users.noreply.github.com> Date: Fri, 1 Nov 2024 12:45:10 +0800 Subject: [PATCH 12/20] Update train_controlnet_flux.py,Fix size mismatch issue in validation (#9679) Update train_controlnet_flux.py Fix the problem of inconsistency between size of image and size of validation_image which causes np.stack to report error. Co-authored-by: Sayak Paul --- examples/controlnet/train_controlnet_flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 2958a9e5f28f..2524d299ef89 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -152,6 +152,7 @@ def log_validation( guidance_scale=3.5, generator=generator, ).images[0] + image = image.resize((args.resolution, args.resolution)) images.append(image) image_logs.append( {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} From 3deed729e677a011c1a2552faccce3cbb9303626 Mon Sep 17 00:00:00 2001 From: Boseong Jeon Date: Fri, 1 Nov 2024 13:46:05 +0900 Subject: [PATCH 13/20] Handling mixed precision for dreambooth flux lora training (#9565) Handling mixed precision and add unwarp Co-authored-by: Sayak Paul Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_lora_flux.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index a0a197b1b2ee..e21485952583 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -177,7 +177,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1706,7 +1706,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if transformer.config.guidance_embeds: + if accelerator.unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1819,6 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # create pipeline if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, From a98a839de75f1ad82d8d200c3bc2e4ff89929081 Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Fri, 1 Nov 2024 00:49:32 -0600 Subject: [PATCH 14/20] Reduce Memory Cost in Flux Training (#9829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve NPU performance * Improve NPU performance * Improve NPU performance * Improve NPU performance * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory * Reduce memory cost for flux training process --------- Co-authored-by: 蒋硕 Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_flux.py | 6 ++++++ examples/dreambooth/train_dreambooth_lora_flux.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index bd1c29009976..9fd95fe823a5 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1740,6 +1740,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_npu.npu.empty_cache() gc.collect() + images = None + del pipeline + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1798,6 +1801,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline + accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e21485952583..2c1126109a36 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1844,6 +1844,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): del text_encoder_one, text_encoder_two free_memory() + images = None + del pipeline + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1908,6 +1911,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline + accelerator.end_training() From c10f875ff042f3fd2bc14ed019db68d4ad9567b6 Mon Sep 17 00:00:00 2001 From: Dorsa Rohani Date: Fri, 1 Nov 2024 23:48:44 -0400 Subject: [PATCH 15/20] Add Diffusion Policy for Reinforcement Learning (#9824) * enable cpu ability * model creation + comprehensive testing * training + tests * all tests working * remove unneeded files + clarify docs * update train tests * update readme.md * remove data from gitignore * undo cpu enabled option * Update README.md * update readme * code quality fixes * diffusion policy example * update readme * add pretrained model weights + doc * add comment * add documentation * add docstrings * update comments * update readme * fix code quality * Update examples/reinforcement_learning/README.md Co-authored-by: Sayak Paul * Update examples/reinforcement_learning/diffusion_policy.py Co-authored-by: Sayak Paul * suggestions + safe globals for weights_only=True * suggestions + safe weights loading * fix code quality * reformat file --------- Co-authored-by: Sayak Paul --- examples/reinforcement_learning/README.md | 11 +- .../diffusion_policy.py | 201 ++++++++++++++++++ 2 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 examples/reinforcement_learning/diffusion_policy.py diff --git a/examples/reinforcement_learning/README.md b/examples/reinforcement_learning/README.md index 3c3ada2031cf..30d3b5bb1dd8 100644 --- a/examples/reinforcement_learning/README.md +++ b/examples/reinforcement_learning/README.md @@ -1,4 +1,13 @@ -# Overview + +## Diffusion-based Policy Learning for RL + +`diffusion_policy` implements [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/), a diffusion model that predicts robot action sequences in reinforcement learning tasks. + +This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow. + +To execute the script, run `diffusion_policy.py` + +## Diffuser Locomotion These examples show how to run [Diffuser](https://arxiv.org/abs/2205.09991) in Diffusers. There are two ways to use the script, `run_diffuser_locomotion.py`. diff --git a/examples/reinforcement_learning/diffusion_policy.py b/examples/reinforcement_learning/diffusion_policy.py new file mode 100644 index 000000000000..3ef4c1dabc2e --- /dev/null +++ b/examples/reinforcement_learning/diffusion_policy.py @@ -0,0 +1,201 @@ +import numpy as np +import numpy.core.multiarray as multiarray +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from torch.serialization import add_safe_globals + +from diffusers import DDPMScheduler, UNet1DModel + + +add_safe_globals( + [ + multiarray._reconstruct, + np.ndarray, + np.dtype, + np.dtype(np.float32).type, + np.dtype(np.float64).type, + np.dtype(np.int32).type, + np.dtype(np.int64).type, + type(np.dtype(np.float32)), + type(np.dtype(np.float64)), + type(np.dtype(np.int32)), + type(np.dtype(np.int64)), + ] +) + +""" +An example of using HuggingFace's diffusers library for diffusion policy, +generating smooth movement trajectories. + +This implements a robot control model for pushing a T-shaped block into a target area. +The model takes in the robot arm position, block position, and block angle, +then outputs a sequence of 16 (x,y) positions for the robot arm to follow. +""" + + +class ObservationEncoder(nn.Module): + """ + Converts raw robot observations (positions/angles) into a more compact representation + + state_dim (int): Dimension of the input state vector (default: 5) + [robot_x, robot_y, block_x, block_y, block_angle] + + - Input shape: (batch_size, state_dim) + - Output shape: (batch_size, 256) + """ + + def __init__(self, state_dim): + super().__init__() + self.net = nn.Sequential(nn.Linear(state_dim, 512), nn.ReLU(), nn.Linear(512, 256)) + + def forward(self, x): + return self.net(x) + + +class ObservationProjection(nn.Module): + """ + Takes the encoded observation and transforms it into 32 values that represent the current robot/block situation. + These values are used as additional contextual information during the diffusion model's trajectory generation. + + - Input: 256-dim vector (padded to 512) + Shape: (batch_size, 256) + - Output: 32 contextual information values for the diffusion model + Shape: (batch_size, 32) + """ + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(32, 512)) + self.bias = nn.Parameter(torch.zeros(32)) + + def forward(self, x): # pad 256-dim input to 512-dim with zeros + if x.size(-1) == 256: + x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1) + return nn.functional.linear(x, self.weight, self.bias) + + +class DiffusionPolicy: + """ + Implements diffusion policy for generating robot arm trajectories. + Uses diffusion to generate sequences of positions for a robot arm, conditioned on + the current state of the robot and the block it needs to push. + + The model expects observations in pixel coordinates (0-512 range) and block angle in radians. + It generates trajectories as sequences of (x,y) coordinates also in the 0-512 range. + """ + + def __init__(self, state_dim=5, device="cpu"): + self.device = device + + # define valid ranges for inputs/outputs + self.stats = { + "obs": {"min": torch.zeros(5), "max": torch.tensor([512, 512, 512, 512, 2 * np.pi])}, + "action": {"min": torch.zeros(2), "max": torch.full((2,), 512)}, + } + + self.obs_encoder = ObservationEncoder(state_dim).to(device) + self.obs_projection = ObservationProjection().to(device) + + # UNet model that performs the denoising process + # takes in concatenated action (2 channels) and context (32 channels) = 34 channels + # outputs predicted action (2 channels for x,y coordinates) + self.model = UNet1DModel( + sample_size=16, # length of trajectory sequence + in_channels=34, + out_channels=2, + layers_per_block=2, # number of layers per each UNet block + block_out_channels=(128,), # number of output neurons per layer in each block + down_block_types=("DownBlock1D",), # reduce the resolution of data + up_block_types=("UpBlock1D",), # increase the resolution of data + ).to(device) + + # noise scheduler that controls the denoising process + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=100, # number of denoising steps + beta_schedule="squaredcos_cap_v2", # type of noise schedule + ) + + # load pre-trained weights from HuggingFace + checkpoint = torch.load( + hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt"), weights_only=True, map_location=device + ) + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.obs_encoder.load_state_dict(checkpoint["encoder_state_dict"]) + self.obs_projection.load_state_dict(checkpoint["projection_state_dict"]) + + # scales data to [-1, 1] range for neural network processing + def normalize_data(self, data, stats): + return ((data - stats["min"]) / (stats["max"] - stats["min"])) * 2 - 1 + + # converts normalized data back to original range + def unnormalize_data(self, ndata, stats): + return ((ndata + 1) / 2) * (stats["max"] - stats["min"]) + stats["min"] + + @torch.no_grad() + def predict(self, observation): + """ + Generates a trajectory of robot arm positions given the current state. + + Args: + observation (torch.Tensor): Current state [robot_x, robot_y, block_x, block_y, block_angle] + Shape: (batch_size, 5) + + Returns: + torch.Tensor: Sequence of (x,y) positions for the robot arm to follow + Shape: (batch_size, 16, 2) where: + - 16 is the number of steps in the trajectory + - 2 is the (x,y) coordinates in pixel space (0-512) + + The function first encodes the observation, then uses it to condition a diffusion + process that gradually denoises random trajectories into smooth, purposeful movements. + """ + observation = observation.to(self.device) + normalized_obs = self.normalize_data(observation, self.stats["obs"]) + + # encode the observation into context values for the diffusion model + cond = self.obs_projection(self.obs_encoder(normalized_obs)) + # keeps first & second dimension sizes unchanged, and multiplies last dimension by 16 + cond = cond.view(normalized_obs.shape[0], -1, 1).expand(-1, -1, 16) + + # initialize action with noise - random noise that will be refined into a trajectory + action = torch.randn((observation.shape[0], 2, 16), device=self.device) + + # denoise + # at each step `t`, the current noisy trajectory (`action`) & conditioning info (context) are + # fed into the model to predict a denoised trajectory, then uses self.noise_scheduler.step to + # apply this prediction & slightly reduce the noise in `action` more + + self.noise_scheduler.set_timesteps(100) + for t in self.noise_scheduler.timesteps: + model_output = self.model(torch.cat([action, cond], dim=1), t) + action = self.noise_scheduler.step(model_output.sample, t, action).prev_sample + + action = action.transpose(1, 2) # reshape to [batch, 16, 2] + action = self.unnormalize_data(action, self.stats["action"]) # scale back to coordinates + return action + + +if __name__ == "__main__": + policy = DiffusionPolicy() + + # sample of a single observation + # robot arm starts in center, block is slightly left and up, rotated 90 degrees + obs = torch.tensor( + [ + [ + 256.0, # robot arm x position (middle of screen) + 256.0, # robot arm y position (middle of screen) + 200.0, # block x position + 300.0, # block y position + np.pi / 2, # block angle (90 degrees) + ] + ] + ) + + action = policy.predict(obs) + + print("Action shape:", action.shape) # should be [1, 16, 2] - one trajectory of 16 x,y positions + print("\nPredicted trajectory:") + for i, (x, y) in enumerate(action[0]): + print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}") From 13e8fdecda91e27e40b15fa8a8f456ade773e6eb Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 2 Nov 2024 09:50:39 +0530 Subject: [PATCH 16/20] [feat] add `load_lora_adapter()` for compatible models (#9712) * add first draft. * fix * updates. * updates. * updates * updates * updates. * fix-copies * lora constants. * add tests * Apply suggestions from code review Co-authored-by: Benjamin Bossan * docstrings. --------- Co-authored-by: Benjamin Bossan --- src/diffusers/loaders/lora_base.py | 242 ++++++------ src/diffusers/loaders/lora_pipeline.py | 498 ++++++------------------ src/diffusers/loaders/peft.py | 223 +++++++++++ tests/lora/test_deprecated_utilities.py | 39 ++ tests/lora/utils.py | 4 +- 5 files changed, 515 insertions(+), 491 deletions(-) create mode 100644 tests/lora/test_deprecated_utilities.py diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index e124b6eeacf3..286d0a12bc71 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -51,6 +51,9 @@ logger = logging.get_logger(__name__) +LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" +LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): """ @@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder): text_encoder._hf_peft_config_loaded = None +def _fetch_state_dict( + pretrained_model_name_or_path_or_dict, + weight_name, + use_safetensors, + local_files_only, + cache_dir, + force_download, + proxies, + token, + revision, + subfolder, + user_agent, + allow_pickle, +): + model_file = None + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + # Here we're relaxing the loading check to enable more Inference API + # friendliness where sometimes, it's not at all possible to automatically + # determine `weight_name`. + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, + file_extension=".safetensors", + local_files_only=local_files_only, + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except (IOError, safetensors.SafetensorError) as e: + if not allow_pickle: + raise e + # try loading non-safetensors weights + model_file = None + pass + + if model_file is None: + if weight_name is None: + weight_name = _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only + ) + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name or LORA_WEIGHT_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + return state_dict + + +def _best_guess_weight_name( + pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False +): + if local_files_only or HF_HUB_OFFLINE: + raise ValueError("When using the offline mode, you must specify a `weight_name`.") + + targeted_files = [] + + if os.path.isfile(pretrained_model_name_or_path_or_dict): + return + elif os.path.isdir(pretrained_model_name_or_path_or_dict): + targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)] + else: + files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings + targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] + if len(targeted_files) == 0: + return + + # "scheduler" does not correspond to a LoRA checkpoint. + # "optimizer" does not correspond to a LoRA checkpoint + # only top-level checkpoints are considered and not the other ones, hence "checkpoint". + unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} + targeted_files = list( + filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) + ) + + if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) + elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): + targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) + + if len(targeted_files) > 1: + raise ValueError( + f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." + ) + weight_name = targeted_files[0] + return weight_name + + class LoraBaseMixin: """Utility class for handling LoRAs.""" @@ -234,124 +350,16 @@ def _optionally_disable_offloading(cls, _pipeline): return (is_model_cpu_offload, is_sequential_cpu_offload) @classmethod - def _fetch_state_dict( - cls, - pretrained_model_name_or_path_or_dict, - weight_name, - use_safetensors, - local_files_only, - cache_dir, - force_download, - proxies, - token, - revision, - subfolder, - user_agent, - allow_pickle, - ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - - model_file = None - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - # Here we're relaxing the loading check to enable more Inference API - # friendliness where sometimes, it's not at all possible to automatically - # determine `weight_name`. - if weight_name is None: - weight_name = cls._best_guess_weight_name( - pretrained_model_name_or_path_or_dict, - file_extension=".safetensors", - local_files_only=local_files_only, - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except (IOError, safetensors.SafetensorError) as e: - if not allow_pickle: - raise e - # try loading non-safetensors weights - model_file = None - pass - - if model_file is None: - if weight_name is None: - weight_name = cls._best_guess_weight_name( - pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only - ) - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name or LORA_WEIGHT_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict - - return state_dict + def _fetch_state_dict(cls, *args, **kwargs): + deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." + deprecate("_fetch_state_dict", "0.35.0", deprecation_message) + return _fetch_state_dict(*args, **kwargs) @classmethod - def _best_guess_weight_name( - cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False - ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - - if local_files_only or HF_HUB_OFFLINE: - raise ValueError("When using the offline mode, you must specify a `weight_name`.") - - targeted_files = [] - - if os.path.isfile(pretrained_model_name_or_path_or_dict): - return - elif os.path.isdir(pretrained_model_name_or_path_or_dict): - targeted_files = [ - f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension) - ] - else: - files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings - targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] - if len(targeted_files) == 0: - return - - # "scheduler" does not correspond to a LoRA checkpoint. - # "optimizer" does not correspond to a LoRA checkpoint - # only top-level checkpoints are considered and not the other ones, hence "checkpoint". - unallowed_substrings = {"scheduler", "optimizer", "checkpoint"} - targeted_files = list( - filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files) - ) - - if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files)) - elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files): - targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files)) - - if len(targeted_files) > 1: - raise ValueError( - f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}." - ) - weight_name = targeted_files[0] - return weight_name + def _best_guess_weight_name(cls, *args, **kwargs): + deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." + deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) + return _best_guess_weight_name(*args, **kwargs) def unload_lora_weights(self): """ @@ -725,8 +733,6 @@ def write_lora_layers( save_function: Callable, safe_serialization: bool, ): - from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE - if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5e01ec567f9a..154aa2d8f9bb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -21,7 +21,6 @@ USE_PEFT_BACKEND, convert_state_dict_to_diffusers, convert_state_dict_to_peft, - convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, @@ -33,7 +32,7 @@ logging, scale_lora_layers, ) -from .lora_base import LoraBaseMixin +from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa from .lora_conversion_utils import ( _convert_kohya_flux_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -62,9 +61,6 @@ UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" -LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" - class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" @@ -222,7 +218,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -282,7 +278,9 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -341,7 +339,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -601,7 +601,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -744,7 +746,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -805,7 +807,9 @@ def load_lora_into_unet( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -865,7 +869,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1182,7 +1188,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1226,7 +1232,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1250,13 +1258,17 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} + if len(transformer_state_dict) > 0: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -1301,94 +1313,24 @@ def load_lora_into_transformer( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - peft_kwargs = {} - if is_peft_version(">=", "0.13.1"): - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -1424,7 +1366,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1742,7 +1686,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1819,7 +1763,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1843,14 +1789,18 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - network_alphas=network_alphas, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} + if len(transformer_state_dict) > 0: + self.load_lora_into_transformer( + state_dict, + network_alphas=network_alphas, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -1881,104 +1831,32 @@ def load_lora_into_transformer( The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - transformer (`SD3Transformer2DModel`): + transformer (`FluxTransformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - + # Load the layers corresponding to transformer. keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - if network_alphas is not None and len(network_alphas) >= 1: - prefix = cls.transformer_name - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - peft_kwargs = {} - if is_peft_version(">=", "0.13.1"): - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2014,7 +1892,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2242,7 +2122,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): text_encoder_name = TEXT_ENCODER_NAME @classmethod - def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel + def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -2255,93 +2138,32 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada The value of the network alpha used for stable learning and preventing underflow. This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - unet (`UNet2DConditionModel`): - The UNet model to load the LoRA layers into. + transformer (`UVit2DModel`): + The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # Load the layers corresponding to transformer. keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] - network_alphas = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - if len(state_dict.keys()) > 0: - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2377,7 +2199,9 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -2619,7 +2443,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -2658,7 +2482,9 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -2691,7 +2517,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False ): @@ -2703,99 +2529,29 @@ def load_lora_into_transformer( A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. - transformer (`SD3Transformer2DModel`): + transformer (`CogVideoXTransformer3DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - peft_kwargs = {} - if is_peft_version(">=", "0.13.1"): - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) - - warn_msg = "" - if incompatible_keys is not None: - # Check only for unexpected keys. - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] - if lora_unexpected_keys: - warn_msg = ( - f"Loading adapter weights from state_dict led to unexpected keys found in the model:" - f" {', '.join(lora_unexpected_keys)}. " - ) - - # Filter missing keys specific to the current adapter. - missing_keys = getattr(incompatible_keys, "missing_keys", None) - if missing_keys: - lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] - if lora_missing_keys: - warn_msg += ( - f"Loading adapter weights from state_dict led to missing keys in the model:" - f" {', '.join(lora_missing_keys)}." - ) - - if warn_msg: - logger.warning(warn_msg) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index d1c6721512fa..cf361e88a670 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -16,18 +16,32 @@ from functools import partial from typing import Dict, List, Optional, Union +import torch.nn as nn + from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, + convert_unet_state_dict_to_peft, delete_adapter_layers, + get_adapter_name, + get_peft_kwargs, + is_accelerate_available, is_peft_available, + is_peft_version, + logging, set_adapter_layers, set_weights_and_activate_adapters, ) +from .lora_base import _fetch_state_dict from .unet_loader_utils import _maybe_expand_lora_scales +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + +logger = logging.get_logger(__name__) + _SET_ADAPTER_SCALE_FN_MAPPING = { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, @@ -53,6 +67,215 @@ class PeftAdapterMixin: _hf_peft_config_loaded = False + @classmethod + # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading + def _optionally_disable_offloading(cls, _pipeline): + """ + Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. + + Args: + _pipeline (`DiffusionPipeline`): + The pipeline to disable offloading for. + + Returns: + tuple: + A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. + """ + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): + r""" + Loads a LoRA adapter into the underlying model. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + prefix (`str`, *optional*): Prefix to filter the state dict. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + adapter_name = kwargs.pop("adapter_name", None) + network_alphas = kwargs.pop("network_alphas", None) + _pipeline = kwargs.pop("_pipeline", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) + allow_pickle = False + + if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + keys = list(state_dict.keys()) + transformer_keys = [k for k in keys if k.startswith(prefix)] + if len(transformer_keys) > 0: + state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys} + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(self, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # =", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + + warn_msg = "" + if incompatible_keys is not None: + # Check only for unexpected keys. + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k] + if lora_unexpected_keys: + warn_msg = ( + f"Loading adapter weights from state_dict led to unexpected keys found in the model:" + f" {', '.join(lora_unexpected_keys)}. " + ) + + # Filter missing keys specific to the current adapter. + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if missing_keys: + lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k] + if lora_missing_keys: + warn_msg += ( + f"Loading adapter weights from state_dict led to missing keys in the model:" + f" {', '.join(lora_missing_keys)}." + ) + + if warn_msg: + logger.warning(warn_msg) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + def set_adapters( self, adapter_names: Union[List[str], str], diff --git a/tests/lora/test_deprecated_utilities.py b/tests/lora/test_deprecated_utilities.py new file mode 100644 index 000000000000..4275ef8089a3 --- /dev/null +++ b/tests/lora/test_deprecated_utilities.py @@ -0,0 +1,39 @@ +import os +import tempfile +import unittest + +import torch + +from diffusers.loaders.lora_base import LoraBaseMixin + + +class UtilityMethodDeprecationTests(unittest.TestCase): + def test_fetch_state_dict_cls_method_raises_warning(self): + state_dict = torch.nn.Linear(3, 3).state_dict() + with self.assertWarns(FutureWarning) as warning: + _ = LoraBaseMixin._fetch_state_dict( + state_dict, + weight_name=None, + use_safetensors=False, + local_files_only=True, + cache_dir=None, + force_download=False, + proxies=None, + token=None, + revision=None, + subfolder=None, + user_agent=None, + allow_pickle=None, + ) + warning_message = str(warning.warnings[0].message) + assert "Using the `_fetch_state_dict()` method from" in warning_message + + def test_best_guess_weight_name_cls_method_raises_warning(self): + with tempfile.TemporaryDirectory() as tmpdir: + state_dict = torch.nn.Linear(3, 3).state_dict() + torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin")) + + with self.assertWarns(FutureWarning) as warning: + _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir) + warning_message = str(warning.warnings[0].message) + assert "Using the `_best_guess_weight_name()` method from" in warning_message diff --git a/tests/lora/utils.py b/tests/lora/utils.py index e7fc840fcaa5..b711c8c9791e 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1787,7 +1787,7 @@ def test_missing_keys_warning(self): logger = ( logging.get_logger("diffusers.loaders.unet") if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.lora_pipeline") + else logging.get_logger("diffusers.loaders.peft") ) logger.setLevel(30) with CaptureLogger(logger) as cap_logger: @@ -1826,7 +1826,7 @@ def test_unexpected_keys_warning(self): logger = ( logging.get_logger("diffusers.loaders.unet") if self.unet_kwargs is not None - else logging.get_logger("diffusers.loaders.lora_pipeline") + else logging.get_logger("diffusers.loaders.peft") ) logger.setLevel(30) with CaptureLogger(logger) as cap_logger: From a3cc641f78bd0c4a749e8ad03141d7fdb76eec1c Mon Sep 17 00:00:00 2001 From: RogerSinghChugh <35698080+RogerSinghChugh@users.noreply.github.com> Date: Mon, 4 Nov 2024 23:10:44 +0530 Subject: [PATCH 17/20] Refac training utils.py (#9815) * Refac training utils.py * quality --------- Co-authored-by: sayakpaul --- src/diffusers/training_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index d2bf3fe07185..2474ed5c2114 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -43,6 +43,9 @@ def set_seed(seed: int): Args: seed (`int`): The seed to set. + + Returns: + `None` """ random.seed(seed) np.random.seed(seed) @@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + for the given timesteps using the provided noise scheduler. + + Args: + noise_scheduler (`NoiseScheduler`): + An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute + the SNR values. + timesteps (`torch.Tensor`): + A tensor of timesteps for which the SNR is computed. + + Returns: + `torch.Tensor`: A tensor containing the computed SNR values for each timestep. """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 From 3f329a426a09d0bf3f96095301042a5903bc78eb Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 5 Nov 2024 20:33:41 +0530 Subject: [PATCH 18/20] [core] Mochi T2V (#9769) * update * udpate * update transformer * make style * fix * add conversion script * update * fix * update * fix * update * fixes * make style * update * update * update * init * update * update * add * up * up * up * update * mochi transformer * remove original implementation * make style * update inits * update conversion script * docs * Update src/diffusers/pipelines/mochi/pipeline_mochi.py Co-authored-by: Dhruv Nair * Update src/diffusers/pipelines/mochi/pipeline_mochi.py Co-authored-by: Dhruv Nair * fix docs * pipeline fixes * make style * invert sigmas in scheduler; fix pipeline * fix pipeline num_frames * flip proj and gate in swiglu * make style * fix * make style * fix tests * latent mean and std fix * update * cherry-pick 1069d210e1b9e84a366cdc7a13965626ea258178 * remove additional sigma already handled by flow match scheduler * fix * remove hardcoded value * replace conv1x1 with linear * Update src/diffusers/pipelines/mochi/pipeline_mochi.py Co-authored-by: Dhruv Nair * framewise decoding and conv_cache * make style * Apply suggestions from code review * mochi vae encoder changes * rebase correctly * Update scripts/convert_mochi_to_diffusers.py * fix tests * fixes * make style * update * make style * update * add framewise and tiled encoding * make style * make original vae implementation behaviour the default; note: framewise encoding does not work * remove framewise encoding implementation due to presence of attn layers * fight test 1 * fight test 2 --------- Co-authored-by: Dhruv Nair Co-authored-by: yiyixuxu --- docs/source/en/_toctree.yml | 6 + .../en/api/models/autoencoderkl_mochi.md | 32 + .../en/api/models/mochi_transformer3d.md | 30 + docs/source/en/api/pipelines/mochi.md | 36 + scripts/convert_mochi_to_diffusers.py | 461 +++++++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/activations.py | 1 + src/diffusers/models/attention_processor.py | 185 ++- src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_cogvideox.py | 32 +- .../autoencoders/autoencoder_kl_mochi.py | 1165 +++++++++++++++++ src/diffusers/models/embeddings.py | 117 ++ src/diffusers/models/normalization.py | 52 +- src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_mochi.py | 387 ++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/mochi/__init__.py | 48 + .../pipelines/mochi/pipeline_mochi.py | 724 ++++++++++ .../pipelines/mochi/pipeline_output.py | 20 + .../scheduling_flow_match_euler_discrete.py | 11 +- src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_mochi.py | 84 ++ tests/pipelines/mochi/__init__.py | 0 tests/pipelines/mochi/test_mochi.py | 299 +++++ 26 files changed, 3727 insertions(+), 22 deletions(-) create mode 100644 docs/source/en/api/models/autoencoderkl_mochi.md create mode 100644 docs/source/en/api/models/mochi_transformer3d.md create mode 100644 docs/source/en/api/pipelines/mochi.md create mode 100644 scripts/convert_mochi_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_mochi.py create mode 100644 src/diffusers/models/transformers/transformer_mochi.py create mode 100644 src/diffusers/pipelines/mochi/__init__.py create mode 100644 src/diffusers/pipelines/mochi/pipeline_mochi.py create mode 100644 src/diffusers/pipelines/mochi/pipeline_output.py create mode 100644 tests/models/transformers/test_models_transformer_mochi.py create mode 100644 tests/pipelines/mochi/__init__.py create mode 100644 tests/pipelines/mochi/test_mochi.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c0d571a5864d..de6cd2981b96 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -270,6 +270,8 @@ title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel + - local: api/models/mochi_transformer3d + title: MochiTransformer3DModel - local: api/models/pixart_transformer2d title: PixArtTransformer2DModel - local: api/models/prior_transformer @@ -306,6 +308,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoderkl_mochi + title: AutoencoderKLMochi - local: api/models/asymmetricautoencoderkl title: AsymmetricAutoencoderKL - local: api/models/consistency_decoder_vae @@ -400,6 +404,8 @@ title: Lumina-T2X - local: api/pipelines/marigold title: Marigold + - local: api/pipelines/mochi + title: Mochi - local: api/pipelines/panorama title: MultiDiffusion - local: api/pipelines/musicldm diff --git a/docs/source/en/api/models/autoencoderkl_mochi.md b/docs/source/en/api/models/autoencoderkl_mochi.md new file mode 100644 index 000000000000..9747de4af937 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_mochi.md @@ -0,0 +1,32 @@ + + +# AutoencoderKLMochi + +The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLMochi + +vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLMochi + +[[autodoc]] AutoencoderKLMochi + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/mochi_transformer3d.md b/docs/source/en/api/models/mochi_transformer3d.md new file mode 100644 index 000000000000..05e28654d58c --- /dev/null +++ b/docs/source/en/api/models/mochi_transformer3d.md @@ -0,0 +1,30 @@ + + +# MochiTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo. + +The model can be loaded with the following code snippet. + +```python +from diffusers import MochiTransformer3DModel + +vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda") +``` + +## MochiTransformer3DModel + +[[autodoc]] MochiTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/mochi.md b/docs/source/en/api/pipelines/mochi.md new file mode 100644 index 000000000000..f29297e5901c --- /dev/null +++ b/docs/source/en/api/pipelines/mochi.md @@ -0,0 +1,36 @@ + + +# Mochi + +[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo. + +*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## MochiPipeline + +[[autodoc]] MochiPipeline + - all + - __call__ + +## MochiPipelineOutput + +[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py new file mode 100644 index 000000000000..892fd871c554 --- /dev/null +++ b/scripts/convert_mochi_to_diffusers.py @@ -0,0 +1,461 @@ +import argparse +from contextlib import nullcontext + +import torch +from accelerate import init_empty_weights +from safetensors.torch import load_file +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +TOKENIZER_MAX_LENGTH = 256 + +parser = argparse.ArgumentParser() +parser.add_argument("--transformer_checkpoint_path", default=None, type=str) +parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str) +parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", required=True, type=str) +parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") +parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") +parser.add_argument("--dtype", type=str, default=None) + +args = parser.parse_args() + + +# This is specific to `AdaLayerNormContinuous`: +# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def swap_proj_gate(weight): + proj, gate = weight.chunk(2, dim=0) + new_weight = torch.cat([gate, proj], dim=0) + return new_weight + + +def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): + original_state_dict = load_file(ckpt_path, device="cpu") + new_state_dict = {} + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight") + new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias") + + # Convert time_embed + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight") + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias") + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight") + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias") + new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight") + new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias") + new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight") + new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias") + new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight") + new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias") + new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight") + new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) + else: + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) + + # Visual attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_x.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop( + old_prefix + "attn.q_norm_y.weight" + ) + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop( + old_prefix + "attn.proj_y.weight" + ) + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop( + old_prefix + "attn.proj_y.bias" + ) + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + original_state_dict.pop(old_prefix + "mlp_x.w1.weight") + ) + new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + original_state_dict.pop(old_prefix + "mlp_y.w1.weight") + ) + new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop( + old_prefix + "mlp_y.w2.weight" + ) + + # Output layers + new_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.mod.weight"), dim=0 + ) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0) + new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + + new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies") + + print("Remaining Keys:", original_state_dict.keys()) + + return new_state_dict + + +def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path): + encoder_state_dict = load_file(encoder_ckpt_path, device="cpu") + decoder_state_dict = load_file(decoder_ckpt_path, device="cpu") + new_state_dict = {} + + # ==== Decoder ===== + prefix = "decoder." + + # Convert conv_in + new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight") + new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias") + + # Convert block_in (MochiMidBlock3D) + for i in range(3): # layers_per_block[-1] = 3 + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( + f"blocks.0.{i+1}.stack.5.bias" + ) + + # Convert up_blocks (MochiUpBlock3D) + down_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4] + for block in range(3): + for i in range(down_block_layers[block]): + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.0.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.0.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.2.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.2.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.3.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.3.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.5.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( + f"blocks.{block+1}.blocks.{i}.stack.5.bias" + ) + new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop( + f"blocks.{block+1}.proj.weight" + ) + new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias") + + # Convert block_out (MochiMidBlock3D) + for i in range(3): # layers_per_block[0] = 3 + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop( + f"blocks.4.{i}.stack.5.bias" + ) + + # Convert proj_out (Conv1x1 ~= nn.Linear) + new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight") + new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias") + + print("Remaining Decoder Keys:", decoder_state_dict.keys()) + + # ==== Encoder ===== + prefix = "encoder." + + new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight") + new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias") + + # Convert block_in (MochiMidBlock3D) + for i in range(3): # layers_per_block[0] = 3 + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+1}.stack.5.bias" + ) + + # Convert down_blocks (MochiDownBlock3D) + down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3] + for block in range(3): + new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.0.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.0.bias" + ) + + for i in range(down_block_layers[block]): + # Convert resnets + new_state_dict[ + f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight" + ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight") + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.0.bias" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.2.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.2.bias" + ) + new_state_dict[ + f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight" + ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight") + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.3.bias" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.5.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.stack.5.bias" + ) + + # Convert attentions + qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias" + ) + new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight" + ) + new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias" + ) + + # Convert block_out (MochiMidBlock3D) + for i in range(3): # layers_per_block[-1] = 3 + # Convert resnets + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.0.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.0.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.2.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.2.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.3.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.3.bias" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.5.weight" + ) + new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.stack.5.bias" + ) + + # Convert attentions + qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q + new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k + new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v + new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.attn.out.weight" + ) + new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.attn.out.bias" + ) + new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.norm.weight" + ) + new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop( + f"layers.{i+7}.attn_block.norm.bias" + ) + + # Convert output layers + new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight") + new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias") + new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight") + + print("Remaining Encoder Keys:", encoder_state_dict.keys()) + + return new_state_dict + + +def main(args): + if args.dtype is None: + dtype = None + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + transformer = None + vae = None + + if args.transformer_checkpoint_path is not None: + converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers( + args.transformer_checkpoint_path + ) + transformer = MochiTransformer3DModel() + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + if dtype is not None: + transformer = transformer.to(dtype=dtype) + + if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None: + vae = AutoencoderKLMochi(latent_channels=12, out_channels=3) + converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers( + args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path + ) + vae.load_state_dict(converted_vae_state_dict, strict=True) + if dtype is not None: + vae = vae.to(dtype=dtype) + + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + # Apparently, the conversion does not work anymore without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + pipe = MochiPipeline( + scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True), + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ff59a3839552..fb6d22084bd6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -83,6 +83,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", "AutoencoderTiny", @@ -102,6 +103,7 @@ "Kandinsky3UNet", "LatteTransformer3DModel", "LuminaNextDiT2DModel", + "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -311,6 +313,7 @@ "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", + "MochiPipeline", "MusicLDMPipeline", "PaintByExamplePipeline", "PIAPipeline", @@ -565,6 +568,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, @@ -584,6 +588,7 @@ Kandinsky3UNet, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, ModelMixin, MotionAdapter, MultiAdapter, @@ -772,6 +777,7 @@ LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, + MochiPipeline, MusicLDMPipeline, PaintByExamplePipeline, PIAPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 38dd2819133d..518ab6df65c4 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -30,6 +30,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] @@ -58,6 +59,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] @@ -85,6 +87,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, AutoencoderTiny, @@ -110,6 +113,7 @@ HunyuanDiT2DModel, LatteTransformer3DModel, LuminaNextDiT2DModel, + MochiTransformer3DModel, PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index fb24a36bae75..f4318fc3cd39 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -136,6 +136,7 @@ class SwiGLU(nn.Module): def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) self.activation = nn.SiLU() diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 20c5cf3d925e..da01b7a1edcd 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -120,14 +120,16 @@ def __init__( _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, + out_context_dim: int = None, context_pre_only=None, pre_only=False, elementwise_affine: bool = True, + is_causal: bool = False, ): super().__init__() # To prevent circular import. - from .normalization import FP32LayerNorm, RMSNorm + from .normalization import FP32LayerNorm, LpNorm, RMSNorm self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads @@ -142,8 +144,10 @@ def __init__( self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only + self.is_causal = is_causal # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly @@ -192,6 +196,9 @@ def __init__( elif qk_norm == "rms_norm": self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "l2": + self.norm_q = LpNorm(p=2, dim=-1, eps=eps) + self.norm_k = LpNorm(p=2, dim=-1, eps=eps) else: raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") @@ -241,7 +248,7 @@ def __init__( self.to_out.append(nn.Dropout(dropout)) if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": @@ -1886,6 +1893,7 @@ def __call__( hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states @@ -2714,6 +2722,91 @@ def __call__( return hidden_states +class MochiVaeAttnProcessor2_0: + r""" + Attention processor used in Mochi VAE. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + is_single_frame = hidden_states.shape[1] == 1 + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if is_single_frame: + hidden_states = attn.to_v(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class StableAudioAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is @@ -3389,6 +3482,94 @@ def __call__( return hidden_states +class MochiAttnProcessor2_0: + """Attention processor used in Mochi.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + if image_rotary_emb is not None: + + def apply_rotary_emb(x, freqs_cos, freqs_sin): + x_even = x[..., 0::2].float() + x_odd = x[..., 1::2].float() + + cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) + sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) + + return torch.stack([cos, sin], dim=-1).flatten(-2) + + query = apply_rotary_emb(query, *image_rotary_emb) + key = apply_rotary_emb(key, *image_rotary_emb) + + query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) + encoder_query, encoder_key, encoder_value = ( + encoder_query.transpose(1, 2), + encoder_key.transpose(1, 2), + encoder_value.transpose(1, 2), + ) + + sequence_length = query.size(2) + encoder_sequence_length = encoder_query.size(2) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( + (sequence_length, encoder_sequence_length), dim=1 + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if hasattr(attn, "to_add_out"): + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + class FusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 9628fe7f21b0..ba45d6671252 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -2,6 +2,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 68b49d72acc5..8575c7658605 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -94,11 +94,13 @@ def __init__( time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - self.pad_mode = pad_mode - time_pad = dilation * (time_kernel_size - 1) + (1 - stride) - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 + # TODO(aryan): configure calculation based on stride and dilation in the future. + # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi + time_pad = time_kernel_size - 1 + height_pad = (height_kernel_size - 1) // 2 + width_pad = (width_kernel_size - 1) // 2 + self.pad_mode = pad_mode self.height_pad = height_pad self.width_pad = width_pad self.time_pad = time_pad @@ -107,7 +109,7 @@ def __init__( self.temporal_dim = 2 self.time_kernel_size = time_kernel_size - stride = (stride, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) dilation = (dilation, 1, 1) self.conv = CogVideoXSafeConv3d( in_channels=in_channels, @@ -120,18 +122,24 @@ def __init__( def fake_context_parallel_forward( self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None ) -> torch.Tensor: - kernel_size = self.time_kernel_size - if kernel_size > 1: - cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - inputs = torch.cat(cached_inputs + [inputs], dim=2) + if self.pad_mode == "replicate": + inputs = F.pad(inputs, self.time_causal_padding, mode="replicate") + else: + kernel_size = self.time_kernel_size + if kernel_size > 1: + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + inputs = torch.cat(cached_inputs + [inputs], dim=2) return inputs def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: inputs = self.fake_context_parallel_forward(inputs, conv_cache) - conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - inputs = F.pad(inputs, padding_2d, mode="constant", value=0) + if self.pad_mode == "replicate": + conv_cache = None + else: + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + inputs = F.pad(inputs, padding_2d, mode="constant", value=0) output = self.conv(inputs) return output, conv_cache diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py new file mode 100644 index 000000000000..57e8b8f647ba --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -0,0 +1,1165 @@ +# Copyright 2024 The Mochi team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MochiChunkedGroupNorm3D(nn.Module): + r""" + Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group + normalization. + + Args: + num_channels (int): Number of channels expected in input + num_groups (int, optional): Number of groups to separate the channels into. Default: 32 + affine (bool, optional): If True, this module has learnable affine parameters. Default: True + chunk_size (int, optional): Size of each chunk for processing. Default: 8 + + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + affine: bool = True, + chunk_size: int = 8, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine) + self.chunk_size = chunk_size + + def forward(self, x: torch.Tensor = None) -> torch.Tensor: + batch_size = x.size(0) + + x = x.permute(0, 2, 1, 3, 4).flatten(0, 1) + output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0) + output = output.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + return output + + +class MochiResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + act_fn: str = "swish", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(act_fn) + + self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" + ) + self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels) + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" + ) + + def forward( + self, + inputs: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) + + hidden_states = hidden_states + inputs + return hidden_states, new_conv_cache + + +class MochiDownBlock3D(nn.Module): + r""" + An downsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet blocks in the block. + temporal_expansion (`int`, defaults to `2`): + Temporal expansion factor. + spatial_expansion (`int`, defaults to `2`): + Spatial expansion factor. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + add_attention: bool = True, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + self.conv_in = CogVideoXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(temporal_expansion, spatial_expansion, spatial_expansion), + stride=(temporal_expansion, spatial_expansion, spatial_expansion), + pad_mode="replicate", + ) + + resnets = [] + norms = [] + attentions = [] + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=out_channels)) + if add_attention: + norms.append(MochiChunkedGroupNorm3D(num_channels=out_channels)) + attentions.append( + Attention( + query_dim=out_channels, + heads=out_channels // 32, + dim_head=32, + qk_norm="l2", + is_causal=True, + processor=MochiVaeAttnProcessor2_0(), + ) + ) + else: + norms.append(None) + attentions.append(None) + + self.resnets = nn.ModuleList(resnets) + self.norms = nn.ModuleList(norms) + self.attentions = nn.ModuleList(attentions) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + chunk_size: int = 2**15, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(hidden_states) + + for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): + conv_cache_key = f"resnet_{i}" + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + conv_cache=conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + if attn is not None: + residual = hidden_states + hidden_states = norm(hidden_states) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() + + # Perform attention in chunks to avoid following error: + # RuntimeError: CUDA error: invalid configuration argument + if hidden_states.size(0) <= chunk_size: + hidden_states = attn(hidden_states) + else: + hidden_states_chunks = [] + for i in range(0, hidden_states.size(0), chunk_size): + hidden_states_chunk = hidden_states[i : i + chunk_size] + hidden_states_chunk = attn(hidden_states_chunk) + hidden_states_chunks.append(hidden_states_chunk) + hidden_states = torch.cat(hidden_states_chunks) + + hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) + + hidden_states = residual + hidden_states + + return hidden_states, new_conv_cache + + +class MochiMidBlock3D(nn.Module): + r""" + A middle block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `3`): + Number of resnet blocks in the block. + """ + + def __init__( + self, + in_channels: int, # 768 + num_layers: int = 3, + add_attention: bool = True, + ): + super().__init__() + + resnets = [] + norms = [] + attentions = [] + + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=in_channels)) + + if add_attention: + norms.append(MochiChunkedGroupNorm3D(num_channels=in_channels)) + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // 32, + dim_head=32, + qk_norm="l2", + is_causal=True, + processor=MochiVaeAttnProcessor2_0(), + ) + ) + else: + norms.append(None) + attentions.append(None) + + self.resnets = nn.ModuleList(resnets) + self.norms = nn.ModuleList(norms) + self.attentions = nn.ModuleList(attentions) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""Forward method of the `MochiMidBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): + conv_cache_key = f"resnet_{i}" + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + if attn is not None: + residual = hidden_states + hidden_states = norm(hidden_states) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() + hidden_states = attn(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) + + hidden_states = residual + hidden_states + + return hidden_states, new_conv_cache + + +class MochiUpBlock3D(nn.Module): + r""" + An upsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet blocks in the block. + temporal_expansion (`int`, defaults to `2`): + Temporal expansion factor. + spatial_expansion (`int`, defaults to `2`): + Spatial expansion factor. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + resnets = [] + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=in_channels)) + self.resnets = nn.ModuleList(resnets) + + self.proj = nn.Linear(in_channels, out_channels * temporal_expansion * spatial_expansion**2) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + conv_cache=conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + st = self.temporal_expansion + sh = self.spatial_expansion + sw = self.spatial_expansion + + # Reshape and unpatchify + hidden_states = hidden_states.view(batch_size, -1, st, sh, sw, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + hidden_states = hidden_states.view(batch_size, -1, num_frames * st, height * sh, width * sw) + + return hidden_states, new_conv_cache + + +class FourierFeatures(nn.Module): + def __init__(self, start: int = 6, stop: int = 8, step: int = 1): + super().__init__() + + self.start = start + self.stop = stop + self.step = step + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `FourierFeatures` class.""" + + num_channels = inputs.shape[1] + num_freqs = (self.stop - self.start) // self.step + + freqs = torch.arange(self.start, self.stop, self.step, dtype=inputs.dtype, device=inputs.device) + w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs] + w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1] + + # Interleaved repeat of input channels to match w + h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] + # Scale channels by frequency. + h = w * h + + return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1) + + +class MochiEncoder3D(nn.Module): + r""" + The `MochiEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): + The number of output channels for each block. + layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): + The number of resnet blocks for each block. + temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): + The temporal expansion factor for each of the up blocks. + spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): + The spatial expansion factor for each of the up blocks. + non_linearity (`str`, *optional*, defaults to `"swish"`): + The non-linearity to use in the decoder. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + add_attention_block: Tuple[bool, ...] = (False, True, True, True, True), + act_fn: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(act_fn) + + self.fourier_features = FourierFeatures() + self.proj_in = nn.Linear(in_channels, block_out_channels[0]) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0] + ) + + down_blocks = [] + for i in range(len(block_out_channels) - 1): + down_block = MochiDownBlock3D( + in_channels=block_out_channels[i], + out_channels=block_out_channels[i + 1], + num_layers=layers_per_block[i + 1], + temporal_expansion=temporal_expansions[i], + spatial_expansion=spatial_expansions[i], + add_attention=add_attention_block[i + 1], + ) + down_blocks.append(down_block) + self.down_blocks = nn.ModuleList(down_blocks) + + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[-1], num_layers=layers_per_block[-1], add_attention=add_attention_block[-1] + ) + self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1]) + self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False) + + def forward( + self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None + ) -> torch.Tensor: + r"""Forward method of the `MochiEncoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = self.fourier_features(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache["block_in"] = self.block_in( + hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = down_block( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states, new_conv_cache["block_out"] = self.block_out( + hidden_states, conv_cache=conv_cache.get("block_out") + ) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + return hidden_states, new_conv_cache + + +class MochiDecoder3D(nn.Module): + r""" + The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): + The number of output channels for each block. + layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): + The number of resnet blocks for each block. + temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): + The temporal expansion factor for each of the up blocks. + spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): + The spatial expansion factor for each of the up blocks. + non_linearity (`str`, *optional*, defaults to `"swish"`): + The non-linearity to use in the decoder. + """ + + def __init__( + self, + in_channels: int, # 12 + out_channels: int, # 3 + block_out_channels: Tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + act_fn: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(act_fn) + + self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[-1], + num_layers=layers_per_block[-1], + add_attention=False, + ) + + up_blocks = [] + for i in range(len(block_out_channels) - 1): + up_block = MochiUpBlock3D( + in_channels=block_out_channels[-i - 1], + out_channels=block_out_channels[-i - 2], + num_layers=layers_per_block[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + ) + up_blocks.append(up_block) + self.up_blocks = nn.ModuleList(up_blocks) + + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[0], + num_layers=layers_per_block[0], + add_attention=False, + ) + self.proj_out = nn.Linear(block_out_channels[0], out_channels) + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None + ) -> torch.Tensor: + r"""Forward method of the `MochiDecoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = self.conv_in(hidden_states) + + # 1. Mid + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states, new_conv_cache["block_in"] = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.block_in), hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache["block_in"] = self.block_in( + hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = up_block( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states, new_conv_cache["block_out"] = self.block_out( + hidden_states, conv_cache=conv_cache.get("block_out") + ) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + return hidden_states, new_conv_cache + + +class AutoencoderKLMochi(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [Mochi 1 preview](https://github.com/genmoai/models). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["MochiResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 15, + out_channels: int = 3, + encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384), + decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768), + latent_channels: int = 12, + layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), + act_fn: str = "silu", + temporal_expansions: Tuple[int, ...] = (1, 2, 3), + spatial_expansions: Tuple[int, ...] = (2, 2, 2), + add_attention_block: Tuple[bool, ...] = (False, True, True, True, True), + latents_mean: Tuple[float, ...] = ( + -0.06730895953510081, + -0.038011381506090416, + -0.07477820912866141, + -0.05565264470995561, + 0.012767231469026969, + -0.04703542746246419, + 0.043896967884726704, + -0.09346305707025976, + -0.09918314763016893, + -0.008729793427399178, + -0.011931556316503654, + -0.0321993391887285, + ), + latents_std: Tuple[float, ...] = ( + 0.9263795028493863, + 0.9248894543193766, + 0.9393059390890617, + 0.959253732819592, + 0.8244560132752793, + 0.917259975397747, + 0.9294154431013696, + 1.3720942357788521, + 0.881393668867029, + 0.9168315692124348, + 0.9185249279345552, + 0.9274757570805041, + ), + scaling_factor: float = 1.0, + ): + super().__init__() + + self.encoder = MochiEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=encoder_block_out_channels, + layers_per_block=layers_per_block, + temporal_expansions=temporal_expansions, + spatial_expansions=spatial_expansions, + add_attention_block=add_attention_block, + act_fn=act_fn, + ) + self.decoder = MochiDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + layers_per_block=layers_per_block, + temporal_expansions=temporal_expansions, + spatial_expansions=spatial_expansions, + act_fn=act_fn, + ) + + self.spatial_compression_ratio = functools.reduce(lambda x, y: x * y, spatial_expansions, 1) + self.temporal_compression_ratio = functools.reduce(lambda x, y: x * y, temporal_expansions, 1) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be used to determine how the number of output frames in the final decoded video. To maintain consistency with + # the original implementation, this defaults to `True`. + # - Original implementation (drop_last_temporal_frames=True): + # Output frames = (latent_frames - 1) * temporal_compression_ratio + 1 + # - Without dropping additional temporal upscaled frames (drop_last_temporal_frames=False): + # Output frames = latent_frames * temporal_compression_ratio + # The latter case is useful for frame packing and some training/finetuning scenarios where the additional. + self.drop_last_temporal_frames = True + + # This can be configured based on the amount of GPU memory available. + # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 12 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MochiEncoder3D, MochiDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _enable_framewise_encoding(self): + r""" + Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the + oneshot encoding implementation without current latent replicate padding. + + Warning: Framewise encoding may not work as expected due to the causal attention layers. If you enable + framewise encoding, encode a video, and try to decode it, there will be noticeable jittering effect. + """ + self.use_framewise_encoding = True + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.pad_mode = "constant" + + def _enable_framewise_decoding(self): + r""" + Enables the framewise VAE decoding implementation with past latent padding. By default, Diffusers uses the + oneshot decoding implementation without current latent replicate padding. + """ + self.use_framewise_decoding = True + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.pad_mode = "constant" + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + if self.use_framewise_encoding: + raise NotImplementedError( + "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " + "As intermediate frames are not independent from each other, they cannot be encoded frame-wise." + ) + else: + enc, _ = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + if self.use_framewise_decoding: + conv_cache = None + dec = [] + + for i in range(0, num_frames, self.num_latent_frames_batch_size): + z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size] + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + + dec = torch.cat(dec, dim=2) + else: + dec, _ = self.decoder(z) + + if self.drop_last_temporal_frames and dec.size(2) >= self.temporal_compression_ratio: + dec = dec[:, :, self.temporal_compression_ratio - 1 :] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + if self.use_framewise_encoding: + raise NotImplementedError( + "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " + "As intermediate frames are not independent from each other, they cannot be encoded frame-wise." + ) + else: + time, _ = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + if self.use_framewise_decoding: + time = [] + conv_cache = None + + for k in range(0, num_frames, self.num_latent_frames_batch_size): + tile = z[ + :, + :, + k : k + self.num_latent_frames_batch_size, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile) + + time = torch.cat(time, dim=2) + else: + time, _ = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) + + if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio: + time = time[:, :, self.temporal_compression_ratio - 1 :] + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 66917dce6107..7cbd958e1d6e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1356,6 +1356,41 @@ def forward(self, timestep, caption_feat, caption_mask): return conditioning +class MochiCombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + text_embed_dim: int, + time_embed_dim: int = 256, + num_attention_heads: int = 8, + ) -> None: + super().__init__() + + self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0) + self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim) + self.pooler = MochiAttentionPool( + num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim + ) + self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim) + + def forward( + self, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ): + time_proj = self.time_proj(timestep) + time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype)) + + pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask) + caption_proj = self.caption_proj(encoder_hidden_states) + + conditioning = time_emb + pooled_projections + return conditioning, caption_proj + + class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() @@ -1484,6 +1519,88 @@ def shape(x): return a[:, 0, :] # cls_token +class MochiAttentionPool(nn.Module): + def __init__( + self, + num_attention_heads: int, + embed_dim: int, + output_dim: Optional[int] = None, + ) -> None: + super().__init__() + + self.output_dim = output_dim or embed_dim + self.num_attention_heads = num_attention_heads + + self.to_kv = nn.Linear(embed_dim, 2 * embed_dim) + self.to_q = nn.Linear(embed_dim, embed_dim) + self.to_out = nn.Linear(embed_dim, self.output_dim) + + @staticmethod + def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor: + """ + Pool tokens in x using mask. + + NOTE: We assume x does not require gradients. + + Args: + x: (B, L, D) tensor of tokens. + mask: (B, L) boolean tensor indicating which tokens are not padding. + + Returns: + pooled: (B, D) tensor of pooled tokens. + """ + assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens. + assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens. + mask = mask[:, :, None].to(dtype=x.dtype) + mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) + pooled = (x * mask).sum(dim=1, keepdim=keepdim) + return pooled + + def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + r""" + Args: + x (`torch.Tensor`): + Tensor of shape `(B, S, D)` of input tokens. + mask (`torch.Tensor`): + Boolean ensor of shape `(B, S)` indicating which tokens are not padding. + + Returns: + `torch.Tensor`: + `(B, D)` tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_attention_heads + kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x + + def get_fourier_embeds_from_boundingbox(embed_dim, box): """ Args: diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 87dec66935da..817b3fff2ea6 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -234,6 +234,33 @@ def forward( return x, gate_msa, scale_mlp, gate_mlp +class MochiRMSNormZero(nn.Module): + r""" + Adaptive RMS Norm used in Mochi. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, hidden_dim) + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + + return hidden_states, gate_msa, scale_mlp, gate_mlp + + class AdaLayerNormSingle(nn.Module): r""" Norm layer adaptive layer norm single (adaLN-single). @@ -356,20 +383,21 @@ def __init__( out_dim: Optional[int] = None, ): super().__init__() + # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") - # linear_2 + + self.linear_2 = None if out_dim is not None: - self.linear_2 = nn.Linear( - embedding_dim, - out_dim, - bias=bias, - ) + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, @@ -526,3 +554,15 @@ def forward(self, x): gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * nx) + self.beta + x + + +class LpNorm(nn.Module): + def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12): + super().__init__() + + self.p = p + self.dim = dim + self.eps = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 873a2bbecf05..a2c087d708a4 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,5 +17,6 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py new file mode 100644 index 000000000000..7f4ad2b328fa --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -0,0 +1,387 @@ +# Copyright 2024 The Genmo team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention, MochiAttnProcessor2_0 +from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class MochiTransformerBlock(nn.Module): + r""" + Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + context_pre_only (`bool`, defaults to `False`): + Whether or not to process context-related conditions with additional layers. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + pooled_projection_dim: int, + qk_norm: str = "rms_norm", + activation_fn: str = "swiglu", + context_pre_only: bool = False, + eps: float = 1e-6, + ) -> None: + super().__init__() + + self.context_pre_only = context_pre_only + self.ff_inner_dim = (4 * dim * 2) // 3 + self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 + + self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) + + if not context_pre_only: + self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) + else: + self.norm1_context = LuminaLayerNormContinuous( + embedding_dim=pooled_projection_dim, + conditioning_embedding_dim=dim, + eps=eps, + elementwise_affine=False, + norm_type="rms_norm", + out_dim=None, + ) + + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=False, + qk_norm=qk_norm, + added_kv_proj_dim=pooled_projection_dim, + added_proj_bias=False, + out_dim=dim, + out_context_dim=pooled_projection_dim, + context_pre_only=context_pre_only, + processor=MochiAttnProcessor2_0(), + eps=eps, + elementwise_affine=True, + ) + + # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) + self.ff_context = None + if not context_pre_only: + self.ff_context = FeedForward( + pooled_projection_dim, + inner_dim=self.ff_context_inner_dim, + activation_fn=activation_fn, + bias=False, + ) + + self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) + self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + + if not self.context_pre_only: + norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context( + encoder_hidden_states, temb + ) + else: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + + attn_hidden_states, context_attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) + + if not self.context_pre_only: + encoder_hidden_states = encoder_hidden_states + self.norm2_context( + context_attn_hidden_states + ) * torch.tanh(enc_gate_msa).unsqueeze(1) + norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( + enc_gate_mlp + ).unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class MochiRoPE(nn.Module): + r""" + RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + base_height (`int`, defaults to `192`): + Base height used to compute interpolation scale for rotary positional embeddings. + base_width (`int`, defaults to `192`): + Base width used to compute interpolation scale for rotary positional embeddings. + """ + + def __init__(self, base_height: int = 192, base_width: int = 192) -> None: + super().__init__() + + self.target_area = base_height * base_width + + def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: + edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) + return (edges[:-1] + edges[1:]) / 2 + + def _get_positions( + self, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + scale = (self.target_area / (height * width)) ** 0.5 + + t = torch.arange(num_frames, device=device, dtype=dtype) + h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) + w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) + + grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") + + positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) + return positions + + def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float()) + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + def forward( + self, + pos_frequencies: torch.Tensor, + num_frames: int, + height: int, + width: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + pos = self._get_positions(num_frames, height, width, device, dtype) + rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) + return rope_cos, rope_sin + + +@maybe_allow_in_graph +class MochiTransformer3DModel(ModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `48`): + The number of layers of Transformer blocks to use. + in_channels (`int`, defaults to `12`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `256`): + Output dimension of timestep embeddings. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + max_sequence_length (`int`, defaults to `256`): + The maximum sequence length of text embeddings supported. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: int = 2, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 48, + pooled_projection_dim: int = 1536, + in_channels: int = 12, + out_channels: Optional[int] = None, + qk_norm: str = "rms_norm", + text_embed_dim: int = 4096, + time_embed_dim: int = 256, + activation_fn: str = "swiglu", + max_sequence_length: int = 256, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + pos_embed_type=None, + ) + + self.time_embed = MochiCombinedTimestepCaptionEmbedding( + embedding_dim=inner_dim, + pooled_projection_dim=pooled_projection_dim, + text_embed_dim=text_embed_dim, + time_embed_dim=time_embed_dim, + num_attention_heads=8, + ) + + self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0)) + self.rope = MochiRoPE() + + self.transformer_blocks = nn.ModuleList( + [ + MochiTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + pooled_projection_dim=pooled_projection_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + context_pre_only=i == num_layers - 1, + ) + for i in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" + ) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + return_dict: bool = True, + ) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p = self.config.patch_size + + post_patch_height = height // p + post_patch_width = width // p + + temb, encoder_hidden_states = self.time_embed( + timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + ) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + image_rotary_emb = self.rope( + self.pos_frequencies, + num_frames, + post_patch_height, + post_patch_width, + device=hidden_states.device, + dtype=torch.float32, + ) + + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) + hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 634088f1b51a..98574de1ad5f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -247,6 +247,7 @@ "MarigoldNormalsPipeline", ] ) + _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] @@ -571,6 +572,7 @@ MarigoldDepthPipeline, MarigoldNormalsPipeline, ) + from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline from .pag import ( AnimateDiffPAGPipeline, diff --git a/src/diffusers/pipelines/mochi/__init__.py b/src/diffusers/pipelines/mochi/__init__.py new file mode 100644 index 000000000000..a8fd4da9fd36 --- /dev/null +++ b/src/diffusers/pipelines/mochi/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_mochi"] = ["MochiPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_mochi import MochiPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py new file mode 100644 index 000000000000..7a9cc41e2dde --- /dev/null +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -0,0 +1,724 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import MochiTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MochiPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MochiPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + >>> pipe.enable_vae_tiling() + >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." + >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] + >>> export_to_video(frames, "mochi.mp4") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MochiPipeline(DiffusionPipeline): + r""" + The mochi pipeline for text-to-video generation. + + Reference: https://github.com/genmoai/models + + Args: + transformer ([`MochiTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: MochiTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + # TODO: determine these scaling factors from model parameters + self.vae_spatial_scale_factor = 8 + self.vae_temporal_scale_factor = 6 + self.patch_size = 2 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_height = 480 + self.default_width = 848 + + # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents=None, + ): + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 19, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `self.default_height`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `self.default_width`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `19`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `4.5`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `256`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple` + is returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.default_height + width = width or self.default_width + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timestep + # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 + threshold_noise = 0.025 + sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) + sigmas = np.array(sigmas) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return MochiPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/mochi/pipeline_output.py b/src/diffusers/pipelines/mochi/pipeline_output.py new file mode 100644 index 000000000000..cc1437279496 --- /dev/null +++ b/src/diffusers/pipelines/mochi/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class MochiPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 937cae2e47f5..c1096dbe0c29 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -71,6 +71,7 @@ def __init__( max_shift: Optional[float] = 1.15, base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, ): timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) @@ -204,9 +205,15 @@ def set_timesteps( sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps.to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas self._step_index = None self._begin_index = None diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8a87b04a66cb..83d1d4270920 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLMochi(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLTemporalDecoder(metaclass=DummyObject): _backends = ["torch"] @@ -377,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class MochiTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 83d160b08df4..8b4b158efd0a 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1052,6 +1052,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class MochiPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class MusicLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_mochi.py b/tests/models/transformers/test_models_transformer_mochi.py new file mode 100644 index 000000000000..fc1412c7cd31 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_mochi.py @@ -0,0 +1,84 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import MochiTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class MochiTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = MochiTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + } + + @property + def input_shape(self): + return (4, 2, 16, 16) + + @property + def output_shape(self): + return (4, 2, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 2, + "num_attention_heads": 2, + "attention_head_dim": 8, + "num_layers": 2, + "pooled_projection_dim": 16, + "in_channels": 4, + "out_channels": None, + "qk_norm": "rms_norm", + "text_embed_dim": 16, + "time_embed_dim": 4, + "activation_fn": "swiglu", + "max_sequence_length": 16, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"MochiTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/mochi/__init__.py b/tests/pipelines/mochi/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py new file mode 100644 index 000000000000..2192c171aa22 --- /dev/null +++ b/tests/pipelines/mochi/test_mochi.py @@ -0,0 +1,299 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = MochiPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = MochiTransformer3DModel( + patch_size=2, + num_attention_heads=2, + attention_head_dim=8, + num_layers=2, + pooled_projection_dim=16, + in_channels=12, + out_channels=None, + qk_norm="rms_norm", + text_embed_dim=32, + time_embed_dim=4, + activation_fn="swiglu", + max_sequence_length=16, + ) + transformer.pos_frequencies.data = transformer.pos_frequencies.new_full(transformer.pos_frequencies.shape, 0) + + torch.manual_seed(0) + vae = AutoencoderKLMochi( + latent_channels=12, + out_channels=3, + encoder_block_out_channels=(32, 32, 32, 32), + decoder_block_out_channels=(32, 32, 32, 32), + layers_per_block=(1, 1, 1, 1, 1), + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": 16, + "width": 16, + # 6 * k + 1 is the recommendation + "num_frames": 7, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (7, 3, 16, 16)) + expected_video = torch.randn(7, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + +@slow +@require_torch_gpu +class MochiPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_cogvideox(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + videos = pipe( + prompt=prompt, + height=480, + width=848, + num_frames=19, + generator=generator, + num_inference_steps=2, + output_type="pt", + ).frames + + video = videos[0] + expected_video = torch.randn(1, 16, 480, 848, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video, expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video}" From 08ac5cbc7f96d348464a84ef11e31be3e41c6826 Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Wed, 6 Nov 2024 02:35:20 +0530 Subject: [PATCH 19/20] [Fix] Test of sd3 lora (#9843) * fix test * fix test asser * fix format * Update test_lora_layers_sd3.py --- tests/lora/test_lora_layers_sd3.py | 40 +++--------------------------- 1 file changed, 3 insertions(+), 37 deletions(-) diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 78d4b786d21b..b37a2a297e04 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -166,48 +166,14 @@ def get_inputs(self, device, seed=0): def test_sd3_img2img_lora(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16) - pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") + pipe.load_lora_weights("zwloong/sd3-lora-training-rank16-v2", weight_name="pytorch_lora_weights.safetensors") pipe.enable_sequential_cpu_offload() inputs = self.get_inputs(torch_device) image = pipe(**inputs).images[0] - image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - 0.47827148, - 0.5, - 0.71972656, - 0.3955078, - 0.4194336, - 0.69628906, - 0.37036133, - 0.40820312, - 0.6923828, - 0.36450195, - 0.40429688, - 0.6904297, - 0.35595703, - 0.39257812, - 0.68652344, - 0.35498047, - 0.3984375, - 0.68310547, - 0.34716797, - 0.3996582, - 0.6855469, - 0.3388672, - 0.3959961, - 0.6816406, - 0.34033203, - 0.40429688, - 0.6845703, - 0.34228516, - 0.4086914, - 0.6870117, - ] - ) - + image_slice = image[0, -3:, -3:] + expected_slice = np.array([0.5396, 0.5776, 0.7432, 0.5151, 0.5586, 0.7383, 0.5537, 0.5933, 0.7153]) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}" From a03bf4a531c69a77f3d0cfbb87fc0bd436b93176 Mon Sep 17 00:00:00 2001 From: Vahid Askari <90127147+vahidaskari@users.noreply.github.com> Date: Wed, 6 Nov 2024 02:07:11 +0330 Subject: [PATCH 20/20] Fix: Remove duplicated comma in distributed_inference.md (#9868) Fix: Remove duplicated comma Co-authored-by: Sayak Paul --- docs/source/en/training/distributed_inference.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index 0e1eb7962bf7..79b4f785f30c 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -183,7 +183,7 @@ Add the transformer model to the pipeline for denoising, but set the other model ```py pipeline = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", , + "black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, tokenizer=None,