-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 36ab80c
Showing
55 changed files
with
10,196 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
models | ||
models/* | ||
|
||
data | ||
data/* | ||
|
||
training-runs | ||
training-runs/* | ||
|
||
slurm/ | ||
slurm/* | ||
|
||
out | ||
out/* | ||
|
||
**/__pycache__ | ||
__pycache__ | ||
.ipynb_checkpoints/ | ||
|
||
*.zip | ||
*.swp | ||
*.pth | ||
*.pt | ||
*.npz | ||
*.tar | ||
*.gz | ||
*.pkl | ||
*.mp4 | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
||
NVIDIA CORPORATION and its licensors retain all intellectual property | ||
and proprietary rights in and to this software, related documentation | ||
and any modifications thereto. Any use, reproduction, disclosure or | ||
distribution of this software and related documentation without an express | ||
license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
|
||
NVIDIA License | ||
|
||
1. Definitions | ||
|
||
“Licensor” means any person or entity that distributes its Work. | ||
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. | ||
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. | ||
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. | ||
|
||
2. License Grant | ||
|
||
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. | ||
|
||
3. Limitations | ||
|
||
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. | ||
|
||
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. | ||
|
||
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. | ||
|
||
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. | ||
|
||
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. | ||
|
||
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. | ||
|
||
4. Disclaimer of Warranty. | ||
|
||
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF | ||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. | ||
|
||
5. Limitation of Liability. | ||
|
||
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
<img src="docs/banner.png"> | ||
|
||
#### [[Project]](https://sites.google.com/view/stylegan-t/) [[PDF]](https://arxiv.org/abs/2301.09515) [[Video]](https://www.youtube.com/watch?v=MMj8OTOUIok) | ||
This repository contains the **training code** for our paper "StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis". **We do not provide pretrained checkpoints.** | ||
|
||
by [Axel Sauer](https://axelsauer.com/), [Tero Karras](https://research.nvidia.com/person/tero-karras), [Samuli Laine](https://research.nvidia.com/person/samuli-laine), [Andreas Geiger](https://www.cvlibs.net/), [Timo Aila](https://research.nvidia.com/person/timo-aila) | ||
|
||
## Requirements ## | ||
- Use the following commands with Miniconda3 to create and activate your environment: | ||
``` | ||
conda create --name sgt python=3.9 | ||
conda activate sgt | ||
conda install pytorch=1.9.1 torchvision==0.10.1 pytorch-cuda=11.6 -c pytorch -c nvidia | ||
pip install -r requirements.txt | ||
``` | ||
- GCC 7 or later compilers. The recommended GCC version depends on your CUDA version; see for example, CUDA 11.4 system requirements. | ||
- If you run into problems when setting up the custom CUDA kernels, we refer to the [Troubleshooting docs](https://github.com/NVlabs/stylegan3/blob/main/docs/troubleshooting.md#why-is-cuda-toolkit-installation-necessary) of the StyleGAN3 repo. | ||
|
||
|
||
## Data Preparation ## | ||
|
||
StyleGAN-T can be trained on unconditional and conditional datasets. For small-scale experiments, we recommend zip datasets. When training on datasets with more than 1 million images, we recommend using webdatasets. | ||
|
||
### Zip Dataset | ||
Zip-Datasets are stored in the same format as in the previous iterations of [StyleGAN](https://github.com/NVlabs/stylegan3): uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images. | ||
|
||
|
||
**CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive: | ||
|
||
```.bash | ||
python dataset_tool.py --source downloads/cifar10/cifar-10-python.tar.gz \ | ||
--dest data/cifar10-32x32.zip | ||
``` | ||
|
||
**FFHQ:** Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and convert to ZIP archive at the same resolution: | ||
|
||
```.bash | ||
python dataset_tool.py --source downloads/ffhq/images1024x1024 \ | ||
--dest data/ffhq1024.zip --resolution 1024x1024 | ||
``` | ||
|
||
**COCO validation set:** The COCO validation set is used for tracking zero-shot FID and CLIP score. First, download the [COCO meta data](https://drive.google.com/file/d/1Xbg36mTJGG68RI_YgSPwhAE2msxgGfaC/view?usp=sharing). Then, run | ||
```.bash | ||
python dataset_tool.py --source downloads/captions_val2014.json \ | ||
--dest data/coco_val256.zip --resolution 256x256 --transform center-crop | ||
``` | ||
|
||
It is recommend to prepare the dataset zip at the highest possible resolution, e.g. for FFHQ, the zip should contain images with 1024x1024 pixels. When training lower-resolution models, the training script can downsample the images on the fly. | ||
|
||
### WebDataset | ||
For preparing webdatasets, we used the excellent [img2dataset](https://github.com/rom1504/img2dataset) tool. Documentation for downloading different datasets can be found [here](https://github.com/rom1504/img2dataset/tree/main/dataset_examples). For our experiments, we used data from the following sources: [CC3M](https://ai.google.com/research/ConceptualCaptions/download), [CC12M](https://github.com/google-research-datasets/conceptual-12m), [YFFC100m](https://huggingface.co/datasets/dalle-mini/YFCC100M_OpenAI_subset), [Redcaps](https://huggingface.co/datasets/red_caps), [LAION-aesthetic-6plus](https://huggingface.co/datasets/ChristophSchuhmann/improved_aesthetics_6plus). | ||
|
||
The joint dataset should have the following structure | ||
```joint_dataset/ | ||
joint_dataset/cc3m/0000.tar | ||
joint_dataset/cc3m/0001.tar | ||
... | ||
joint_dataset/laion_6plus/0000.tar | ||
joint_dataset/laion_6plus/0001.tar | ||
... | ||
``` | ||
## Training ## | ||
|
||
Training StyleGAN-T with full capacity on MYDATASET.zip at a resolution of 64x64 pixels: | ||
|
||
``` | ||
python -m torch.distributed.run --standalone --nproc_per_node 1 train.py \ | ||
--outdir ./training-runs/ --cfg full --data ./data/MYDATASET.zip \ | ||
--img-resolution 64 --batch 128 --batch-gpu 8 --kimg 25000 --metrics fid50k_full | ||
``` | ||
|
||
- The above commands can be parallelized across multiple GPUs by adjusting ```--nproc_per_node```. | ||
- ```--batch``` specifies the overall batch size, ```--batch-gpu``` specifies the batch size per GPU. Be aware that ```--batch-gpu``` is also a hyperparameter as the discriminator uses (local) BatchNorm; We generally recommend ```--batch-gpu``` of 4 or 8. | ||
The training loop will automatically accumulate gradients if you use fewer GPUs until the overall batch size is reached. | ||
- Samples and metrics are saved in ```outdir```. You can inspect ```METRIC_NAME.json``` or run tensorboard in ```training-runs/``` to monitor the training progress. | ||
- The generator will be conditional if the dataset contains text labels; otherwise, it will be unconditional. | ||
- For a webdataset comprised of different subsets, the data path should point to the joint parent directory: ```--data path/to/joint_dataset/``` | ||
|
||
To use the same configuration we used for our ablation study, use ```--cfg lite```. If you want direct control over network parameters, use a custom config. E.g., a smaller models which has 1 residual block, a capacity multiplier of 16384 and a maximum channel count of 256, run | ||
|
||
``` | ||
python -m torch.distributed.run --standalone --nproc_per_node 1 train.py \ | ||
--outdir ./training-runs/ --data ./data/MYDATASET.zip \ | ||
--img-resolution 64 --batch 128 --batch-gpu 8 --kimg 25000 --metrics fid50k_full \ | ||
--cfg custom --cbase 16384 --cmax 256 --res-blocks 1 | ||
``` | ||
|
||
For a description of all input arguments, run ```python train.py --help``` | ||
|
||
### Starting from pretrained checkpoints | ||
|
||
If you want to use a previously trained model, you can start from a checkpoint by specifying its path adding ```--resume PATH_TO_NETWORK_PKL```. If you want to continue training from where you left off in a previous run, you can also specify the number of images processed in that run using ```--resume-kimg XXX```, where XXX is that number. | ||
|
||
### Training modes | ||
|
||
By default, all layers of the generator are trained and the CLIP text encoder is frozen. If you want to train only the text encoder, provide ```--train-mode text-encoder```. | ||
|
||
If you want to do progressive growing, first train a model at 64x64 pixels. Then provide the path to this pretrained network via ```-resume```, the new target resolution via ```--img-resolution``` and use ```--train-mode freeze64``` to freeze the blocks of the 64x64 model and only train the high resolution layers. For example: | ||
|
||
``` | ||
python -m torch.distributed.run --standalone --nproc_per_node 1 train.py \ | ||
--outdir ./training-runs/ --data ./data/MYDATASET.zip \ | ||
--img-resolution 512 --batch 128 --batch-gpu 4 --kimg 25000 \ | ||
--cfg lite --resume PATH_TO_NETWORK_64 --train-mode freeze64 | ||
``` | ||
|
||
## Generating Samples ## | ||
To generate samples with a given network, run | ||
``` | ||
python gen_images.py --network PATH_TO_NETWORK_PKL \ | ||
--prompt 'A painting of a fox in the style of starry night.' \ | ||
--truncation 1 --outdir out --seeds 0-29 | ||
``` | ||
|
||
For a description of all input arguments, run ```python gen_images.py --help``` | ||
|
||
## Quality Metrics ## | ||
To calculate metrics for a specific network snapshot, run | ||
|
||
``` | ||
python calc_metrics.py --metrics METRIC_NAME --network PATH_TO_NETWORK_PKL | ||
``` | ||
|
||
Metric computation is only supported on zip datasets, not webdatasets. The zero-shot COCO metrics expect a ```coco_val256.zip``` to be present in the same folder as the training dataset. Alternatively, one can explicitely set an environment variable as follows: ```export COCOPATH=path/to/coco_val256.zip```. | ||
|
||
To see the available metrics, run ```python calc_metrics.py --help``` | ||
|
||
|
||
## License ## | ||
|
||
Copyright © 2023, NVIDIA Corporation. All rights reserved. | ||
|
||
This work is made available under the [Nvidia Source Code License](https://nvlabs.github.io/stylegan2-ada-pytorch/license.html). | ||
|
||
Excempt are the files ```training/diffaug.py``` and ```networks/vit_utils.py``` which are partially or fully based on third party github repositories. These two files are copyright their respective authors and under their respective licenses; we include the original license and link to the source at the beginning of the files. | ||
|
||
## Development ## | ||
|
||
This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests. | ||
|
||
## Citation ## | ||
|
||
```bibtex | ||
@InProceedings{Sauer2023ARXIV, | ||
author = {Axel Sauer and Tero Karras and Samuli Laine and Andreas Geiger and Timo Aila}, | ||
title = {{StyleGAN-T}: Unlocking the Power of {GANs} for Fast Large-Scale Text-to-Image Synthesis}, | ||
journal = {{arXiv.org}}, | ||
volume = {abs/2301.09515}, | ||
year = {2023}, | ||
url = {https://arxiv.org/abs/2301.09515}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# NVIDIA CORPORATION and its licensors retain all intellectual property | ||
# and proprietary rights in and to this software, related documentation | ||
# and any modifications thereto. Any use, reproduction, disclosure or | ||
# distribution of this software and related documentation without an express | ||
# license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
|
||
"""Calculate quality metrics for previous training run or pretrained network pickle.""" | ||
|
||
import os | ||
import json | ||
import copy | ||
|
||
import torch | ||
import dill | ||
import click | ||
|
||
import dnnlib | ||
from metrics import metric_main | ||
from metrics import metric_utils | ||
from torch_utils import misc | ||
from torch_utils import custom_ops | ||
from torch_utils import distributed as dist | ||
from torch_utils.ops import conv2d_gradfix | ||
|
||
|
||
def parse_comma_separated_list(s): | ||
if isinstance(s, list): | ||
return s | ||
if s is None or s.lower() == 'none' or s == '': | ||
return [] | ||
return s.split(',') | ||
|
||
|
||
@click.command("cli", context_settings={'show_default': True}) | ||
@click.pass_context | ||
@click.option('--network', 'network_pkl', help='Network pickle filename or URL', type=str, required=True) | ||
@click.option('--metrics', help='Quality metrics', type=parse_comma_separated_list, default='fid50k_full') | ||
@click.option('--data', help='Dataset to evaluate against', type=str) | ||
@click.option('--mirror', help='Enable dataset x-flips', type=bool) | ||
@click.option('--truncation', help='Truncation', type=float, default=1.0) | ||
def calc_metrics( | ||
ctx, | ||
network_pkl: str, | ||
metrics: list, | ||
data: str, | ||
mirror: bool, | ||
truncation: float, | ||
): | ||
"""Calculate quality metrics for previous training run or pretrained network pickle. | ||
Examples: | ||
\b | ||
# Previous training run: look up options automatically, save result to JSONL file. | ||
python calc_metrics.py --metrics=cs10k,fid50k_full \\ | ||
--network=~/training-runs/00000-mydataset@512-custom-gpus1-b4-bgpu2/network-snapshot-000000.pkl | ||
\b | ||
# Pre-trained network pickle: specify dataset explicitly, print result to stdout. | ||
python calc_metrics.py --metrics=fid50k_full --data=~/datasets/mydataset.zip --mirror=1 \\ | ||
--network=~/training-runs/00000-mydataset@512-custom-gpus1-b4-bgpu2/network-snapshot-000000.pkl | ||
\b | ||
General metrics: | ||
fid50k_full Frechet inception distance against the full dataset (50k generated samples). | ||
fid10k_full Frechet inception distance against the full dataset (10k generated samples). | ||
cs10k Clip score (10k generated samples). | ||
pr50k3_full Precision and recall againt the full dataset (50k generated samples, neighborhood size=3). | ||
\b | ||
Zero-shot COCO metrics: | ||
fid30k_coco64 Frechet inception distance against the COCO validation set (30k generated samples). | ||
fid30k_coco256 Frechet inception distance against the COCO validation set (30k generated samples). | ||
cs10k_coco Clip score on the COCO validation set (10k generated samples). | ||
""" | ||
|
||
# Init distributed | ||
torch.multiprocessing.set_start_method('spawn') | ||
dist.init() | ||
device = torch.device('cuda') | ||
|
||
# Validate arguments. | ||
G_kwargs=dnnlib.EasyDict(truncation_psi=truncation) | ||
|
||
if not all(metric_main.is_valid_metric(metric) for metric in metrics): | ||
err = ['--metrics can only contain the following values:'] + metric_main.list_valid_metrics() | ||
ctx.fail('\n'.join(err)) | ||
|
||
# Load network. | ||
if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): | ||
ctx.fail('--network must point to a file or URL') | ||
|
||
dist.print0(f'Loading network from "{network_pkl}"...') | ||
with dnnlib.util.open_url(network_pkl, verbose=True) as f: | ||
network_dict = dill.load(f) | ||
G = network_dict['G_ema'] # subclass of torch.nn.Module | ||
|
||
# Initialize dataset options. | ||
if data is not None: | ||
dataset_kwargs = dnnlib.EasyDict(class_name='training.data_zip.ImageFolderDataset', path=data) | ||
elif network_dict.get('training_set_kwargs') is not None: | ||
dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) | ||
else: | ||
ctx.fail('Could not look up dataset options; please specify --data') | ||
|
||
# Finalize dataset options. | ||
dataset_kwargs.resolution = G.img_resolution | ||
dataset_kwargs.use_labels = (G.c_dim != 0) | ||
if mirror is not None: | ||
dataset_kwargs.xflip = mirror | ||
|
||
# Print dataset options. | ||
dist.print0('Dataset options:') | ||
dist.print0(json.dumps(dataset_kwargs, indent=2)) | ||
|
||
# Locate run dir. | ||
run_dir = None | ||
if os.path.isfile(network_pkl): | ||
pkl_dir = os.path.dirname(network_pkl) | ||
if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): | ||
run_dir = pkl_dir | ||
|
||
# Launch processes. | ||
dist.print0('Launching processes...') | ||
dnnlib.util.Logger(should_flush=True) | ||
if dist.get_rank() != 0: | ||
custom_ops.verbosity = 'none' | ||
|
||
# Configure torch. | ||
torch.backends.cuda.matmul.allow_tf32 = False | ||
torch.backends.cudnn.allow_tf32 = False | ||
conv2d_gradfix.enabled = True | ||
|
||
# Print network summary. | ||
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) | ||
if dist.get_rank() == 0: | ||
z = torch.empty([1, G.z_dim], device=device) | ||
c = torch.empty([1, G.c_dim], device=device) | ||
misc.print_module_summary(G, [z, c]) | ||
|
||
# Calculate each metric. | ||
for metric in metrics: | ||
dist.print0(f'Calculating {metric}...') | ||
|
||
progress = metric_utils.ProgressMonitor(verbose=True) | ||
result_dict = metric_main.calc_metric(metric=metric, G=G, G_kwargs=G_kwargs, dataset_kwargs=dataset_kwargs, | ||
num_gpus=dist.get_world_size(), rank=dist.get_rank(), device=device, progress=progress) | ||
|
||
if dist.get_rank() == 0: | ||
metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=network_pkl) | ||
dist.print0() | ||
|
||
# Done. | ||
dist.print0('Exiting...') | ||
|
||
|
||
if __name__ == "__main__": | ||
calc_metrics() # pylint: disable=no-value-for-parameter |
Oops, something went wrong.