Skip to content

Commit

Permalink
Fixed some errors in diffusers training.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuchuting committed Nov 15, 2024
1 parent 19c1bfb commit 55ecd83
Show file tree
Hide file tree
Showing 21 changed files with 74 additions and 104 deletions.
13 changes: 13 additions & 0 deletions examples/diffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,20 @@
> We've tried to provide a completely consistent interface and usage with the [huggingface/diffusers](https://github.com/huggingface/diffusers).
> Only necessary changes are made to the [huggingface/diffusers](https://github.com/huggingface/diffusers) to make it seamless for users from torch.
## Requirements

| mindspore | ascend driver | firmware |cann toolkit/kernel |
|:----------:|:--------------:|:-----------:|:------------------:|
| 2.3.1 | 24.1.RC2 | 7.3.0.1.231 | 8.0.RC2.beta1 |

To install other dependent packages:

```bash
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install -e .
pip install -e ".[training]"
```


## Quickstart
Expand Down
6 changes: 1 addition & 5 deletions examples/diffusers/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install -e .
```

Then cd in the example folder and run
```bash
pip install -r requirements.txt
pip install -e ".[training]"
```

## Circle filling dataset
Expand Down
5 changes: 1 addition & 4 deletions examples/diffusers/controlnet/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install -e .
pip install -e ".[training]"
```

Then cd in the `examples/controlnet` folder and run
```bash
pip install -r requirements_sdxl.txt
```

## Circle filling dataset

Expand Down
2 changes: 1 addition & 1 deletion examples/diffusers/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ def __init__(

def forward(self, pixel_values, conditioning_pixel_values, prompt_ids, add_text_embeds, add_time_ids):
# Convert images to latent space
latents = self.vae.diag_gauss_dist.sample(self.vae.encode(pixel_values)[0])
latents = self.vae.diag_gauss_dist.sample(self.vae.encode(pixel_values.to(self.vae.dtype))[0])
latents = latents * self.vae_scaling_factor
latents = latents.to(self.weight_dtype)

Expand Down
21 changes: 4 additions & 17 deletions examples/diffusers/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install -e .
pip install -e ".[training]"
```

Then cd in the example folder and run
```bash
pip install -r requirements.txt
```

### Dog toy example

Expand Down Expand Up @@ -124,16 +121,6 @@ python train_dreambooth.py \
--train_text_encoder
```

### Using DreamBooth for pipelines other than Stable Diffusion

The [AltDiffusion pipeline](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion) also supports dreambooth fine-tuning. The process is the same as above, all you need to do is replace the `MODEL_NAME` like this:

```
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
or
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
```

### Inference

Once you have trained a model using the above command, you can run inference simply using the `StableDiffusionPipeline`. Make sure to include the `identifier` (e.g. sks in above example) in your prompt.
Expand Down Expand Up @@ -238,7 +225,7 @@ pipe.load_lora_weights("path-to-the-lora-checkpoint")
Finally, we can run the model in inference.

```python
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25)[0][0]
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=50)[0][0]
```

If you are loading the LoRA parameters from the Hub and if the Hub repository has
Expand Down Expand Up @@ -391,7 +378,7 @@ python train_dreambooth_lora.py \
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="dreambooth_dog_upscale"
export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
export VALIDATION_IMAGES="dog_downsized/image_1.jpg dog_downsized/image_2.jpg dog_downsized/image_3.jpg dog_downsized/image_4.jpg"

python train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
Expand Down Expand Up @@ -464,7 +451,7 @@ faces required large effective batch sizes.
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="dreambooth_dog_upscale"
export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
export VALIDATION_IMAGES="dog_downsized/image_1.jpg dog_downsized/image_2.jpg dog_downsized/image_3.jpg dog_downsized/image_4.jpg"

python train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
Expand Down
6 changes: 1 addition & 5 deletions examples/diffusers/dreambooth/README_sd3.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install -e .
```

Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sd3.txt
pip install -e ".[training]"
```


Expand Down
7 changes: 2 additions & 5 deletions examples/diffusers/dreambooth/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install -e .
pip install -e ".[training]"
```

Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_sdxl.txt
```

### Dog toy example

Expand Down Expand Up @@ -184,7 +181,7 @@ python train_dreambooth_lora_sdxl.py \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
Expand Down
6 changes: 0 additions & 6 deletions examples/diffusers/dreambooth/requirements_sd3.txt

This file was deleted.

12 changes: 7 additions & 5 deletions examples/diffusers/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def log_validation(
else:
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
image = pipeline(**pipeline_args, image=image, generator=generator)[0][0]
images.append(image)

if is_master(args):
Expand Down Expand Up @@ -589,7 +589,7 @@ def __getitem__(self, index):
example["instance_images"] = self.image_transforms(instance_image)[0]

if self.encoder_hidden_states is not None:
example["instance_prompt_ids"] = self.encoder_hidden_states
example["instance_prompt_ids"] = self.encoder_hidden_states.asnumpy()
else:
text_inputs = tokenize_prompt(
self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
Expand Down Expand Up @@ -822,7 +822,8 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# set sample_size of unet
unet.register_to_config(sample_size=args.resolution // (2 ** (len(vae.config.block_out_channels) - 1)))
if vae is not None:
unet.register_to_config(sample_size=args.resolution // (2 ** (len(vae.config.block_out_channels) - 1)))

def freeze_params(m: nn.Cell):
for p in m.get_parameters():
Expand Down Expand Up @@ -1178,15 +1179,16 @@ def __init__(
self.unet = unet
self.unet_in_channels = unet.config.in_channels
self.vae = vae
self.vae_scaling_factor = vae.config.scaling_factor
if self.vae is not None:
self.vae_scaling_factor = vae.config.scaling_factor
self.text_encoder = text_encoder
self.noise_scheduler = noise_scheduler
self.noise_scheduler_num_train_timesteps = noise_scheduler.config.num_train_timesteps
self.noise_scheduler_prediction_type = noise_scheduler.config.prediction_type
self.weight_dtype = weight_dtype
self.args = AttrJitWrapper(**vars(args))

def forward(self, pixel_values, input_ids, attention_mask):
def forward(self, pixel_values, input_ids, attention_mask=None):
pixel_values = pixel_values.to(dtype=self.weight_dtype)

if self.vae is not None:
Expand Down
12 changes: 7 additions & 5 deletions examples/diffusers/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def log_validation(
else:
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
image = pipeline(**pipeline_args, image=image, generator=generator)[0][0]
images.append(image)

phase_name = "test" if is_final_validation else "validation"
Expand Down Expand Up @@ -551,7 +551,7 @@ def __getitem__(self, index):
example["instance_images"] = self.image_transforms(instance_image)[0]

if self.encoder_hidden_states is not None:
example["instance_prompt_ids"] = self.encoder_hidden_states
example["instance_prompt_ids"] = self.encoder_hidden_states.asnumpy()
else:
text_inputs = tokenize_prompt(
self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
Expand Down Expand Up @@ -775,7 +775,8 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# set sample_size of unet
unet.register_to_config(sample_size=args.resolution // (2 ** (len(vae.config.block_out_channels) - 1)))
if vae is not None:
unet.register_to_config(sample_size=args.resolution // (2 ** (len(vae.config.block_out_channels) - 1)))

# We only train the additional adapter LoRA layers
def freeze_params(m: nn.Cell):
Expand Down Expand Up @@ -1261,15 +1262,16 @@ def __init__(
self.unet = unet
self.unet_in_channels = unet.config.in_channels
self.vae = vae
self.vae_scaling_factor = vae.config.scaling_factor
if self.vae is not None:
self.vae_scaling_factor = vae.config.scaling_factor
self.text_encoder = text_encoder
self.noise_scheduler = noise_scheduler
self.noise_scheduler_num_train_timesteps = noise_scheduler.config.num_train_timesteps
self.noise_scheduler_prediction_type = noise_scheduler.config.prediction_type
self.weight_dtype = weight_dtype
self.args = AttrJitWrapper(**vars(args))

def forward(self, pixel_values, input_ids, attention_mask):
def forward(self, pixel_values, input_ids, attention_mask=None):
pixel_values = pixel_values.to(dtype=self.weight_dtype)

if self.vae is not None:
Expand Down
37 changes: 20 additions & 17 deletions examples/diffusers/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=ms.float32):
sigmas = noise_scheduler.sigmas.to(dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps

step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = [(schedule_timesteps == t).nonzero().item(0) for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
Expand All @@ -870,7 +870,9 @@ def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=ms.float32):

def main():
args = parse_args()
ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT)
ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.LAX)
if args.train_text_encoder:
ms.set_context(max_call_depth=5000)
init_distributed_device(args) # read attr distributed, writer attrs rank/local_rank/world_size

# tensorboard, mindinsight, wandb logging stuff into logging_dir
Expand Down Expand Up @@ -1448,8 +1450,9 @@ def load_model_hook(models, input_dir):
output_model_file = os.path.join(save_path, "pytorch_model.ckpt")
ms.save_checkpoint(unet, output_model_file)
logger.info(f"Saved state to {save_path}")

logs = {"loss": loss.numpy().item(), "lr": optimizer.get_lr().numpy().item()}
log_lr = optimizer.get_lr()
log_lr = log_lr[0] if isinstance(log_lr, tuple) else log_lr
logs = {"loss": loss.numpy().item(), "lr": log_lr.numpy().item()}
progress_bar.set_postfix(**logs)
for tracker_name, tracker in trackers.items():
if tracker_name == "tensorboard":
Expand All @@ -1469,6 +1472,19 @@ def load_model_hook(models, input_dir):
(epoch + 1),
)

# Final inference
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
log_validation(
pipeline,
args,
trackers,
logging_dir,
pipeline_args,
args.num_train_epochs,
is_final_validation=True,
)

# Save the lora layers
if is_master(args):
unet = unet.to(ms.float32)
Expand All @@ -1492,19 +1508,6 @@ def load_model_hook(models, input_dir):
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)

# Final inference
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
log_validation(
pipeline,
args,
trackers,
logging_dir,
pipeline_args,
args.num_train_epochs,
is_final_validation=True,
)

# End of training
for tracker_name, tracker in trackers.items():
if tracker_name == "tensorboard":
Expand Down
5 changes: 1 addition & 4 deletions examples/diffusers/text_to_image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install .
pip install -e ".[training]"
```

Then cd in the example folder `examples/diffusers/text_to_image` and run
```bash
pip install -r requirements.txt
```

### OnePiece example

Expand Down
11 changes: 1 addition & 10 deletions examples/diffusers/text_to_image/README_sdxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install -e .
pip install -e ".[training]"
```

Then cd in the `examples/diffusers/text_to_image` folder and run
```bash
pip install -r requirements_sdxl.txt
```

### Training

Expand Down Expand Up @@ -83,12 +80,6 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de
With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
on consumer GPUs like Tesla T4, Tesla V100.

> [!WARNING]
> If you're using mindspore 2.2.x, you have to set the `MS_DEV_TRAVERSE_SUBSTITUTIONS_MODE` environment variables to `1` before running the training commands,
> otherwise you'll get a segmentation fault (core dumped).
> ```bash
> export MS_DEV_TRAVERSE_SUBSTITUTIONS_MODE=1
> ```

### Training

Expand Down
5 changes: 1 addition & 4 deletions examples/diffusers/textual_inversion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ To make sure you can successfully run the latest versions of the example scripts
git clone https://github.com/mindspore-lab/mindone
cd mindone
pip install .
pip install -e ".[training]"
```

Then cd in the example folder and run:
```bash
pip install -r requirements.txt
```

### Cat toy example

Expand Down
7 changes: 0 additions & 7 deletions examples/diffusers/textual_inversion/README_sdxl.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
## Textual Inversion fine-tuning example for SDXL

> [!WARNING]
> If you're using mindspore 2.2.x, you have to set the `MS_DEV_TRAVERSE_SUBSTITUTIONS_MODE` environment variables to `1` before running the training commands,
> otherwise you'll get a segmentation fault (core dumped).
> ```bash
> export MS_DEV_TRAVERSE_SUBSTITUTIONS_MODE=1
> ```
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATA_DIR="./cat"
Expand Down
Loading

0 comments on commit 55ecd83

Please sign in to comment.