Skip to content

Commit

Permalink
OpenSora-PKU v1.1: training speed improvement (#557)
Browse files Browse the repository at this point in the history
* correct enable tiling

* skip compress_maxpool2d if compress_kv_factor==1

* revise (65+16)x512x512 train scripts max device memory

* correct printing message

* update profilercallback epoch

* vae dtype and custom fp32 printing

* update speed info

* update tv2.ckpt download link

* update (65+4)x512x512 speed

* update (65+16)x512x512 speed

* fix text encoder dtype printing
  • Loading branch information
wtomin authored Jul 1, 2024
1 parent f4328d1 commit 0922795
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 27 deletions.
11 changes: 7 additions & 4 deletions examples/opensora_pku/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Other useful documents and links are listed below.
## Installation
1. Use python>=3.8 [[install]](https://www.python.org/downloads/)

2. Install MindSpore 2.3 master (0615daily) according to the [website](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/) and use C18 CANN (0517) which can be downloaded from [here](https://repo.mindspore.cn/ascend/ascend910/20240517/).
2. Install MindSpore 2.3 master (0615daily) according to the [website](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/). Select the corresponding wheel file based your computer's OS and the python verison. Please use C18 CANN (0517) which can be downloaded from [here](https://repo.mindspore.cn/ascend/ascend910/20240517/).


3. Install requirements
Expand Down Expand Up @@ -250,10 +250,12 @@ msrun --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="output_l
The first-stage training depends on the `t2v.pt` from [Vchitect/Latte](https://huggingface.co/maxin-cn/Latte/tree/main). Please download `t2v.pt` and place it under `LanguageBind/Open-Sora-Plan-v1.1.0/t2v.pt`. Then run model conversion with:
```bash
python tools/model_conversion/convert_latte.py \
--src pretrained/t2v.pt \
--src LanguageBind/Open-Sora-Plan-v1.1.0/t2v.pt \
--target LanguageBind/Open-Sora-Plan-v1.1.0/t2v.ckpt
```

> **Since [Vchitect/Latte](https://huggingface.co/maxin-cn/Latte/tree/main) has deleted `t2v.pt` from their HF repo, please download `t2v.ckpt` from this [URL](https://download-mindspore.osinfra.cn/toolkits/mindone/opensora-pku/tv2.ckpt). There is no need to convert it.**
The [Open-Sora-Dataset-v1.1.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main) includes three image datasets and three video datasets, as recorded in `scripts/train_data/image_data.txt` and `scripts/train_data/video_data.txt`. Each line includes the paths to three folders/files: the video folder, the t5 embedding cache folder, and the path to the annotation json file.
For acceleration, we pre-compute the t5 embedding before training the diffusion transformer.

Expand Down Expand Up @@ -323,8 +325,9 @@ We evaluated the training performance on MindSpore and Ascend NPUs. The results

| Model | Context | Precision | BS | NPUs | num_frames + num_images| Resolution | Train T. (s/step) |
|:----------------|:---------------|:----------|:--:|:----:|:-----------:|:-----------:|:--------------:|
| LatteT2V-XL/122 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/) | BF16 | 2 | 8 | 17 + 4 | 512x512 | 2.6 |
| LatteT2V-XL/122 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/) | BF16 | 2 | 8 | 65 + 16 | 512x512 | 12.4 |
| LatteT2V-XL/122 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | BF16 | 2 | 8 | 17 + 4 | 512x512 | 2.6 |
| LatteT2V-XL/122 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | BF16 | 2 | 8 | 65 + 16 | 512x512 | 11.2 |
| LatteT2V-XL/122 | D910\*-[CANN C18(0517)](https://repo.mindspore.cn/ascend/ascend910/20240517/)-[MS2.3_master(0615)](https://repo.mindspore.cn/mindspore/mindspore/version/202406/20240615/master_20240615020018_43ccb91e45899b64fe31d304497ab17e3ada3cea_newest/unified/) | BF16 | 2 | 8 | 65 + 4 | 512x512 | 7.5 |
> Context: {NPU type}-{CANN version}-{MindSpore version}
## 👍 Acknowledgement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ def construct(
attention_mask = ops.ones((input_batch_size, frame + use_image_num, h, w), dtype=hidden_states.dtype)
attention_mask = self.vae_to_diff_mask(attention_mask, use_image_num)
dtype = attention_mask.dtype
attention_mask_compress = self.compress_maxpool2d(attention_mask)
attention_mask_compress = attention_mask_compress.to(dtype)
attention_mask_compress = (
self.compress_maxpool2d(attention_mask).to(dtype) if self.compress_kv_factor != 1 else attention_mask
)

attention_mask = self.make_attn_mask(attention_mask, frame, hidden_states.dtype)
attention_mask_compress = self.make_attn_mask(attention_mask_compress, frame, hidden_states.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):

RoPE2D = cuRoPE2D
except ImportError:
print("Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead")
print("Warning, cannot find compiled version of RoPE2D, using a slow version instead")

class RoPE2D(nn.Cell):
def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0):
Expand Down Expand Up @@ -153,7 +153,7 @@ def construct(self, tokens, positions):

RoPE1D = cuRoPE1D
except ImportError:
print("Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead")
print("Warning, cannot find compiled version of RoPE2D, using a slow version instead")

class RoPE1D(nn.Cell):
def __init__(self, freq=10000.0, F0=1.0, scaling_factor=1.0):
Expand Down
18 changes: 14 additions & 4 deletions examples/opensora_pku/opensora/sample/sample_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ def parse_args():
choices=["bf16", "fp16"],
help="what data type to use for vae. Default is `bf16`, which corresponds to ms.bfloat16",
)
parser.add_argument(
"--vae_keep_gn_fp32",
default=False,
type=str2bool,
help="whether keep GroupNorm in fp32. Defaults to False in inference mode. If training vae, better set it to True",
)
parser.add_argument(
"--text_encoder_precision",
default="bf16",
Expand Down Expand Up @@ -278,9 +284,12 @@ def parse_args():
vae.vae_scale_factor = ae_stride_config[args.ae]
# use amp level O2 for causal 3D VAE with bfloat16 or float16
vae_dtype = get_precision(args.vae_precision)
custom_fp32_cells = [nn.GroupNorm] if vae_dtype == ms.float16 else [nn.AvgPool2d, TrilinearInterpolate]
if vae_dtype == ms.float16:
custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else []
else:
custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate]
vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells)
logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}")
logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}")
vae.set_train(False)
for param in vae.get_parameters(): # freeze vae
param.requires_grad = False
Expand Down Expand Up @@ -445,8 +454,9 @@ def parse_args():
f"Num of samples: {n}",
f"Num params: {num_params:,} (latte: {num_params_latte:,}, vae: {num_params_vae:,})",
f"Num trainable params: {num_params_trainable:,}",
f"Use model dtype: {dtype}",
f"Use FA: {args.enable_flash_attention}",
f"Transformer dtype: {dtype}",
f"VAE dtype: {vae_dtype}",
f"Text encoder dtype: {text_encoder_dtype}",
f"Sampling steps {args.num_sampling_steps}",
f"Sampling method: {args.sample_method}",
f"CFG guidance scale: {args.guidance_scale}",
Expand Down
55 changes: 43 additions & 12 deletions examples/opensora_pku/opensora/train/train_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from opensora.train.commons import create_loss_scaler, init_env, parse_args
from opensora.utils.utils import get_precision

from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback
from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallbackEpoch
from mindone.trainers.checkpoint import resume_train_network
from mindone.trainers.ema import EMA
from mindone.trainers.lr_schedule import create_scheduler
Expand Down Expand Up @@ -87,10 +87,13 @@ def main(args):
else:
logger.info("vae init")
vae = getae_wrapper(args.ae)(args.ae_path, subfolder="vae")
vae_dtype = ms.bfloat16
custom_fp32_cells = [nn.GroupNorm] if vae_dtype == ms.float16 else [nn.AvgPool2d, TrilinearInterpolate]
vae_dtype = get_precision(args.vae_precision)
if vae_dtype == ms.float16:
custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else []
else:
custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate]
vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells)
logger.info(f"Use amp level O2 for causal 3D VAE. Use dtype {vae_dtype}")
logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells {custom_fp32_cells}")

vae.set_train(False)
for param in vae.get_parameters(): # freeze vae
Expand Down Expand Up @@ -157,15 +160,19 @@ def main(args):
else:
model_dtype = get_precision(args.precision)
if not args.global_bf16:
if model_dtype == ms.float16:
custom_fp32_cells = [LayerNorm, Attention, nn.SiLU, nn.GELU]
else:
custom_fp32_cells = [nn.MaxPool2d, LayerNorm, nn.SiLU, nn.GELU]
latte_model = auto_mixed_precision(
latte_model,
amp_level=args.amp_level,
dtype=model_dtype,
custom_fp32_cells=[LayerNorm, Attention, nn.SiLU, nn.GELU]
if model_dtype == ms.float16
else [nn.MaxPool2d, LayerNorm, nn.SiLU, nn.GELU],
custom_fp32_cells=custom_fp32_cells,
)
logger.info(
f"Set mixed precision to {args.amp_level} with dtype={args.precision}, custom_fp32_cells: {custom_fp32_cells}"
)
logger.info(f"Set mixed precision to {args.amp_level} with dtype={args.precision}")
else:
logger.info(f"Using global bf16 for latte t2v model. Force model dtype from {model_dtype} to ms.bfloat16")
model_dtype = ms.bfloat16
Expand All @@ -186,7 +193,7 @@ def main(args):
model_max_length=args.model_max_length,
)
# mixed precision
text_encoder_dtype = ms.bfloat16 # using bf16 for text encoder and vae
text_encoder_dtype = get_precision(args.text_encoder_precision) # using bf16 for text encoder and vae
text_encoder = auto_mixed_precision(text_encoder, amp_level="O2", dtype=text_encoder_dtype)
text_encoder.dtype = text_encoder_dtype
logger.info(f"Use amp level O2 for text encoder T5 with dtype={text_encoder_dtype}")
Expand All @@ -195,6 +202,7 @@ def main(args):
else:
text_encoder = None
tokenizer = None
text_encoder_dtype = None

# 2.3 ldm with loss
diffusion = create_diffusion(timestep_respacing="")
Expand Down Expand Up @@ -368,7 +376,7 @@ def main(args):
)
callback.append(save_cb)
if args.profile:
callback.append(ProfilerCallback())
callback.append(ProfilerCallbackEpoch(2, 2, "./profile_data"))

# 5. log and save config
if rank_id == 0:
Expand All @@ -388,8 +396,11 @@ def main(args):
else "",
f"Num params: {num_params:,} (latte: {num_params_latte:,}, vae: {num_params_vae:,})",
f"Num trainable params: {num_params_trainable:,}",
f"Use model dtype: {model_dtype}",
f"AMP level: {args.amp_level}" if not args.global_bf16 else "Global BF16: True",
f"Transformer model dtype: {model_dtype}",
f"Transformer AMP level: {args.amp_level}" if not args.global_bf16 else "Global BF16: True",
f"VAE dtype: {vae_dtype} (amp level O2)" + f"\nText encoder dtype: {text_encoder_dtype} (amp level O2)"
if text_encoder_dtype is not None
else "",
f"Learning rate: {args.start_learning_rate}",
f"Batch size: {args.batch_size}",
f"Image size: {args.max_image_size}",
Expand Down Expand Up @@ -505,6 +516,26 @@ def parse_t2v_train_args(parser):
help="If use_recompute is True, `num_no_recompute` blocks will be removed from the recomputation list."
"This is a positive integer which can be tuned based on the memory usage.",
)
parser.add_argument(
"--vae_keep_gn_fp32",
default=False,
type=str2bool,
help="whether keep GroupNorm in fp32. Defaults to False in inference mode. If training vae, better set it to True",
)
parser.add_argument(
"--vae_precision",
default="fp16",
type=str,
choices=["bf16", "fp16"],
help="what data type to use for vae. Default is `fp16`, which corresponds to ms.float16",
)
parser.add_argument(
"--text_encoder_precision",
default="bf16",
type=str,
choices=["bf16", "fp16"],
help="what data type to use for T5 text encoder. Default is `bf16`, which corresponds to ms.bfloat16",
)
return parser


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --
--use_recompute True \
--dataset_sink_mode True \
--use_parallel True \
--parallel_mode "optim" \
--parallel_mode "data" \
--num_no_recompute 6 \
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ export MS_ENABLE_NUMA=0
export MS_MEMORY_STATISTIC=1
export GLOG_v=2

export HCCL_BUFFSIZE=1 # reduce memory consumption when dataset_sink_mode=True, may degrade speed
export MS_DATASET_SINK_QUEUE=2 # reduce memory consumption when dataset_sink_mode=True, may degrade speed
# hyper-parameters
image_size=512 # the image size of frames, same to image height and image width
use_image_num=16 # to include n number of images in an input sample
Expand Down Expand Up @@ -40,7 +42,9 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --
--model_max_length 300 \
--clip_grad True \
--use_image_num $use_image_num \
--enable_tiling \
--use_recompute True \
--dataset_sink_mode False \
--dataset_sink_mode True \
--use_parallel True \
--parallel_mode "data" \
--parallel_mode "optim" \
--max_device_memory "59GB" \

0 comments on commit 0922795

Please sign in to comment.