Skip to content

Commit

Permalink
Merge pull request #150 from huggingface/main
Browse files Browse the repository at this point in the history
Merge changes - yay Stable Cascade!
  • Loading branch information
Skquark authored Mar 6, 2024
2 parents b62dfa1 + 534f5d5 commit 28190fd
Show file tree
Hide file tree
Showing 27 changed files with 3,313 additions and 29 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
Expand Down
88 changes: 88 additions & 0 deletions docs/source/en/api/pipelines/stable_cascade.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Stable Cascade

This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
Diffusion 1.5.

Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.

The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).

## Model Overview
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
hence the name "Stable Cascade".

Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
for generating the small 24 x 24 latents given a text prompt.

## Uses

### Direct Use

The model is intended for research purposes for now. Possible research areas and tasks include

- Research on generative models.
- Safe deployment of models which have the potential to generate harmful content.
- Probing and understanding the limitations and biases of generative models.
- Generation of artworks and use in design and other artistic processes.
- Applications in educational or creative tools.

Excluded uses are described below.

### Out-of-Scope Use

The model was not trained to be factual or true representations of people or events,
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).

## Limitations and Bias

### Limitations
- Faces and people in general may not be generated properly.
- The autoencoding part of the model is lossy.


## StableCascadeCombinedPipeline

[[autodoc]] StableCascadeCombinedPipeline
- all
- __call__

## StableCascadePriorPipeline

[[autodoc]] StableCascadePriorPipeline
- all
- __call__

## StableCascadePriorPipelineOutput

[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput

## StableCascadeDecoderPipeline

[[autodoc]] StableCascadeDecoderPipeline
- all
- __call__

Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,6 @@ def parse_args(input_args=None):
)
parser.add_argument(
"--use_dora",
type=bool,
action="store_true",
default=False,
help=(
Expand Down
28 changes: 28 additions & 0 deletions examples/research_projects/diffusion_dpo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ accelerate launch train_diffusion_dpo_sdxl.py \
--push_to_hub
```

## SDXL Turbo training command

```bash
accelerate launch train_diffusion_dpo_sdxl.py \
--pretrained_model_name_or_path=stabilityai/sdxl-turbo \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir="diffusion-sdxl-turbo-dpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--train_batch_size=8 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--checkpointing_steps=500 \
--run_validation --validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--is_turbo --resolution 512 \
--push_to_hub
```


## Acknowledgements

This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,16 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
images = []
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()

guidance_scale = 5.0
num_inference_steps = 25
if args.is_turbo:
guidance_scale = 0.0
num_inference_steps = 4
for prompt in VALIDATION_PROMPTS:
with context:
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
image = pipeline(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
).images[0]
images.append(image)

tracker_key = "test" if is_final_validation else "validation"
Expand All @@ -141,7 +148,10 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
if is_final_validation:
pipeline.disable_lora()
no_lora_images = [
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
pipeline(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
).images[0]
for prompt in VALIDATION_PROMPTS
]

for tracker in accelerator.trackers:
Expand Down Expand Up @@ -423,6 +433,11 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--is_turbo",
action="store_true",
help=("Use if tuning SDXL Turbo instead of SDXL"),
)
parser.add_argument(
"--rank",
type=int,
Expand All @@ -444,6 +459,9 @@ def parse_args(input_args=None):
if args.dataset_name is None:
raise ValueError("Must provide a `dataset_name`.")

if args.is_turbo:
assert "turbo" in args.pretrained_model_name_or_path

env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
Expand Down Expand Up @@ -560,6 +578,36 @@ def main(args):

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

def enforce_zero_terminal_snr(scheduler):
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
# Original implementation https://arxiv.org/pdf/2305.08891.pdf
# Turbo needs zero terminal SNR
# Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
# Convert betas to alphas_bar_sqrt
alphas = 1 - scheduler.betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])

alphas_cumprod = torch.cumprod(alphas, dim=0)
scheduler.alphas_cumprod = alphas_cumprod
return

if args.is_turbo:
enforce_zero_terminal_snr(noise_scheduler)

text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
Expand Down Expand Up @@ -909,6 +957,10 @@ def collate_fn(examples):
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
).repeat(2)
if args.is_turbo:
# Learn a 4 timestep schedule
timesteps_0_to_3 = timesteps % 4
timesteps = 250 * timesteps_0_to_3 + 249

# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down
Loading

0 comments on commit 28190fd

Please sign in to comment.