Skip to content

Commit

Permalink
Merge pull request #140 from huggingface/main
Browse files Browse the repository at this point in the history
Merge changes
  • Loading branch information
Skquark authored Jan 22, 2024
2 parents b489020 + 5e96333 commit 90630f0
Show file tree
Hide file tree
Showing 44 changed files with 3,767 additions and 537 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr_test_peft_backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/test_lora_layers_peft.py
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi

## Quickstart

Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 16000+ checkpoints):
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 19000+ checkpoints):

```python
from diffusers import DiffusionPipeline
Expand Down Expand Up @@ -221,7 +221,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
- https://github.com/deep-floyd/IF
- https://github.com/bentoml/BentoML
- https://github.com/bmaltais/kohya_ss
- +7000 other amazing GitHub repositories 💪
- +8000 other amazing GitHub repositories 💪

Thank you for using us ❤️.

Expand Down
5 changes: 2 additions & 3 deletions docker/diffusers-pytorch-compile-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
numpy \
scipy \
tensorboard \
transformers \
omegaconf

transformers

CMD ["/bin/bash"]
1 change: 0 additions & 1 deletion docker/diffusers-pytorch-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
scipy \
tensorboard \
transformers \
omegaconf \
pytorch-lightning

CMD ["/bin/bash"]
1 change: 0 additions & 1 deletion docker/diffusers-pytorch-xformers-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
scipy \
tensorboard \
transformers \
omegaconf \
xformers

CMD ["/bin/bash"]
3 changes: 3 additions & 0 deletions docs/source/en/api/models/autoencoderkl.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ model = AutoencoderKL.from_single_file(url)
## AutoencoderKL

[[autodoc]] AutoencoderKL
- decode
- encode
- all

## AutoencoderKLOutput

Expand Down
58 changes: 58 additions & 0 deletions docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,62 @@ export_to_gif(frames, "animation.gif")
</tr>
</table>

## Using FreeInit

[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.

FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.

The following example demonstrates the usage of FreeInit.

```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers.utils import export_to_gif

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
clip_sample=False,
timestep_spacing="linspace",
steps_offset=1
)

# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()

# enable FreeInit
# Refer to the enable_free_init documentation for a full list of configurable parameters
pipe.enable_free_init(method="butterworth", use_fast_sampling=True)

# run inference
output = pipe(
prompt="a panda playing a guitar, on a boat, in the ocean, high quality",
negative_prompt="bad quality, worse quality",
num_frames=16,
guidance_scale=7.5,
num_inference_steps=20,
generator=torch.Generator("cpu").manual_seed(666),
)

# disable FreeInit
pipe.disable_free_init()

frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```

<Tip warning={true}>

FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).

</Tip>

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
Expand All @@ -248,6 +304,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- __call__
- enable_freeu
- disable_freeu
- enable_free_init
- disable_free_init
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ source .env/bin/activate

You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:


<frameworkcontent>
<pt>
Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
```bash
pip install diffusers["torch"] transformers
```
Expand Down
3 changes: 2 additions & 1 deletion docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-a
IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.

```py
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
from diffusers import AutoPipelineForText2Image
from transformers import CLIPVisionModelWithProjection
import torch

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/using-diffusers/sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Before you begin, make sure you have the following libraries installed:

```py
# uncomment to install the necessary libraries in Colab
#!pip install -q diffusers transformers accelerate omegaconf invisible-watermark>=0.2.0
#!pip install -q diffusers transformers accelerate invisible-watermark>=0.2.0
```

<Tip warning={true}>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/using-diffusers/sdxl_turbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Before you begin, make sure you have the following libraries installed:

```py
# uncomment to install the necessary libraries in Colab
#!pip install -q diffusers transformers accelerate omegaconf
#!pip install -q diffusers transformers accelerate
```

## Load model checkpoints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -1278,7 +1279,7 @@ def main(args):
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
Expand All @@ -1287,12 +1288,17 @@ def main(args):
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
param.requires_grad = False

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
Expand All @@ -1303,14 +1309,14 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder:
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
if args.train_text_encoder:
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
Expand Down Expand Up @@ -1338,11 +1344,11 @@ def load_model_hook(models, input_dir):
while len(models) > 0:
model = models.pop()

if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
Expand Down Expand Up @@ -1719,19 +1725,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)

# flag used for textual inversion
pivoted = False
for epoch in range(first_epoch, args.num_train_epochs):
# if performing any kind of optimization of text_encoder params
if args.train_text_encoder or args.train_text_encoder_ti:
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
# this flag is used to reset the optimizer to optimize only on unet params
pivoted = True

else:
# still optimizng the text encoder
# still optimizing the text encoder
text_encoder_one.train()
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
Expand All @@ -1741,6 +1747,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

unet.train()
for step, batch in enumerate(train_dataloader):
if pivoted:
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0

with accelerator.accumulate(unet):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
Expand Down Expand Up @@ -1879,8 +1891,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

# every step, we reset the embeddings to the original embeddings.
if args.train_text_encoder_ti:
for idx, text_encoder in enumerate(text_encoders):
embedding_handler.retract_embeddings()
embedding_handler.retract_embeddings()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down
29 changes: 28 additions & 1 deletion examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |

| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |

To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
Expand Down Expand Up @@ -2990,7 +2992,7 @@ pipe = DiffusionPipeline.from_pretrained(
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
model_id, subfolder="scheduler", beta_schedule="linear", clip_sample=False, timestep_spacing="linspace", steps_offset=1
)
pipe.enable_vae_slicing()
Expand Down Expand Up @@ -3409,7 +3411,32 @@ images = pipe(
pipe.disable_style_aligned()
```

### AnimateDiff Image-To-Video Pipeline

This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.

```py
import torch
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
from diffusers.utils import export_to_gif, load_image

adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")

image = load_image("snail.png")
output = pipe(
image=image,
prompt="A snail moving on the ground",
strength=0.8,
latent_interpolation_method="slerp", # can be lerp, slerp, or your own callback
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```

### IP Adapter Face ID

IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
Expand Down
Loading

0 comments on commit 90630f0

Please sign in to comment.