Latte [1] is a novel Latent Diffusion Transformer designed for video generation. It is built based on DiT (a diffusion transformer model for image generation). For introduction of DiT [2], please refer to README of DiT.
Latte first uses a VAE (Variational AutoEncoder) to compress the video data into a latent space, and then extracts spatial-temporal tokens given the latent codes. Similar to DiT, it stacks multiple transformer blocks to model the video diffusion in the latent space. How to design the spatial and temporal blocks becomes a major question.
Through experiments and analysis, they found the best practice is structure (a) in the image below. It stacks spatial blocks and temporal blocks alternately to model spatial attentions and temporal attentions in turns.
Figure 1. The Structure of Latte and Latte transformer blocks. [1]
Similar to DiT, Latte supports un-conditional video generation and class-labels-conditioned video generation. In addition, it supports to generate videos given text captions.
In this tutorial, we will introduce how to run inference and training experiments using MindONE.
This tutorial includes:
- Pretrained checkpoints conversion;
- Un-conditional video sampling with pretrained Latte checkpoints;
- Training un-conditional Latte on Sky TimeLapse dataset: support training (1) with videos ; and (2) with embedding cache;
- Mixed Precision: support (1) Float16; (2) BFloat16 (set patch_embedder to "linear");
- Standalone training and distributed training.
- Text-to-Video Latte inference and training.
pip install -r requirements.txt
decord
is required for video generation. In case decord
package is not available in your environment, try pip install eva-decord
.
Instruction on ffmpeg and decord install on EulerOS:
1. install ffmpeg 4, referring to https://ffmpeg.org/releases
wget wget https://ffmpeg.org/releases/ffmpeg-4.0.1.tar.bz2 --no-check-certificate
tar -xvf ffmpeg-4.0.1.tar.bz2
mv ffmpeg-4.0.1 ffmpeg
cd ffmpeg
./configure --enable-shared # --enable-shared is needed for sharing libavcodec with decord
make -j 64
make install
2. install decord, referring to https://github.com/dmlc/decord?tab=readme-ov-file#install-from-source
git clone --recursive https://github.com/dmlc/decord
cd decord
rm build && mkdir build && cd build
cmake .. -DUSE_CUDA=0 -DCMAKE_BUILD_TYPE=Release
make -j 64
make install
cd ../python
python3 setup.py install --user
We refer to the official repository of Latte for pretrained checkpoints downloading. The pretrained checkpoint files trained on FaceForensics, SkyTimelapse, Taichi-HD and UCF101 (256x256) can be downloaded from huggingface.
After downloading the {}.pt
file, please place it under the models/
folder, and then run tools/latte_converter.py
. For example, to convert models/skytimelapse.pt
, you can run:
python tools/latte_converter.py --source models/skytimelapse.pt --target models/skytimelapse.ckpt
Please also download the VAE checkpoint from huggingface/stabilityai.co, and convert this VAE checkpoint by running:
python tools/vae_converter.py --source path/to/vae/ckpt --target models/sd-vae-ft-mse.ckpt
For example, to run inference of skytimelapse.ckpt
model with the 256x256
image size on Ascend devices, you can use:
python sample.py -c configs/inference/sky.yaml
Some of the generated results are shown here:
Example 1 | Example 2 | Example 3 |
Figure 2. The generated videos of the pretrained model converted from the torch checkpoint.
Now, we support training Latte model on the Sky Timelapse dataset, a video dataset which can be downloaded from https://github.com/weixiong-ur/mdgan.
After uncompressing the downloaded file, you will get a folder named sky_train/
which contains all training video frames. The folder structure is similar to:
sky_train/
├── video_name_0/
| ├── frame_id_0.jpg
| ├── frame_id_0.jpg
| └── ...
├── video_name_1/
└── ...
First, edit the configuration file configs/training/data/sky_video.yaml
. Change the data_folder
from ""
to the absolute path to sky_train/
.
Then, you can start standalone training on Ascend devices using:
python train.py -c configs/training/sky_video.yaml
To start training on GPU devices, simply append --device_target GPU
to the command above.
The default training configuration is to train Latte model from scratch. The batch size is O2
. See more details in configs/training/sky_video.yaml
.
To accelerate the training speed, we use dataset_sink_mode: True
in the configuration file by default. You can also set enable_flash_attention: True
to further accelerate the training speed.
After training, the checkpoints are saved under output_dir/ckpt/
. To run inference with the checkpoint, please change checkpoint
in configs/inference/sky.yaml
to the path of the checkpoint, and then run python sample.py -c configs/inference/sky.yaml
.
The number of epochs is set to a large number to ensure convergence. You can terminate training whenever it is ready. For example, we took the checkpoint which was trained for
Example 1 | Example 2 | Example 3 |
Figure 3. The generated videos of the Latte model trained for 1700 epochs (about 500k steps).
We can accelerate the training speed by caching the embeddings of the dataset before running the training script. This takes three steps:
- Step 1: Cache the embedding into a cache folder. See the following example about how to cache the embeddings. This step can take a bit long time.
To cache embeddings for Sky Timelapse dataset, first, please make sure the data_path
in configs/training/sky_video.yaml
is set correctly to the folder named sky_train/
.
Then you can start saving the embeddings using:
python tools/embedding_cache.py --config configs/training/sky_video.yaml --cache_folder path/to/cache/folder --cache_file_type numpy
You can also change cache_file_type
to mindrecord
to save embeddings in .mindrecord
files.
In general, we recommend to use mindrecord
file type because it is supported by MindDataset
which can better accelerates data loading. However, Sky Timelapse dataset has extra long videos. Using mindrecord
file to cache embedding increases the risk of exceeding the maximum page size of the MindRecord writer. Therefore, we recommend to use numpy
file.
The embedding caching process can take a while depending on the size of the video dataset. Some exceptions maybe thrown during the process. If unexpected exceptions are thrown, the program will be stoped and the embedding caching writer's status will be printed on the screen:
Start Video Index: 0. # the start of video index to be processed
Saving Attempts: 0: save 120 videos, failed 0 videos. # the number of saved video files
In this case, you can resume the embedding cache from the video indexed at --resume_cache_index 120
, and run python tools/embedding_cache.py
. It will start caching the embedding from the
To check more usages, please use python tools/embedding_cache.py -h
.
- Step 2: Change the dataset configuration file's
data_folder
to the current cache folder path.
After the embeddings have been cached, edit configs/training/data/sky_numpy_video.yaml
, and change the data_folder
to the folder where the cached embeddings are stored in.
- Step 3: Run the training script.
You can start training on the cached embedding dataset of Sky TimeLapse using:
python train.py -c configs/training/sky_numpy_video.yaml
Note that in sky_numpy_video.yaml
, we use a large number of frames sky_video.yaml
(num_frames=16 and stride=3)· Embedding caching allows us to train Latte to generate more frames with a larger frame rate.
Due to the memory limit, we set the local batch size to
In case of OOM, please set enable_flash_attention: True
in the configs/training/sky_numpy_video.yaml
. It can reduce the memory cost and also accelerate the training speed.
Taking the 4-card distributed training as an example, you can start the distributed training using:
export MS_ASCEND_CHECK_OVERFLOW_MODE="INFNAN_MODE"
mpirun -n 4 python train.py \
-c path/to/configuration/file \
--use_parallel True
where the configuration file can be selected from the .yaml
files in configs/training/
folder.
If you have the rank table of Ascend devices, you can take scripts/run_distributed_sky_numpy_video.sh
as a reference, and start the 4-card distributed training using:
bash scripts/run_distributed_sky_numpy_video.sh path/to/rank/table 0 4
The first number 0
indicates the start index of the training devices, and the second number 4
indicates the total number of distributed processes you want to launch.
The training speed of the experiments with 256x256
image size is summarized in the following table:
Cards | Recompute | Dataset Sink mode | Embedding Cache | Train. imgs/s |
---|---|---|---|---|
1 | OFF | ON | OFF | 62.3 |
1 | ON | ON | ON | 93.6 |
4 | ON | ON | ON | 368.3 |
[1] Xin Ma, Yaohui Wang, Gengyun Jia, Xinyuan Chen, Ziwei Liu, Yuan-Fang Li, Cunjian Chen, Yu Qiao: Latte: Latent Diffusion Transformer for Video Generation. CoRR abs/2401.03048 (2024)
[2] W. Peebles and S. Xie, “Scalable diffusion models with transformers,” in Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4195–4205, 2023