diff --git a/.github/workflows/black-format.yaml b/.github/workflows/black-format.yaml new file mode 100644 index 0000000..40e0ef5 --- /dev/null +++ b/.github/workflows/black-format.yaml @@ -0,0 +1,59 @@ +name: Format Code + +on: + push: + paths: + - '**.py' + - 'pyproject.toml' + - '.github/workflows/black-format.yml' + +jobs: + format: + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.head_ref }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install formatters + run: | + python -m pip install --upgrade pip + pip install black isort + + - name: Create pyproject.toml + run: | + cat << EOF > pyproject.toml + [tool.black] + line-length = 120 + target-version = ['py310'] + include = '\.pyi?$' + + [tool.isort] + profile = "black" + line_length = 120 + multi_line_output = 3 + include_trailing_comma = true + force_grid_wrap = 0 + use_parentheses = true + ensure_newline_before_comments = true + EOF + + - name: Format code + run: | + isort . + black . + + - name: Commit changes + uses: stefanzweifel/git-auto-commit-action@v5 + with: + commit_message: "style: format code with Black and isort" + commit_user_name: "github-actions[bot]" + commit_user_email: "github-actions[bot]@users.noreply.github.com" + commit_author: "github-actions[bot] " diff --git a/.github/workflows/validate-mamba-env.yaml b/.github/workflows/validate-mamba-env.yaml new file mode 100644 index 0000000..51a4f63 --- /dev/null +++ b/.github/workflows/validate-mamba-env.yaml @@ -0,0 +1,39 @@ +name: Validate Mamba Environment + +on: + workflow_dispatch: + push: + paths: + - 'env.yaml' + - '.github/workflows/validate-mamba-env.yml' + pull_request: + paths: + - 'env.yaml' + - '.github/workflows/validate-mamba-env.yml' + +jobs: + validate-environment: + name: Validate Environment Setup + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Setup Micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: env.yaml + init-shell: bash + + - name: Verify Environment + shell: bash -el {0} + run: | + micromamba activate torch-provlae + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torchvision; print('Torchvision:', torchvision.__version__)" + + - name: List Environment Info + if: always() + shell: bash -el {0} + run: | + micromamba list \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f9adf24 --- /dev/null +++ b/.gitignore @@ -0,0 +1,165 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +data/ +output/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7e4c456 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Shosuke Suzuki + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c05ef15 --- /dev/null +++ b/README.md @@ -0,0 +1,179 @@ +# pytorch pro-VLAE +![MIT LICENSE](https://img.shields.io/badge/LICENSE-MIT-blue) +[![Format Code](https://github.com/suzuki-2001/pytorch-proVLAE/actions/workflows/black-format.yaml/badge.svg)](https://github.com/suzuki-2001/pytorch-proVLAE/actions/workflows/black-format.yaml) +[![Validate Mamba Environment](https://github.com/suzuki-2001/pytorch-proVLAE/actions/workflows/validate-mamba-env.yaml/badge.svg)](https://github.com/suzuki-2001/pytorch-proVLAE/actions/workflows/validate-mamba-env.yaml) + +
+ +This is a PyTorch implementation of the paper [PROGRESSIVE LEARNING AND DISENTANGLEMENT OF HIERARCHICAL REPRESENTATIONS](https://openreview.net/forum?id=SJxpsxrYPS) by Zhiyuan et al, [ICLR 2020](https://iclr.cc/virtual_2020/poster_SJxpsxrYPS.html). +The code is based on the official TensorFlow implementation [here](https://github.com/Zhiyuan1991/proVLAE). + +
+ + + + + + + + + + +
+ +⬆︎ Traverse the latent space of proVLAE using four datasets: 3D Shapes (top-left), MNIST (top-right), and preliminary experimental results from 3DIdent (bottom-left) and MPI3D (bottom-right), where hyperparameter tuning is still in progress. + +  + +This implementation enables flexible configuration of VAE architecture by introducing dynamic size management for arbitrary input image sizes, automatic calculation of maximum possible ladder layers based on input dimensions, and adaptive handling of latent space dimensionality. The model automatically adjusts its network depth and feature map sizes by calculating appropriate intermediate dimensions, while ensuring minimum feature map size and proper dimension handling in flatten/unflatten operations. These adaptations allow users to freely specify z_dim, number of ladder layers, and input image size. + +![figure-1 in pro-vlae paper](./md/provlae-figure1.png) + +> Figure 1: Progressive learning of hierarchical representations. White blocks and solid lines are VAE +> models at the current progression. α is a fade-in coefficient for blending in the new network component. Gray circles and dash line represents (optional) constraining of the future latent variables. + +⬆︎ ladder and progressive learning of hierarchical representations. + +  + +## Installation +We recommend using [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html)(via [miniforge](https://github.com/conda-forge/miniforge)) for faster installation of dependencies, but you can also use [conda](https://docs.anaconda.com/miniconda/miniconda-install/). +```bash +mamba env create -f env.yaml # or conda +mamba activate torch-provlae +``` + +  + +## Usage +You can train pytorch pro-VLAE with the following command. Sample hyperparameters and train configuration are provided in [scripts directory](./scripts/). +If a checkpoint is available, setting the mode argument to "visualize" allows you to only inspect the latent traversal, Please keep the parameter settings the same as checkpoint. + +
+ +```bash +# all progressive training steps +python train.py \ + --dataset shapes3d \ + --mode seq_train \ + --batch_size 100 \ + --num_epochs 15 \ + --learning_rate 5e-4 \ + --beta 15 \ + --z_dim 3 \ + --hidden_dim 64 \ + --fade_in_duration 5000 \ + --optim adamw \ + --image_size 64 \ + --chn_num 3 \ + --output_dir ./output/shapes3d/ + +# training with distributed data parallel +torchrun --nproc_per_node=2 train_ddp.py \ + --distributed True \ + --mode seq_train \ + --dataset ident3d \ + --num_ladders 3 \ + --batch_size 128 \ + --num_epochs 30 \ + --learning_rate 5e-4 \ + --beta 1 \ + --z_dim 3 \ + --coff 0.5 \ + --hidden_dim 64 \ + --fade_in_duration 5000 \ + --output_dir ./output/ident3d/ \ + --optim adamw +``` + +
+ +| Argument | Default | Description | +|----------|---------------|-------------| +| `dataset` | "shapes3d" | Dataset to use (mnist, fashionmnist, dsprites, shapes3d, mpi3d, ident3d,celeba, flowers102, dtd, imagenet) | +| `data_path` | "./data" | Path to dataset storage | +| `z_dim` | 3 | Dimension of latent variables | +| `num_ladders` | 3 | Number of ladders (hierarchies) in pro-VLAE | +| `beta` | 8.0 | β parameter for pro-VLAE | +| `learning_rate` | 5e-4 | Learning rate | +| `fade_in_duration` | 5000 | Number of steps for fade-in period | +| `image_size` | 64 | Input image size | +| `chn_num` | 3 | Number of input image channels | +| `train_seq` | 1 | Current training sequence number (`indep_train mode` only) | +| `batch_size` | 100 | Batch size | +| `num_epochs` | 1 | Number of epochs | +| `mode` | "seq_train" | Execution mode ("seq_train", "indep_train", "visualize") | +| `hidden_dim` | 32 | Hidden layer dimension | +| `coff` | 0.5 | Coefficient for KL divergence | +| `output_dir` | "outputs" | Output directory | +| `compile_mode` | "default" | PyTorch compilation mode | +| `on_cudnn_benchmark` | True | Enable/disable cuDNN benchmark | +| `optim` | "adam" | Optimization algorithm (adam, adamw, sgd, lamb, diffgrad, madgrad) | + +
+ +Mode descriptions: +- `seq_train`: Sequential training from ladder 1 to `num_ladders` +- `indep_train`: Independent training of specified `train_seq` ladder +- `visualize`: Visualize latent space using trained model (need checkpoints) + +  + +## Pytorch Optimization + +- __Fast training__: performance tuning referred to in [PyTorch Performance Tuning Guide - Szymon Migacz, NVIDIA](https://t.co/7CIDWfrI0J). + - [torch.backends.cudnn.benchmark](https://pytorch.org/docs/stable/backends.html#torch.backends.cudnn.benchmark) + - [Automatic Mixed Precision (FP16/BF16)](https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html) + - [Asynchronous GPU Copies](https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html) + - [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) + - [Tensor Float 32 (>= Ampere)](https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html) + - [Distributed Data Parallel](https://pytorch.org/docs/stable/notes/ddp.html) (experimental) + + +- __Optimizer__: DiffGrad, Lamb, MADGRAD is implemented by [jettify/pytorch-optimizer](https://github.com/jettify/pytorch-optimizer), other optimizer based [torch.optim package](torch.serialization.add_safe_globals([set])). + - Adam + - AdamW + - SGD + - [DiffGrad](https://arxiv.org/abs/1909.11015) + - [Lamb](https://arxiv.org/abs/1904.00962) + - [MADGRAD](https://arxiv.org/abs/2101.11075) + + +  + +## Dataset +We provided various datasets used in the original pro-VLAE paper and additional disentanglement datasets. The dataset is automatically downloaded and preprocessed when you specify the dataset name in the `--dataset` argument except imagenet. + +### Datasets used in the original pro-VLAE paper +1. [MNIST](https://yann.lecun.com/exdb/mnist/): `mnist` +2. [Disentanglement testing Sprites dataset (dSprites)](https://github.com/google-deepmind/dsprites-dataset): `dsprites` +3. [3D Shapes](https://github.com/google-deepmind/3d-shapes): `shapes3d` +4. [Large-scale CelebFaces Attributes (CelebA)](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html): `celeba` + +### Additional Disentanglement Datasets +1. [MPI3D Disentanglement Datasets](https://github.com/rr-learning/disentanglement_dataset?tab=readme-ov-file): `mpi3d` +2. [3DIdent](https://paperswithcode.com/dataset/3dident): `ident3d` + +### Other Datasets (experimental) +1. [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist): `fashionmnist` +2. [Describable Textures Dataset (DTD)](https://www.robots.ox.ac.uk/~vgg/data/dtd/): `dtd` +3. [102 Category Flower Dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/): `flowers102` +4. [ImageNet](https://www.image-net.org/): `imagenet` + +  + +## Work in Progress +Hyperparameter optimization (beta, coff, fade-in duration, learning rates) and implementation of disentanglement metrics (MIG for detecting factor splitting, MIG-sup for factor entanglement) are currently under development. Benchmark results will be provided in future updates. + +  + +## License +This repository is licensed under the MIT License - see the [LICENSE](./LICENSE) file for details. This follows the licensing of the [original implementation license](https://github.com/Zhiyuan1991/proVLAE/blob/master/LICENSE) by Zhiyuan. + +  + +*** +*This repository is a contribution to [AIST (National Institute of Advanced Industrial Science and Technology)](https://www.aist.go.jp/) project. + +[Human Informatics and Interaction Research Institute](https://unit.aist.go.jp/hiiri/), [Neuronrehabilitation Research Group](https://unit.aist.go.jp/hiiri/nrehrg/), \ +Shosuke Suzuki, Ryusuke Hayashi diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..904e6d4 --- /dev/null +++ b/dataset.py @@ -0,0 +1,657 @@ +import os +import tarfile +from dataclasses import dataclass + +import h5py +import numpy as np +import requests +import torch +import torchvision +import torchvision.transforms as transforms +from loguru import logger +from PIL import Image +from requests.exceptions import ConnectionError, HTTPError, RequestException, Timeout +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + + +@dataclass +class DatasetConfig: + """Dataset configuration class""" + + name: str + image_size: int + chn_num: int + default_path: str + + @classmethod + def get_config(cls, dataset_name: str) -> "DatasetConfig": + """Get dataset configuration by name""" + configs = { + "mnist": cls("mnist", 28, 1, "./data"), + "fashionmnist": cls("fashionmnist", 28, 1, "./data"), + "dsprites": cls("dsprites", 64, 1, "./data"), + "shapes3d": cls("shapes3d", 64, 3, "./data"), + "celeba": cls("celeba", 128, 3, "./data"), + "flowers102": cls("flowers102", 128, 3, "./data"), + "dtd": cls("dtd", 128, 3, "./data"), + "imagenet": cls("imagenet", 128, 3, "./data/imagenet"), + "mpi3d": cls("mpi3d", 64, 3, "./data"), + "ident3d": cls("ident3d", 128, 3, "./data"), + } + + if dataset_name not in configs: + raise ValueError(f"Unknown dataset: {dataset_name}") + + return configs[dataset_name] + + +def download_file(url, filename): + """Download file with progress bar""" + try: + response = requests.get(url, stream=True) + total_size = int(response.headers.get("content-length", 0)) + + with ( + open(filename, "wb") as f, + tqdm( + desc=filename, + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as pbar, + ): + for chunk in response.iter_content(chunk_size=8192): + size = f.write(chunk) + pbar.update(size) + + logger.success("Dowload complete") + except HTTPError as e: + logger.error(f"HTTP error occurred: {e}") + except ConnectionError as e: + logger.error(f"Connection error occurred: {e}") + except Timeout as e: + logger.error(f"Timeout error occurred: {e}") + except RequestException as e: + logger.error(f"An error occurred during the request: {e}") + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + finally: + if os.path.exists(filename) and os.path.getsize(filename) == 0: + os.remove(filename) + logger.warning("Incomplete file removed due to download failure") + + +def dsprites_download(root="./data"): + """Download dsprites dataset if not present""" + dsprites_dir = os.path.join(root, "dsprites") + npz_path = os.path.join(dsprites_dir, "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") + + if not os.path.exists(npz_path): + os.makedirs(dsprites_dir, exist_ok=True) + url = "https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz" + download_file(url, npz_path) + + return npz_path + + +def shapes3d_download(root="./data"): + """Download 3dshapes dataset if not present""" + shapes_dir = os.path.join(root, "shapes3d") + h5_path = os.path.join(shapes_dir, "3dshapes.h5") + + # GCS URL for 3dshapes dataset + url = "https://storage.googleapis.com/3d-shapes/3dshapes.h5" + + if not os.path.exists(h5_path): + os.makedirs(shapes_dir, exist_ok=True) + logger.info("Downloading 3dshapes dataset...") + download_file(url, h5_path) + else: + logger.info("3dshapes dataset already exists.") + + return h5_path + + +def mpi3d_download(root="./data", variant="toy"): + """Download MPI3D dataset if not present""" + mpi3d_dir = os.path.join(root, "mpi3d") + os.makedirs(mpi3d_dir, exist_ok=True) + + variants = { + "toy": { + "url": "https://storage.googleapis.com/mpi3d_disentanglement_dataset/data/mpi3d_toy.npz", + "filename": "mpi3d_toy.npz", + }, + "realistic": { + "url": "https://storage.googleapis.com/mpi3d_disentanglement_dataset/data/mpi3d_realistic.npz", + "filename": "mpi3d_realistic.npz", + }, + "real": { + "url": "https://storage.googleapis.com/mpi3d_disentanglement_dataset/data/real.npz", + "filename": "mpi3d_real.npz", + }, + } + + if variant not in variants: + raise ValueError(f"Unknown MPI3D variant: {variant}. Choose from {list(variants.keys())}") + + npz_path = os.path.join(mpi3d_dir, variants[variant]["filename"]) + + if not os.path.exists(npz_path): + logger.info(f"Downloading MPI3D {variant} dataset...") + download_file(variants[variant]["url"], npz_path) + + return npz_path + + +def ident3d_download(root="./data"): + """Download 3DIdent dataset if not present""" + os.makedirs(root, exist_ok=True) + + train_dir = os.path.join(root, "ident3d/train") + test_dir = os.path.join(root, "ident3d/test") + train_tar = os.path.join(root, "3dident_train.tar") + test_tar = os.path.join(root, "3dident_test.tar") + + # Download and extract training data + if not os.path.exists(train_dir): + try: + logger.info("Downloading 3DIdent training dataset...") + download_file( + "https://zenodo.org/records/4502485/files/3dident_train.tar?download=1", + train_tar, + ) + logger.info("Extracting training dataset...") + with tarfile.open(train_tar, mode="r") as tar: + tar.extractall(root) + finally: + if os.path.exists(train_tar): + os.remove(train_tar) + + # Download and extract test data + if not os.path.exists(test_dir): + try: + logger.info("Downloading 3DIdent test dataset...") + download_file( + "https://zenodo.org/records/4502485/files/3dident_test.tar?download=1", + test_tar, + ) + logger.info("Extracting test dataset...") + with tarfile.open(test_tar, mode="r") as tar: + tar.extractall(root) + finally: + if os.path.exists(test_tar): + os.remove(test_tar) + + downloaded_path = os.path.join(root, "3dident") + if os.path.exists(downloaded_path): + os.rename(downloaded_path, os.path.join(root, "ident3d")) + return train_dir, test_dir + + +class DSpritesDataset(Dataset): + def __init__(self, npz_path, transform=None): + data = np.load(npz_path, allow_pickle=True) + self.images = data["imgs"] * 255 + + if transform is None: + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + self.transform = transform + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + image = self.images[idx] + if self.transform: + image = self.transform(image) + return image, 0 + + +class DSprites: + def __init__(self, root="./data", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("dsprites") + npz_path = dsprites_download(root) + + dataset = DSpritesDataset(npz_path) + + train_size = int(0.9 * len(dataset)) + test_size = len(dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) + + self.train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class MPI3DDataset(Dataset): + def __init__(self, npz_path, transform=None): + data = np.load(npz_path) + self.images = data["images"] + + # dataset variation + n_images = len(self.images) + if n_images == 1036800: + self.dataset_type = "regular" + elif n_images == 460800: + self.dataset_type = "complex" + else: + raise ValueError(f"Unexpected number of images: {n_images}") + + self.images = self.images.astype(np.float32) / 255.0 + + if transform is None: + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + self.transform = transform + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + image = self.images[idx].copy() + + if self.transform: + if isinstance(image, np.ndarray): + image = self.transform(image) + + return image, 0 + + +class MPI3D: + def __init__(self, root="./data", variant="toy", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("mpi3d") + npz_path = mpi3d_download(root, variant) + + dataset = MPI3DDataset(npz_path) + + train_size = int(0.9 * len(dataset)) + test_size = len(dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) + + self.train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class Ident3DDataset(Dataset): + def __init__(self, root_dir, split="train", transform=None): + """ + Args: + root_dir (str): Root directory + split (str): 'train' or 'test' + transform (callable, optional): Transform to apply on the images + """ + self.root_dir = os.path.join(root_dir, "ident3d", split) + self.images_dir = os.path.join(self.root_dir, "images") + + # Get list of image files + self.image_files = sorted([f for f in os.listdir(self.images_dir) if f.endswith((".png", ".jpg", ".jpeg"))]) + + if transform is None: + self.transform = transforms.Compose( + [ + transforms.Resize(128), + transforms.ToTensor(), + ] + ) + else: + self.transform = transform + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, idx): + img_path = os.path.join(self.images_dir, self.image_files[idx]) + image = Image.open(img_path).convert("RGB") + + if self.transform: + image = self.transform(image) + + return image, 0 + + +class Ident3D: + def __init__(self, root="./data", batch_size=32, num_workers=4): + train_dir, test_dir = ident3d_download(root=root) + self.config = DatasetConfig.get_config("ident3d") + + train_dataset = Ident3DDataset(root, split="train") + test_dataset = Ident3DDataset(root, split="test") + + self.train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class MNIST: + def __init__(self, root="./data", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("mnist") + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + trainset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform) + testset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform) + + self.train_loader = DataLoader( + trainset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class FashionMNIST: + def __init__(self, root="./data", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("fashionmnist") + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + trainset = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform) + testset = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform) + + self.train_loader = DataLoader( + trainset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class Shapes3DDataset(Dataset): + def __init__(self, h5_path, transform=None): + with h5py.File(h5_path, "r") as f: + self.data = f["images"][:] # [N, H, W, C] + + self.data = self.data.astype(np.float32) / 255.0 + self.transform = transform + self.labels = torch.zeros(len(self.data)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + image = self.data[idx] # [H, W, C] + + if self.transform: + image = self.transform(image) # [C, H, W] + else: + image = torch.from_numpy(image.transpose(2, 0, 1)).float() # [C, H, W] + + return image, self.labels[idx] + + +class Shapes3D: + def __init__(self, root="./data", batch_size=32, num_workers=4): + h5_path = shapes3d_download(root=root) + self.config = DatasetConfig.get_config("shapes3d") + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + + dataset = Shapes3DDataset(h5_path, transform=transform) + + train_size = int(0.9 * len(dataset)) + test_size = len(dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) + + self.train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class CelebA: + def __init__(self, root="./data", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("celeba") + transform = transforms.Compose( + [ + transforms.CenterCrop(178), + transforms.Resize(128), + transforms.ToTensor(), + ] + ) + + dataset = torchvision.datasets.CelebA(root=root, split="train", download=True, transform=transform) + test_dataset = torchvision.datasets.CelebA(root=root, split="test", download=True, transform=transform) + + self.train_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class Flowers102: + def __init__(self, root="./data", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("flowers102") + transform = transforms.Compose( + [ + transforms.Resize(146), + transforms.CenterCrop(128), + transforms.ToTensor(), + ] + ) + + trainset = torchvision.datasets.Flowers102(root=root, split="train", download=True, transform=transform) + testset = torchvision.datasets.Flowers102(root=root, split="test", download=True, transform=transform) + + self.train_loader = DataLoader( + trainset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class DTD: + def __init__(self, root="./data", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("dtd") + transform = transforms.Compose( + [ + transforms.Resize(146), + transforms.CenterCrop(128), + transforms.ToTensor(), + ] + ) + + trainset = torchvision.datasets.DTD(root=root, split="train", download=True, transform=transform) + testset = torchvision.datasets.DTD(root=root, split="test", download=True, transform=transform) + + self.train_loader = DataLoader( + trainset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.test_loader = DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.test_loader + + def get_config(self): + return self.config + + +class ImageNet: + def __init__(self, root="./data/imagenet", batch_size=32, num_workers=4): + self.config = DatasetConfig.get_config("imagenet") + transform = transforms.Compose( + [ + transforms.Resize(146), + transforms.CenterCrop(128), + transforms.ToTensor(), + ] + ) + + trainset = torchvision.datasets.ImageNet(root=root, split="train", transform=transform) + valset = torchvision.datasets.ImageNet(root=root, split="val", transform=transform) + + self.train_loader = DataLoader( + trainset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + self.val_loader = DataLoader( + valset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + ) + + def get_data_loader(self): + return self.train_loader, self.val_loader + + def get_config(self): + return self.config diff --git a/env.yaml b/env.yaml new file mode 100644 index 0000000..aef58d8 --- /dev/null +++ b/env.yaml @@ -0,0 +1,18 @@ +name: torch-provlae +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - python>=3.10 + - numpy + - matplotlib + - pytorch>=2.5.1 + - torchvision>=0.20.1 + - pip: + - h5py + - loguru + - imageio + - moviepy + - gdown + - torch_optimizer diff --git a/md/ident3d.gif b/md/ident3d.gif new file mode 100644 index 0000000..2dfc5e8 Binary files /dev/null and b/md/ident3d.gif differ diff --git a/md/mnist.gif b/md/mnist.gif new file mode 100644 index 0000000..135140d Binary files /dev/null and b/md/mnist.gif differ diff --git a/md/mpi3d.gif b/md/mpi3d.gif new file mode 100644 index 0000000..0e6f03e Binary files /dev/null and b/md/mpi3d.gif differ diff --git a/md/provlae-figure1.png b/md/provlae-figure1.png new file mode 100644 index 0000000..a986630 Binary files /dev/null and b/md/provlae-figure1.png differ diff --git a/md/shapes3d.gif b/md/shapes3d.gif new file mode 100644 index 0000000..4f3c84f Binary files /dev/null and b/md/shapes3d.gif differ diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..a57f0e7 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .provlae import ProVLAE diff --git a/models/provlae.py b/models/provlae.py new file mode 100644 index 0000000..c0abec5 --- /dev/null +++ b/models/provlae.py @@ -0,0 +1,292 @@ +from math import ceil, log2 + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions import Bernoulli, Normal + + +class ProVLAE(nn.Module): + def __init__( + self, + image_size, + z_dim=3, + chn_num=3, + num_ladders=3, + hidden_dim=32, + fc_dim=256, + beta=1.0, + learning_rate=5e-4, + fade_in_duration=5000, + pre_kl=True, + coff=0.5, + train_seq=1, + ): + super(ProVLAE, self).__init__() + + # Calculate network architecture parameters + self.hidden_dims = [hidden_dim] * num_ladders + self.image_size = image_size + self.target_size = 2 ** ceil(log2(image_size)) # Nearest power of 2 + self.min_size = 4 # Minimum feature map size + + # Calculate the number of possible downsampling steps + self.max_steps = int(log2(self.target_size // self.min_size)) + self.num_stages = max(self.max_steps, 1) + self.num_ladders = min(num_ladders, self.num_stages) + + self.z_dim = max(1, z_dim) + self.chn_num = chn_num + self.beta = beta + self.pre_kl = pre_kl + self.coff = coff + self.learning_rate = learning_rate + self.fade_in_duration = fade_in_duration + self.train_seq = min(train_seq, self.num_ladders) + + # Calculate encoder sizes + self.encoder_sizes = [self.target_size] + current_size = self.target_size + for _ in range(self.num_ladders + 1): # +1 for initial size + current_size = max(current_size // 2, self.min_size) + self.encoder_sizes.append(current_size) + + # Dynamic hidden dimensions + if len(self.hidden_dims) < self.num_ladders: + self.hidden_dims.extend([self.hidden_dims[-1]] * (self.num_ladders - len(self.hidden_dims))) + self.hidden_dims = self.hidden_dims[: self.num_ladders] + + # Base setup + self.activation = nn.ELU() # or LeakyReLU + self.q_dist = Normal + self.x_dist = Bernoulli + self.prior_params = nn.Parameter(torch.zeros(self.z_dim, 2)) + + # Create encoder layers + self.encoder_layers = nn.ModuleList() + current_channels = chn_num + for dim in self.hidden_dims: + self.encoder_layers.append(self._create_conv_block(current_channels, dim)) + current_channels = dim + + # Create ladder networks + self.ladders = nn.ModuleList() + for i in range(self.num_ladders): + ladder_input_size = self.encoder_sizes[i + 1] + self.ladders.append(self._create_ladder_block(self.hidden_dims[i], fc_dim, self.z_dim, ladder_input_size)) + + # Create generator networks + self.generators = nn.ModuleList() + for i in range(self.num_ladders): + size = self.encoder_sizes[i + 1] + self.generators.append(self._create_generator_block(self.z_dim, fc_dim, (self.hidden_dims[i], size, size))) + + # Create decoder layers + self.decoder_layers = nn.ModuleList() + for i in range(self.num_ladders - 1): + out_size = self.encoder_sizes[i] + self.decoder_layers.append( + self._create_decoder_block( + self.hidden_dims[-(i + 1)] * 2, # Account for concatenation + self.hidden_dims[-(i + 2)], + out_size, + ) + ) + + # Additional upsampling to reach target size + self.additional_ups = nn.ModuleList() + current_size = self.encoder_sizes[1] # Start from size after first encoder + while current_size < self.target_size: + next_size = min(current_size * 2, self.target_size) + self.additional_ups.append(self._create_upsampling_block(self.hidden_dims[0], next_size)) + current_size = next_size + + # Final output layer + self.output_layer = nn.Conv2d(self.hidden_dims[0], chn_num, kernel_size=3, padding=1) + + def _create_conv_block(self, in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(out_channels), + self.activation, + ) + + def _create_ladder_block(self, in_channels, fc_dim, z_dim, input_size): + def get_conv_output_size(input_size, kernel_size=4, stride=2, padding=1): + return ((input_size + 2 * padding - (kernel_size - 1) - 1) // stride) + 1 + + conv_size = get_conv_output_size(input_size) + total_flatten_size = in_channels * conv_size * conv_size + + return nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1), + nn.BatchNorm2d(in_channels), + self.activation, + nn.Flatten(), + nn.Linear(total_flatten_size, fc_dim), + nn.BatchNorm1d(fc_dim), + self.activation, + nn.Linear(fc_dim, z_dim * 2), + ) + + def _create_generator_block(self, z_dim, fc_dim, output_shape): + total_dim = output_shape[0] * output_shape[1] * output_shape[2] + return nn.Sequential( + nn.Linear(z_dim, fc_dim), + nn.BatchNorm1d(fc_dim), + self.activation, + nn.Linear(fc_dim, total_dim), + nn.BatchNorm1d(total_dim), + self.activation, + nn.Unflatten(1, output_shape), + ) + + def _create_decoder_block(self, in_channels, out_channels, target_size): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + self.activation, + nn.Upsample(size=(target_size, target_size), mode="nearest"), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + self.activation, + ) + + def _create_upsampling_block(self, channels, target_size): + return nn.Sequential( + nn.Upsample(size=(target_size, target_size), mode="nearest"), + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm2d(channels), + self.activation, + ) + + def _sample_latent(self, z_params): + z_mean, z_log_var = torch.chunk(z_params, 2, dim=1) + z_log_var = torch.clamp(z_log_var, min=-1e2, max=3) + + std = torch.exp(0.5 * z_log_var) + eps = torch.randn_like(std) + + return z_mean + eps * std, z_mean, z_log_var + + def _kl_divergence(self, z_mean, z_log_var): + return -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) + + def fade_in_alpha(self, step): + if step > self.fade_in_duration: + return 1.0 + return step / self.fade_in_duration + + def encode(self, x): + # Store original size + original_size = x.size()[-2:] + + # Resize to target size + if original_size != (self.target_size, self.target_size): + x = F.interpolate( + x, + size=(self.target_size, self.target_size), + mode="bilinear", + align_corners=True, + ) + + h_list = [] + h = x + + # Track encoder outputs + for i, layer in enumerate(self.encoder_layers): + h = layer(h) + expected_size = self.encoder_sizes[i + 1] + assert h.size(-1) == expected_size + h_list.append(h) + + if i + 1 == self.train_seq - 1: + h = h * self.fade_in + + z_params = [] + for i in range(self.num_ladders): + ladder_output = self.ladders[i](h_list[i]) + z, z_mean, z_log_var = self._sample_latent(ladder_output) + z_params.append((z, z_mean, z_log_var)) + + return z_params, original_size + + def decode(self, z_list, original_size): + # Generate features from latent vectors + features = [] + for i, z in enumerate(z_list): + f = self.generators[i](z) + if i > self.train_seq - 1: + f = f * 0 + elif i == self.train_seq - 1: + f = f * self.fade_in + features.append(f) + + # Start from deepest layer + x = features[-1] + + # Progressive decoding with explicit size management + for i in range(self.num_ladders - 2, -1, -1): + # Ensure feature maps have matching spatial dimensions + target_size = features[i].size(-1) + if x.size(-1) != target_size: + x = F.interpolate(x, size=(target_size, target_size), mode="nearest") + + # Concatenate features + x = torch.cat([features[i], x], dim=1) + if i < len(self.decoder_layers): + x = self.decoder_layers[i](x) + + # Additional upsampling if needed + for up_layer in self.additional_ups: + x = up_layer(x) + + # Final convolution + x = self.output_layer(x) + + # Resize to original input size + if original_size != (x.size(-2), x.size(-1)): + x = F.interpolate(x, size=original_size, mode="bilinear", align_corners=True) + + return x + + def forward(self, x, step=0): + self.fade_in = self.fade_in_alpha(step) + + # Encode + z_params, original_size = self.encode(x) + + # Calculate KL divergence + latent_losses = [] + zs = [] + for z, z_mean, z_log_var in z_params: + latent_losses.append(self._kl_divergence(z_mean, z_log_var)) + zs.append(z) + + latent_loss = sum(latent_losses) + + # Decode + x_recon = self.decode(zs, original_size) + + # Reconstruction loss + bce_loss = nn.BCEWithLogitsLoss(reduction="sum") + recon_loss = bce_loss(x_recon, x) + + # Calculate final loss + if self.pre_kl: + active_latents = latent_losses[self.train_seq - 1 :] + inactive_latents = latent_losses[: self.train_seq - 1] + loss = recon_loss + self.beta * sum(active_latents) + self.coff * sum(inactive_latents) + else: + loss = recon_loss + self.beta * latent_loss + + return torch.sigmoid(x_recon), loss, latent_loss, recon_loss + + def inference(self, x): + with torch.no_grad(): + z_params, _ = self.encode(x) + return z_params + + def generate(self, z_list): + with torch.no_grad(): + return torch.sigmoid(self.decode(z_list, (self.image_size, self.image_size))) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6d37b3e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.black] +line-length = 120 +target-version = ['py310'] +include = '\.pyi?$' + +[tool.isort] +profile = "black" +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true diff --git a/scripts/run_ddp.sh b/scripts/run_ddp.sh new file mode 100644 index 0000000..8e75605 --- /dev/null +++ b/scripts/run_ddp.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +torchrun --nproc_per_node=2 train_ddp.py \ + --distributed True \ + --mode seq_train \ + --dataset ident3d \ + --num_ladders 3 \ + --batch_size 128 \ + --num_epochs 30 \ + --learning_rate 5e-4 \ + --beta 1 \ + --z_dim 3 \ + --coff 0.5 \ + --hidden_dim 64 \ + --fade_in_duration 5000 \ + --output_dir ./output/ident3d/ \ + --optim adamw diff --git a/scripts/run_dsprites.sh b/scripts/run_dsprites.sh new file mode 100644 index 0000000..37036de --- /dev/null +++ b/scripts/run_dsprites.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset dsprites \ + --num_ladders 3 \ + --batch_size 64 \ + --num_epochs 30 \ + --learning_rate 5e-4 \ + --beta 3 \ + --z_dim 3 \ + --hidden_dim 32 \ + --fade_in_duration 5000 \ + --output_dir ./output/dsprites/ \ + --optim adamw diff --git a/scripts/run_dtd.sh b/scripts/run_dtd.sh new file mode 100644 index 0000000..ad27f15 --- /dev/null +++ b/scripts/run_dtd.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset dtd \ + --num_ladders 3 \ + --batch_size 32 \ + --num_epochs 30 \ + --learning_rate 5e-4 \ + --beta 3 \ + --z_dim 3 \ + --coff 0.2 \ + --hidden_dim 64 \ + --fade_in_duration 5000 \ + --output_dir ./output/dtd/ \ + --optim adamw diff --git a/scripts/run_fashionmnist.sh b/scripts/run_fashionmnist.sh new file mode 100644 index 0000000..1241e67 --- /dev/null +++ b/scripts/run_fashionmnist.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset fashionmnist \ + --num_ladders 3 \ + --batch_size 16 \ + --num_epochs 30 \ + --learning_rate 5e-4 \ + --beta 3 \ + --z_dim 3 \ + --hidden_dim 32 \ + --fade_in_duration 5000 \ + --output_dir ./output/fashionmnist/ \ + --optim adamw diff --git a/scripts/run_flowers102.sh b/scripts/run_flowers102.sh new file mode 100644 index 0000000..d308d50 --- /dev/null +++ b/scripts/run_flowers102.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset flowers102 \ + --num_ladders 3 \ + --batch_size 32 \ + --num_epochs 30 \ + --learning_rate 5e-4 \ + --beta 3 \ + --z_dim 3 \ + --coff 0.2 \ + --hidden_dim 32 \ + --fade_in_duration 5000 \ + --output_dir ./output/flowers102/ \ + --optim adamw diff --git a/scripts/run_ident3d.sh b/scripts/run_ident3d.sh new file mode 100644 index 0000000..5a32372 --- /dev/null +++ b/scripts/run_ident3d.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset ident3d \ + --num_ladders 3 \ + --batch_size 128 \ + --num_epochs 15 \ + --learning_rate 5e-4 \ + --beta 3 \ + --z_dim 3 \ + --coff 0.3 \ + --hidden_dim 32 \ + --fade_in_duration 5000 \ + --output_dir ./output/ident3d/ \ + --optim adamw \ diff --git a/scripts/run_mnist.sh b/scripts/run_mnist.sh new file mode 100644 index 0000000..510c2da --- /dev/null +++ b/scripts/run_mnist.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset mnist \ + --num_ladders 3 \ + --batch_size 16 \ + --num_epochs 30 \ + --learning_rate 5e-4 \ + --beta 3 \ + --z_dim 3 \ + --hidden_dim 32 \ + --fade_in_duration 5000 \ + --output_dir ./output/mnist/ \ + --optim adamw \ No newline at end of file diff --git a/scripts/run_mpi3d.sh b/scripts/run_mpi3d.sh new file mode 100644 index 0000000..85e8aef --- /dev/null +++ b/scripts/run_mpi3d.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset mpi3d \ + --num_ladders 3 \ + --batch_size 128 \ + --num_epochs 15 \ + --learning_rate 5e-4 \ + --beta 3 \ + --z_dim 3 \ + --coff 0.3 \ + --hidden_dim 32 \ + --fade_in_duration 5000 \ + --output_dir ./output/mpi3d/ \ + --optim adamw \ diff --git a/scripts/run_shape3d.sh b/scripts/run_shape3d.sh new file mode 100644 index 0000000..b45f16a --- /dev/null +++ b/scripts/run_shape3d.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +python train.py \ + --mode seq_train \ + --dataset shapes3d \ + --num_ladders 3 \ + --batch_size 256 \ + --num_epochs 1 \ + --learning_rate 5e-4 \ + --beta 20 \ + --z_dim 3 \ + --hidden_dim 32 \ + --fade_in_duration 5000 \ + --output_dir ./output/shapes3d/ \ + --optim adamw diff --git a/train.py b/train.py new file mode 100644 index 0000000..a465e6d --- /dev/null +++ b/train.py @@ -0,0 +1,434 @@ +import argparse +import os +import time +from dataclasses import dataclass, field + +import imageio.v3 as imageio +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +import torch_optimizer as jettify_optim +import torchvision +from loguru import logger +from PIL import Image, ImageDraw, ImageFont +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from dataset import DTD, MNIST, MPI3D, CelebA, DSprites, FashionMNIST, Flowers102, Ident3D, ImageNet, Shapes3D +from models import ProVLAE + + +@dataclass +class HyperParameters: + dataset: str = field(default="shapes3d") + data_path: str = field(default="./data") + z_dim: int = field(default=3) + num_ladders: int = field(default=3) + beta: float = field(default=8.0) + learning_rate: float = field(default=5e-4) + fade_in_duration: int = field(default=5000) + image_size: int = field(default=64) + chn_num: int = field(default=3) + train_seq: int = field(default=1) + batch_size: int = field(default=100) + num_epochs: int = field(default=1) + mode: str = field(default="seq_train") + hidden_dim: int = field(default=32) + coff: float = field(default=0.5) + output_dir: str = field(default="outputs") + + # pytorch optimization + compile_mode: str = field(default="default") # or max-autotune-no-cudagraphs + on_cudnn_benchmark: bool = field(default=True) + optim: str = field(default="adam") + + @property + def checkpoint_path(self): + return os.path.join(self.output_dir, f"checkpoints/model_seq{self.train_seq}.pt") + + @property + def recon_path(self): + return os.path.join(self.output_dir, f"reconstructions/recon_seq{self.train_seq}.png") + + @property + def traverse_path(self): + return os.path.join(self.output_dir, f"traversals/traverse_seq{self.train_seq}.gif") + + @classmethod + def from_args(cls): + parser = argparse.ArgumentParser() + for field_info in cls.__dataclass_fields__.values(): + parser.add_argument(f"--{field_info.name}", type=field_info.type, default=field_info.default) + return cls(**vars(parser.parse_args())) + + +def get_dataset(params): + """Load the dataset and return the data loader""" + dataset_classes = { + "mnist": MNIST, + "fashionmnist": FashionMNIST, + "shapes3d": Shapes3D, + "dsprites": DSprites, + "celeba": CelebA, + "flowers102": Flowers102, + "dtd": DTD, + "imagenet": ImageNet, + "mpi3d": MPI3D, + "ident3d": Ident3D, + } + + if params.dataset not in dataset_classes: + raise ValueError(f"Unknown dataset: {params.dataset}. " f"Available datasets: {list(dataset_classes.keys())}") + + dataset_class = dataset_classes[params.dataset] + if params.dataset == "mpi3d": + # mpi3d variants: toy, real + variant = getattr(params, "mpi3d_variant", "toy") + dataset = dataset_class( + root=params.data_path, + variant=variant, + batch_size=params.batch_size, + num_workers=4, + ) + else: # other dataset + dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4) + + config = dataset.get_config() + params.chn_num = config.chn_num + params.image_size = config.image_size + + logger.success("Dataset loaded.") + return dataset.get_data_loader() + + +def exec_time(func): + """Decorates a function to measure its execution time in hours and minutes.""" + + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + + logger.success(f"Training completed ({int(execution_time // 3600)}h {int((execution_time % 3600) // 60)}min)") + return result + + return wrapper + + +def load_checkpoint(model, optimizer, scaler, checkpoint_path): + """ + Load a model checkpoint to resume training or run further inference. + """ + torch.serialization.add_safe_globals([set]) + checkpoint = torch.load( + checkpoint_path, + map_location=torch.device("cpu" if not torch.cuda.is_available() else "cuda"), + weights_only=True, + ) + + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if scaler is not None and "scaler_state_dict" in checkpoint: + scaler.load_state_dict(checkpoint["scaler_state_dict"]) + + logger.info( + f"Loaded checkpoint from '{checkpoint_path}' (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})" + ) + + return model, optimizer, scaler + + +def save_reconstruction(inputs, reconstructions, save_path): + """Save a grid of original and reconstructed images""" + batch_size = min(8, inputs.shape[0]) + inputs = inputs[:batch_size].float() + reconstructions = reconstructions[:batch_size].float() + comparison = torch.cat([inputs[:batch_size], reconstructions[:batch_size]]) + + # Denormalize and convert to numpy + images = comparison.cpu().detach() + images = torch.clamp(images, 0, 1) + grid = torchvision.utils.make_grid(images, nrow=batch_size) + image = grid.permute(1, 2, 0).numpy() + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + imageio.imwrite(save_path, (image * 255).astype("uint8")) + + +def create_latent_traversal(model, data_loader, save_path, device, params): + """Create and save organized latent traversal GIF with optimized layout""" + model.eval() + model.fade_in = 1.0 + with torch.no_grad(): + # Get a single batch of images + inputs, _ = next(iter(data_loader)) + inputs = inputs[0:1].to(device) + + # Get latent representations + with torch.amp.autocast(device_type="cuda", enabled=False): + latent_vars = [z[0] for z in model.inference(inputs)] + + # Traverse values + traverse_range = torch.linspace(-1.5, 1.5, 15).to(device) + + # Image layout parameters + img_size = 96 # Base image size + padding = 1 # Reduced padding between images + label_margin = 1 # Margin for labels inside images + font_size = 7 # Smaller font size for better fit + + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) + except: + font = ImageFont.load_default() + + frames = [] + for t_idx in range(len(traverse_range)): + current_images = [] + + # Generate images for each ladder and dimension + for ladder_idx in range(len(latent_vars) - 1, -1, -1): + for dim in range(latent_vars[ladder_idx].shape[1]): + z_mod = [v.clone() for v in latent_vars] + z_mod[ladder_idx][0, dim] = traverse_range[t_idx] + + with torch.amp.autocast(device_type="cuda", enabled=False): + gen_img = model.generate(z_mod) + img = gen_img[0].cpu().float() + img = torch.clamp(img, 0, 1) + + # Resize image if needed + if img.shape[-1] != img_size: + img = F.interpolate( + img.unsqueeze(0), + size=img_size, + mode="bilinear", + align_corners=False, + ).squeeze(0) + + # Handle both single-channel and multi-channel images + if img.shape[0] == 1: + # Single channel (grayscale) - repeat to create RGB + img = img.repeat(3, 1, 1) + + # Convert to PIL Image for adding text + img_array = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + img_pil = Image.fromarray(img_array) + draw = ImageDraw.Draw(img_pil) + + # Add label inside the image + label = f"L{len(latent_vars)-ladder_idx} z{dim+1}" + # draw black and white text to create border effect + for offset in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + draw.text( + (label_margin + offset[0], label_margin + offset[1]), + label, + (0, 0, 0), + font=font, + ) + draw.text((label_margin, label_margin), label, (255, 255, 255), font=font) + + # Add value label to bottom-left image + if ladder_idx == 0 and dim == 0: + value_label = f"v = {traverse_range[t_idx].item():.2f}" + for offset in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + draw.text( + ( + label_margin + offset[0], + img_size - font_size - label_margin + offset[1], + ), + value_label, + (0, 0, 0), + font=font, + ) + draw.text( + (label_margin, img_size - font_size - label_margin), + value_label, + (255, 255, 255), + font=font, + ) + + # Convert back to tensor + img_tensor = torch.tensor(np.array(img_pil)).float() / 255.0 + img_tensor = img_tensor.permute(2, 0, 1) + current_images.append(img_tensor) + + # Convert to tensor and create grid + current_images = torch.stack(current_images) + + # Create grid with minimal padding + grid = torchvision.utils.make_grid(current_images, nrow=params.z_dim, padding=padding, normalize=True) + + grid_array = (grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + frames.append(grid_array) + + # Save GIF with infinite loop + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # duration=200 means 5 FPS, loop=0 means infinite loop + imageio.imwrite(save_path, frames, duration=200, loop=0, format="GIF", optimize=False) + + +@exec_time +def train_model(model, data_loader, optimizer, params, device, scaler=None, autocast_dtype=torch.float16): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + model.train() + global_step = 0 + + logger.info("Start training.") + for epoch in range(params.num_epochs): + with tqdm( + enumerate(data_loader), + desc=f"Current epoch [{epoch + 1}/{params.num_epochs}]", + leave=False, + total=len(data_loader), + ) as pbar: + for batch_idx, (inputs, _) in pbar: + inputs = inputs.to(device, non_blocking=True) + + # Forward pass with autocast + with torch.amp.autocast(device_type="cuda", dtype=autocast_dtype): + x_recon, loss, latent_loss, recon_loss = model(inputs, step=global_step) + + # Backward pass with appropriate scaling + optimizer.zero_grad() + if scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + pbar.set_postfix( + total_loss=f"{loss.item():.2f}", + latent_loss=f"{latent_loss:.2f}", + recon_loss=f"{recon_loss:.2f}", + ) + global_step += 1 + + # Save checkpoints and visualizations + os.makedirs(os.path.dirname(params.checkpoint_path), exist_ok=True) + + # Model evaluation for visualizations + model.eval() + with torch.no_grad(), torch.amp.autocast("cuda", dtype=autocast_dtype): + save_reconstruction(inputs, x_recon, params.recon_path) + create_latent_traversal(model, data_loader, params.traverse_path, device, params) + model.train() + + checkpoint = { + "epoch": epoch + 1, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scaler_state_dict": scaler.state_dict() if scaler is not None else None, + "loss": loss.item(), + } + torch.save(checkpoint, params.checkpoint_path) + + if (epoch + 1) % 5 == 0: + logger.info(f"Epoch: [{epoch+1}/{params.num_epochs}], Loss: {loss.item():.2f}") + + +def get_optimizer(model, params): + optimizers = { + "adam": optim.Adam, + "adamw": optim.AdamW, + "sgd": optim.SGD, + "lamb": jettify_optim.Lamb, + "diffgrad": jettify_optim.DiffGrad, + "madgrad": jettify_optim.MADGRAD, + } + + optimizer = optimizers.get(params.optim.lower()) + + if optimizer is None: + optimizer = optimizers.get("adam") + logger.warning(f"unsupported optimizer {params.optim}, use Adam optimizer.") + + return optimizer(model.parameters(), lr=params.learning_rate) + + +def main(): + # Setup + params = HyperParameters.from_args() + + # gpu config + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.set_float32_matmul_precision("high") + if params.on_cudnn_benchmark: + torch.backends.cudnn.benchmark = True + + if torch.cuda.get_device_capability()[0] >= 8: + if not torch.cuda.is_bf16_supported(): + logger.warning("BF16 is not supported, falling back to FP16") + autocast_dtype = torch.float16 + scaler = torch.amp.GradScaler() + else: # BF16 + autocast_dtype = torch.bfloat16 + scaler = None + logger.debug("Using BF16 mixed precision") + else: # FP16 + autocast_dtype = torch.float16 + scaler = torch.amp.GradScaler() + logger.debug("Using FP16 mixed precision with gradient scaling") + + os.makedirs(params.output_dir, exist_ok=True) + train_loader, test_loader = get_dataset(params) + + # Initialize model + model = ProVLAE( + z_dim=params.z_dim, + beta=params.beta, + learning_rate=params.learning_rate, + fade_in_duration=params.fade_in_duration, + chn_num=params.chn_num, + train_seq=params.train_seq, + image_size=params.image_size, + num_ladders=params.num_ladders, + hidden_dim=params.hidden_dim, + coff=params.coff, + ).to(device) + + optimizer = get_optimizer(model, params) + model = torch.compile(model, mode=params.compile_mode) + + # Train model or visualize traverse + if params.mode == "seq_train": + logger.opt(colors=True).info(f"✓ Mode: sequential execution [progress 1 >> {params.num_ladders}]") + for i in range(1, params.num_ladders + 1): + params.train_seq, model.train_seq = i, i + if params.train_seq >= 2: + prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt") + if os.path.exists(prev_checkpoint): + model, optimizer, scaler = load_checkpoint(model, optimizer, scaler, prev_checkpoint) + train_model(model, train_loader, optimizer, params, device, scaler, autocast_dtype) + + elif params.mode == "indep_train": + logger.opt(colors=True).info(f"✓ Mode: independent execution [progress {params.train_seq}]") + if params.train_seq >= 2: + prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt") + if os.path.exists(prev_checkpoint): + model, optimizer, scaler = load_checkpoint(model, optimizer, scaler, prev_checkpoint) + + train_model(model, train_loader, optimizer, params, device, scaler, autocast_dtype) + + elif params.mode == "visualize": + logger.opt(colors=True).info(f"✓ Mode: visualize latent traversing [progress {params.train_seq}]") + current_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq}.pt") + if os.path.exists(current_checkpoint): + model, _, _ = load_checkpoint(model, optimizer, scaler, current_checkpoint) + create_latent_traversal(model, test_loader, params.traverse_path, device, params) + logger.success("Latent traversal visualization saved.") + + else: + logger.error(f"unsupported mode: {params.mode}, use 'train' or 'visualize'.") + + +if __name__ == "__main__": + main() diff --git a/train_ddp.py b/train_ddp.py new file mode 100644 index 0000000..3e4390f --- /dev/null +++ b/train_ddp.py @@ -0,0 +1,688 @@ +import argparse +import datetime +import os +import sys +import time +from dataclasses import dataclass, field + +import imageio.v3 as imageio +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.optim as optim +import torch_optimizer as jettify_optim +import torchvision +from loguru import logger +from PIL import Image, ImageDraw, ImageFont +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + +from dataset import DTD, MNIST, MPI3D, CelebA, DSprites, FashionMNIST, Flowers102, Ident3D, ImageNet, Shapes3D +from models import ProVLAE + + +@dataclass +class HyperParameters: + dataset: str = field(default="shapes3d") + data_path: str = field(default="./data") + z_dim: int = field(default=3) + num_ladders: int = field(default=3) + beta: float = field(default=8.0) + learning_rate: float = field(default=5e-4) + fade_in_duration: int = field(default=5000) + image_size: int = field(default=64) + chn_num: int = field(default=3) + train_seq: int = field(default=1) + batch_size: int = field(default=100) + num_epochs: int = field(default=1) + mode: str = field(default="seq_train") + hidden_dim: int = field(default=32) + coff: float = field(default=0.5) + output_dir: str = field(default="outputs") + + # pytorch optimization + compile_mode: str = field(default="default") # or max-autotune-no-cudagraphs + on_cudnn_benchmark: bool = field(default=True) + optim: str = field(default="adam") + + # distributed training parameters + distributed: bool = field(default=False) + local_rank: int = field(default=-1) + world_size: int = field(default=1) + dist_backend: str = field(default="nccl") + dist_url: str = field(default="env://") + + @property + def checkpoint_path(self): + return os.path.join(self.output_dir, f"checkpoints/model_seq{self.train_seq}.pt") + + @property + def recon_path(self): + return os.path.join(self.output_dir, f"reconstructions/recon_seq{self.train_seq}.png") + + @property + def traverse_path(self): + return os.path.join(self.output_dir, f"traversals/traverse_seq{self.train_seq}.gif") + + @classmethod + def from_args(cls): + parser = argparse.ArgumentParser() + for field_info in cls.__dataclass_fields__.values(): + parser.add_argument(f"--{field_info.name}", type=field_info.type, default=field_info.default) + return cls(**vars(parser.parse_args())) + + +def setup_logger(rank=-1, world_size=1): + """Setup logger for distributed training""" + config = { + "handlers": [ + { + "sink": sys.stdout, + "format": ( + "{time:YYYY-MM-DD HH:mm:ss} | " + "{level: <8} | " + "Rank {extra[rank]}/{extra[world_size]} | " + "{name}:{line} | " + "{message}" + ), + "level": "DEBUG", + "colorize": True, + } + ] + } + + try: # Remove all existing handlers + logger.configure(**config) + except ValueError: + pass + + # Create a new logger instance with rank information + return logger.bind(rank=rank, world_size=world_size) + + +def setup_distributed(params): + """Initialize distributed training environment with explicit device mapping""" + if not params.distributed: + return False + + try: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + params.rank = int(os.environ["RANK"]) + params.world_size = int(os.environ["WORLD_SIZE"]) + params.local_rank = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + params.rank = int(os.environ["SLURM_PROCID"]) + params.local_rank = params.rank % torch.cuda.device_count() + params.world_size = int(os.environ["SLURM_NTASKS"]) + else: + raise ValueError("Not running with distributed environment variables set") + + torch.cuda.set_device(params.local_rank) + init_method = "env://" + backend = params.dist_backend + if backend == "nccl" and not torch.cuda.is_available(): + backend = "gloo" + + if not dist.is_initialized(): + dist.init_process_group( + backend=backend, + init_method=init_method, + world_size=params.world_size, + rank=params.rank, + ) + torch.cuda.set_device(params.local_rank) + dist.barrier(device_ids=[params.local_rank]) + + return True + except Exception as e: + print(f"Failed to initialize distributed training: {e}") + return False + + +def cleanup_distributed(): + """Clean up distributed training resources safely with device mapping""" + if dist.is_initialized(): + try: + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.barrier(device_ids=[local_rank]) + dist.destroy_process_group() + except Exception as e: + print(f"Error during distributed cleanup: {e}") + + +def get_dataset(params, logger): + """Load dataset with distributed support""" + dataset_classes = { + "mnist": MNIST, + "fashionmnist": FashionMNIST, + "shapes3d": Shapes3D, + "dsprites": DSprites, + "celeba": CelebA, + "flowers102": Flowers102, + "dtd": DTD, + "imagenet": ImageNet, + "mpi3d": MPI3D, + "ident3d": Ident3D, + } + + if params.dataset not in dataset_classes: + raise ValueError(f"Unknown dataset: {params.dataset}") + + dataset_class = dataset_classes[params.dataset] + + try: + if params.dataset == "mpi3d": + variant = getattr(params, "mpi3d_variant", "toy") + dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4, variant=variant) + else: + dataset = dataset_class(root=params.data_path, batch_size=params.batch_size, num_workers=4) + + config = dataset.get_config() + params.chn_num = config.chn_num + params.image_size = config.image_size + + train_loader, test_loader = dataset.get_data_loader() + if params.distributed: + train_sampler = DistributedSampler( + train_loader.dataset, + num_replicas=params.world_size, + rank=params.local_rank, + shuffle=True, + drop_last=True, + ) + + train_loader = torch.utils.data.DataLoader( + train_loader.dataset, + batch_size=params.batch_size, + sampler=train_sampler, + num_workers=4, + pin_memory=True, + drop_last=True, + persistent_workers=True, + ) + + if params.local_rank == 0: + logger.info(f"Dataset {params.dataset} loaded with distributed sampler") + else: + logger.info(f"Dataset {params.dataset} loaded") + + return train_loader, test_loader + + except Exception as e: + logger.error(f"Failed to load dataset: {str(e)}") + raise + + +def exec_time(func): + """Decorates a function to measure its execution time in hours and minutes.""" + + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + + logger = kwargs.get("logger") # Get logger from kwargs + if not logger: # Find logger in positional arguments + for arg in args: + if isinstance(arg, type(setup_logger())): + logger = arg + break + + if logger: + logger.success( + f"Training completed ({int(execution_time // 3600)}h {int((execution_time % 3600) // 60)}min)" + ) + return result + + return wrapper + + +def load_checkpoint(model, optimizer, scaler, checkpoint_path, device, logger): + """Load a model checkpoint with proper device management.""" + try: + checkpoint = torch.load( + checkpoint_path, + map_location=device, + weights_only=True, + ) + + # Load model state dict + if hasattr(model, "module"): + model.module.load_state_dict(checkpoint["model_state_dict"]) + else: + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + + # Load optimizer state dict + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + if scaler is not None and "scaler_state_dict" in checkpoint: + scaler.load_state_dict(checkpoint["scaler_state_dict"]) + + logger.info( + f"Loaded checkpoint from '{checkpoint_path}' (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})" + ) + + return model, optimizer, scaler + except Exception as e: + logger.error(f"Failed to load checkpoint: {str(e)}") + return model, optimizer, scaler + + +def save_reconstruction(inputs, reconstructions, save_path): + """Save a grid of original and reconstructed images""" + batch_size = min(8, inputs.shape[0]) + inputs = inputs[:batch_size].float() + reconstructions = reconstructions[:batch_size].float() + comparison = torch.cat([inputs[:batch_size], reconstructions[:batch_size]]) + + # Denormalize and convert to numpy + images = comparison.cpu().detach() + images = torch.clamp(images, 0, 1) + grid = torchvision.utils.make_grid(images, nrow=batch_size) + image = grid.permute(1, 2, 0).numpy() + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + imageio.imwrite(save_path, (image * 255).astype("uint8")) + + +def create_latent_traversal(model, data_loader, save_path, device, params): + """Create and save organized latent traversal GIF with optimized layout""" + model.eval() + + if hasattr(model, "module"): + model = model.module + + model.fade_in = 1.0 + with torch.no_grad(): + # Get a single batch of images + inputs, _ = next(iter(data_loader)) + inputs = inputs[0:1].to(device) + + # Get latent representations + with torch.amp.autocast(device_type="cuda", enabled=False): + latent_vars = [z[0] for z in model.inference(inputs)] + + # Traverse values + traverse_range = torch.linspace(-1.5, 1.5, 15).to(device) + + # Image layout parameters + img_size = 96 # Base image size + padding = 1 # Reduced padding between images + label_margin = 1 # Margin for labels inside images + font_size = 7 # Smaller font size for better fit + + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) + except: + font = ImageFont.load_default() + + frames = [] + for t_idx in range(len(traverse_range)): + current_images = [] + + # Generate images for each ladder and dimension + for ladder_idx in range(len(latent_vars) - 1, -1, -1): + for dim in range(latent_vars[ladder_idx].shape[1]): + z_mod = [v.clone() for v in latent_vars] + z_mod[ladder_idx][0, dim] = traverse_range[t_idx] + + with torch.amp.autocast(device_type="cuda", enabled=False): + gen_img = model.generate(z_mod) + img = gen_img[0].cpu().float() + img = torch.clamp(img, 0, 1) + + # Resize image if needed + if img.shape[-1] != img_size: + img = F.interpolate( + img.unsqueeze(0), + size=img_size, + mode="bilinear", + align_corners=False, + ).squeeze(0) + + # Handle both single-channel and multi-channel images + if img.shape[0] == 1: + # Single channel (grayscale) - repeat to create RGB + img = img.repeat(3, 1, 1) + + # Convert to PIL Image for adding text + img_array = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + img_pil = Image.fromarray(img_array) + draw = ImageDraw.Draw(img_pil) + + # Add label inside the image + label = f"L{len(latent_vars)-ladder_idx} z{dim+1}" + # draw black and white text to create border effect + for offset in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + draw.text( + (label_margin + offset[0], label_margin + offset[1]), + label, + (0, 0, 0), + font=font, + ) + draw.text((label_margin, label_margin), label, (255, 255, 255), font=font) + + # Add value label to bottom-left image + if ladder_idx == 0 and dim == 0: + value_label = f"v = {traverse_range[t_idx].item():.2f}" + for offset in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + draw.text( + ( + label_margin + offset[0], + img_size - font_size - label_margin + offset[1], + ), + value_label, + (0, 0, 0), + font=font, + ) + draw.text( + (label_margin, img_size - font_size - label_margin), + value_label, + (255, 255, 255), + font=font, + ) + + # Convert back to tensor + img_tensor = torch.tensor(np.array(img_pil)).float() / 255.0 + img_tensor = img_tensor.permute(2, 0, 1) + current_images.append(img_tensor) + + # Convert to tensor and create grid + current_images = torch.stack(current_images) + + # Create grid with minimal padding + grid = torchvision.utils.make_grid(current_images, nrow=params.z_dim, padding=padding, normalize=True) + + grid_array = (grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8) + frames.append(grid_array) + + # Save GIF with infinite loop + os.makedirs(os.path.dirname(save_path), exist_ok=True) + # duration=200 means 5 FPS, loop=0 means infinite loop + imageio.imwrite(save_path, frames, duration=200, loop=0, format="GIF", optimize=False) + + +@exec_time +def train_model(model, data_loader, optimizer, params, device, logger, scaler=None, autocast_dtype=torch.float16): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + if hasattr(model, "module"): + model.module.to(device) + else: + model.to(device) + + model.train() + global_step = 0 + + logger.info("Start training.") + for epoch in range(params.num_epochs): + if params.distributed: + data_loader.sampler.set_epoch(epoch) + + with tqdm( + enumerate(data_loader), + desc=f"Current epoch [{epoch + 1}/{params.num_epochs}]", + leave=False, + total=len(data_loader), + disable=params.distributed and params.local_rank != 0, + ) as pbar: + for batch_idx, (inputs, _) in pbar: + inputs = inputs.to(device, non_blocking=True) + + with torch.amp.autocast(device_type="cuda", dtype=autocast_dtype): + x_recon, loss, latent_loss, recon_loss = model(inputs, step=global_step) + + optimizer.zero_grad() + if scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + if params.local_rank == 0 or not params.distributed: # Only show progress on main process + pbar.set_postfix( + total_loss=f"{loss.item():.2f}", + latent_loss=f"{latent_loss:.2f}", + recon_loss=f"{recon_loss:.2f}", + ) + global_step += 1 + + # Save checkpoints and visualizations only on main process + if params.local_rank == 0 or not params.distributed: + os.makedirs(os.path.dirname(params.checkpoint_path), exist_ok=True) + + model.eval() + with torch.no_grad(), torch.amp.autocast("cuda", dtype=autocast_dtype): + save_reconstruction(inputs, x_recon, params.recon_path) + create_latent_traversal(model, data_loader, params.traverse_path, device, params) + model.train() + + checkpoint = { + "epoch": epoch + 1, + "model_state_dict": model.module.state_dict() if params.distributed else model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scaler_state_dict": scaler.state_dict() if scaler is not None else None, + "loss": loss.item(), + } + torch.save(checkpoint, params.checkpoint_path) + + if (epoch + 1) % 5 == 0: + logger.info(f"Epoch: [{epoch+1}/{params.num_epochs}], Loss: {loss.item():.2f}") + + +def get_optimizer(model, params): + """Get the optimizer based on the parameter settings""" + optimizers = { + "adam": optim.Adam, + "adamw": optim.AdamW, + "sgd": optim.SGD, + "lamb": jettify_optim.Lamb, + "diffgrad": jettify_optim.DiffGrad, + "madgrad": jettify_optim.MADGRAD, + } + + optimizer = optimizers.get(params.optim.lower()) + + if optimizer is None: + optimizer = optimizers.get("adam") + logger.warning(f"unsupported optimizer {params.optim}, use Adam optimizer.") + + return optimizer(model.parameters(), lr=params.learning_rate) + + +def main(): + params = HyperParameters.from_args() + + try: + # Setup distributed training + is_distributed = setup_distributed(params) + rank = params.local_rank if is_distributed else 0 + world_size = params.world_size if is_distributed else 1 + + # Setup device and logger + device = torch.device(f"cuda:{params.local_rank}" if is_distributed else "cuda") + logger = setup_logger(rank, world_size) + + # GPU config + torch.set_float32_matmul_precision("high") + if params.on_cudnn_benchmark: + torch.backends.cudnn.benchmark = True + + # Mixed precision setup + if torch.cuda.get_device_capability()[0] >= 8: + if not torch.cuda.is_bf16_supported(): + logger.warning("BF16 is not supported, falling back to FP16") + autocast_dtype = torch.float16 + scaler = torch.amp.GradScaler() + else: + autocast_dtype = torch.bfloat16 + scaler = None + logger.debug("Using BF16 mixed precision") + else: + autocast_dtype = torch.float16 + scaler = torch.amp.GradScaler() + logger.debug("Using FP16 mixed precision with gradient scaling") + + # Create output directory + if rank == 0: + os.makedirs(params.output_dir, exist_ok=True) + + if is_distributed: + dist.barrier() + + # Get dataset and model + train_loader, test_loader = get_dataset(params, logger) + model = ProVLAE( + z_dim=params.z_dim, + beta=params.beta, + learning_rate=params.learning_rate, + fade_in_duration=params.fade_in_duration, + chn_num=params.chn_num, + train_seq=params.train_seq, + image_size=params.image_size, + num_ladders=params.num_ladders, + hidden_dim=params.hidden_dim, + coff=params.coff, + ).to(device) + + if is_distributed: + model = DDP( + model, + device_ids=[params.local_rank], + output_device=params.local_rank, + find_unused_parameters=True, + broadcast_buffers=True, + ) + torch.cuda.synchronize() + dist.barrier() + + optimizer = get_optimizer(model, params) + if not is_distributed: + model = torch.compile(model, mode=params.compile_mode) + + # Training mode selection + if params.mode == "seq_train": + if rank == 0: + logger.opt(colors=True).info( + f"✓ Mode: sequential execution [progress 1 >> {params.num_ladders}]" + ) + + for i in range(1, params.num_ladders + 1): + if is_distributed: + torch.cuda.synchronize() + dist.barrier() + + # Update sequence number + if is_distributed: + params.train_seq = i + model.module.train_seq = i + else: + params.train_seq = i + model.train_seq = i + + # Load checkpoint if needed + if params.train_seq >= 2: + prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt") + if os.path.exists(prev_checkpoint): + model, optimizer, scaler = load_checkpoint( + model=model, + optimizer=optimizer, + scaler=scaler, + checkpoint_path=prev_checkpoint, + device=device, + logger=logger, + ) + + if is_distributed: + torch.cuda.synchronize() + dist.barrier() + + # Training + train_model( + model=model, + data_loader=train_loader, + optimizer=optimizer, + params=params, + device=device, + logger=logger, + scaler=scaler, + autocast_dtype=autocast_dtype, + ) + + if is_distributed: + torch.cuda.synchronize() + dist.barrier() + + elif params.mode == "indep_train": + if rank == 0: + logger.opt(colors=True).info( + f"✓ Mode: independent execution [progress {params.train_seq}]" + ) + + if is_distributed: + torch.cuda.synchronize() + dist.barrier() + + # Load checkpoint if needed + if params.train_seq >= 2: + prev_checkpoint = os.path.join(params.output_dir, f"checkpoints/model_seq{params.train_seq-1}.pt") + if os.path.exists(prev_checkpoint): + model, optimizer, scaler = load_checkpoint( + model=model, + optimizer=optimizer, + scaler=scaler, + checkpoint_path=prev_checkpoint, + device=device, + logger=logger, + ) + + if is_distributed: + torch.cuda.synchronize() + dist.barrier() + + # Training + train_model( + model=model, + data_loader=train_loader, + optimizer=optimizer, + params=params, + device=device, + logger=logger, + scaler=scaler, + autocast_dtype=autocast_dtype, + ) + + if is_distributed: + torch.cuda.synchronize() + dist.barrier() + + else: + logger.error(f"Unsupported mode: {params.mode}, use 'seq_train' or 'indep_train'") + return + + except Exception as e: + logger.error(f"Training failed: {str(e)}") + raise + finally: + try: + if is_distributed: + torch.cuda.synchronize() + cleanup_distributed() + logger.info("Distributed resources cleaned up successfully") + except Exception as e: + logger.error(f"Error during cleanup: {e}") + + +if __name__ == "__main__": + main()