This is the official repository for Saliency Unlearning for stable diffusion. The code structure of this project is adapted from the ESD codebase.
- To get started clone the following repository of Original Stable Diffusion Link
- Then download the files from our repository to
stable-diffusion
main directory of stable diffusion. This would replace theldm
folder of the original repo with our customldm
directory - Download the weights from here and move them to
SD/models/ldm/
- [Only for training] To convert your trained models to diffusers download the diffusers Unet config from here
The unlearned weights for NSFW and object forgetting are available here.
-
First, we need to generate saliency map for unlearning.
python train-scripts/generate_mask.py --ckpt_path 'models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt' --classes {label} --device '0'
This will save saliency map in
SD/mask/{label}
. -
Forgetting training with Saliency-Unlearning
python train-scripts/random_label.py --train_method full --alpha 0.5 --lr 1e-5 --epochs 5 --class_to_forget {label} --mask_path 'mask/{label}/with_0.5.pt' --device '0'
This should create another folder in
SD/model
.You can experiment with forgetting different class labels using the
--class_to_forget
flag, but we will consider forgetting the 0 (tench) class here. -
Forgetting training with ESD
Edit
train-script/train-esd.py
and change the default argparser values according to your convenience (especially the config paths) To choose train_method, pick from following'xattn'
,'noxattn'
,'selfattn'
,'full'
python train-scripts/train-esd.py --prompt 'your prompt' --train_method 'your choice of training' --devices '0,1'
- To use
eval-scripts/generate-images.py
you would need a csv file with columnsprompt
,evaluation_seed
andcase_number
. (Sample data indata/
) - To generate multiple images per prompt use the argument
num_samples
. It is default to 10. - The path to model can be customised in the script.
- It is to be noted that the current version requires the model to be in saved in
SD/model/compvis-<based on hyperparameters>/diffusers-<based on hyperparameters>.pt
python eval-scripts/generate-images.py --prompts_path 'prompts/imagenette.csv' --save_path 'evaluation_folder/ --model_name {model} --device 'cuda:0'
-
FID
- First,we need to select some images from Imagenette as real images.
- Then, we can compute FID between real images and generated images.
python eval-scripts/compute-fid.py --folder_path {images_path}
-
Accuracy
python eval-scripts/imageclassify.py --prompts_path 'prompts/imagenette.csv' --folder_path {images_path}
-
To remove NSFW-concept, we initially utilize SD V1.4 to generate 800 images as Df with the prompt "a photo of a nude person" and store them in "SD/data/nsfw". Additionally, we generate another 800 images designated as Dr using the prompt "a photo of a person wearing clothes" and store them in "SD/data/not-nsfw".
-
Next, we need to generate saliency map for NSFW-concept.
python train-scripts/generate_mask.py --ckpt_path 'models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt' --nsfw True --device '0'
This will save saliency map in
SD/mask
. -
Forgetting training with Saliency-Unlearning
python train-scripts/nsfw_removal.py --train_method 'full' --mask_path 'mask/nude_0.5.pt' --device '0'