Skip to content

Latest commit

 

History

History
92 lines (65 loc) · 3.25 KB

README.md

File metadata and controls

92 lines (65 loc) · 3.25 KB

Diffusion-LM

figure Figure from the original paper.

This is an unofficial implementation of Diffusion-LM. For the official implementation, please refer here.

🛠️ Setup

If you are using poetry, run the following command to install required dependencies.

poetry install

Next, activate your virtual environment.

poetry shell

You can find more details about the required packages in pyproject.toml.

After that, initialize an 🤗Accelerate environment with:

accelerate config

Alternatively, you can set up a default Accelerate configuration without answering questions about your environment using:

accelerate config default

🚀 Run

Diffusion-LM training

This repository allows you to train the Diffusion-LM on E2E datasets using the following commands.

accelerate launch scripts/train.py --expn foo
Available arguments
  • --expn (-e): The experimental name, which is used as the basename of the output directory. If this argument is not provided, the directory name is assigned based on the current time.
  • --wandb (-w): Indicates whether to use the Weights & Biases tracker.

Classifier Training for Plug-and-Play Control

This repository also supports the training of GPT2 classifiers for control by Semantic Content.

accelerate launch scripts/clf_train.py --output output/foo
Available arguments
  • -o, --output: The directory where the training results will be saved.
  • -mc, --model_ckpt (default='checkpoints/pytorch_model_1.bin'): Path to the Diffusion-LM checkpoint (from the path specified in the --output argument).

Sampling

Conditional Sampling with Plug-and-Play Control

After training the Diffusion-LM and the GPT2 classifier, you can perform conditional sampling.

accelerate launch scripts/sample.py --output output/foo --control_label 'food : Japanese'
Available arguments
  • -o, --output: The directory where the training results will be saved.
  • -n, --n_samples (default=16): The number of samples (used as batch size).
  • -mc, --model_ckpt (default='checkpoints/pytorch_model_1.bin'): Path to the Diffusion-LM checkpoint (from the path specified in the --output argument).
  • -ud, --use_ddpm (default=False): Whether to use DDPM sampling (the default is DDIM).
  • -cc, --clf_ckpt (default='classifier/pytorch_model.bin'): Path to the classifier checkpoint (from the path specified in the output argument).Path to the classifier checkpoint (from the path specified in the --output argument).
  • -cl, --control_label (default=None): Label for plug-and-play control.

Unconditional Sampling

For unconditional sampling, it's only necessary to train the Diffusion-LM (there's no need to train the GPT2 classifier).classifier).

accelerate launch scripts/sample.py --output output/foo

If you find this repository helpful, please consider giving a star ⭐!