Skip to content

Commit

Permalink
divide utils and fix arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
suzuki-2001 committed Nov 18, 2024
1 parent a391fbf commit 36e9141
Show file tree
Hide file tree
Showing 21 changed files with 445 additions and 713 deletions.
63 changes: 34 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ This implementation introduces dynamic size management for arbitrary input image
## Installation
We recommend using [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) (via [miniforge](https://github.com/conda-forge/miniforge)) for faster installation of dependencies, but you can also use [conda](https://docs.anaconda.com/miniconda/miniconda-install/).
```bash
git clone https://github.com/suzuki-2001/pytorch-proVLAE.git
cd pytorch-proVLAE

mamba env create -f env.yaml # or conda
mamba activate torch-provlae
```
Expand All @@ -48,51 +51,38 @@ mamba activate torch-provlae

## Usage
You can train pytorch-proVLAE with the following command. Sample hyperparameters and train configuration are provided in [scripts directory](./scripts/).
If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "visualize" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.
If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "traverse" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.

</br>

```bash
# all progressive training steps
python train.py \
--dataset shapes3d \
--mode seq_train \
--batch_size 100 \
--num_epochs 15 \
--learning_rate 5e-4 \
--beta 15 \
--z_dim 3 \
--hidden_dim 64 \
--fade_in_duration 5000 \
--optim adamw \
--image_size 64 \
--chn_num 3 \
--output_dir ./output/shapes3d/

# training with distributed data parallel
torchrun --nproc_per_node=2 train_ddp.py \
--distributed True \
# we tested NVIDIA V100 PCIE 16GB+32GB, NVIDIA A6000 48GB x2
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset ident3d \
--dataset shapes3d \
--optim adamw \
--num_ladders 3 \
--batch_size 128 \
--num_epochs 30 \
--num_epochs 15 \
--learning_rate 5e-4 \
--beta 1 \
--beta 8 \
--z_dim 3 \
--coff 0.5 \
--hidden_dim 64 \
--pre_kl \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/ident3d/ \
--optim adamw
--output_dir ./output/shapes3d/ \
--data_path ./data/shapes3d/
```

</br>

- ### Hyper Parameters

| Argument | Default | Description |
|----------|---------------|-------------|
| `dataset` | "shapes3d" | Dataset to use (mnist, fashionmnist, dsprites, shapes3d, mpi3d, ident3d,celeba, flowers102, dtd, imagenet) |
| `data_path` | "./data" | Path to dataset storage |
| `z_dim` | 3 | Dimension of latent variables |
| `num_ladders` | 3 | Number of ladders (hierarchies) in pro-VLAE |
| `beta` | 8.0 | β parameter for pro-VLAE |
Expand All @@ -103,20 +93,35 @@ torchrun --nproc_per_node=2 train_ddp.py \
| `train_seq` | 1 | Current training sequence number (`indep_train mode` only) |
| `batch_size` | 100 | Batch size |
| `num_epochs` | 1 | Number of epochs |
| `mode` | "seq_train" | Execution mode ("seq_train", "indep_train", "visualize") |
| `hidden_dim` | 32 | Hidden layer dimension |
| `coff` | 0.5 | Coefficient for KL divergence |
| `pre_kl` | True | use inactive ladder loss |

</br>

- ### Training Parameters

| Argument | Default | Description |
|----------|---------------|-------------|
| `dataset` | "shapes3d" | Dataset to use (mnist, fashionmnist, dsprites, shapes3d, mpi3d, ident3d,celeba, flowers102, dtd, imagenet) |
| `data_path` | "./data" | Path to dataset storage |
| `output_dir` | "outputs" | Output directory |
| `checkpoint_dir` | "checkpoints" | Checkpooints results directory |
| `recon_dir` | "reconstructions" | Reconstructions results directory |
| `traverse_dir` | "travesals" | Traversal results directory |
| `mode` | "seq_train" | Execution mode ("seq_train", "indep_train", "traverse") |
| `compile_mode` | "default" | PyTorch compilation mode |
| `on_cudnn_benchmark` | True | Enable/disable cuDNN benchmark |
| `optim` | "adam" | Optimization algorithm (adam, adamw, sgd, lamb, diffgrad, madgrad) |
| `distributed` | False | enable distributed data parallel |
| `num_workers` | 4 | Number of workers for data loader |

</br>

Mode descriptions:
- `seq_train`: Sequential training from ladder 1 to `num_ladders`
- `indep_train`: Independent training of specified `train_seq` ladder
- `visualize`: Visualize latent space using trained model (need checkpoints)
- `traverse`: Visualize latent space using trained model (need checkpoints)

&nbsp;

Expand Down
1 change: 0 additions & 1 deletion models/__init__.py

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_dsprites.sh

This file was deleted.

16 changes: 0 additions & 16 deletions scripts/run_flowers102.sh

This file was deleted.

16 changes: 0 additions & 16 deletions scripts/run_ident3d.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_shape3d.sh

This file was deleted.

File renamed without changes.
86 changes: 86 additions & 0 deletions src/ddp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import sys
import os

from loguru import logger
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler


def setup_logger(rank=-1, world_size=1):
"""Setup logger for distributed training"""
config = {
"handlers": [
{
"sink": sys.stdout,
"format": (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>Rank {extra[rank]}/{extra[world_size]}</cyan> | "
"<cyan>{name}</cyan>:<cyan>{line}</cyan> | "
"<level>{message}</level>"
),
"level": "DEBUG",
"colorize": True,
}
]
}

try: # Remove all existing handlers
logger.configure(**config)
except ValueError:
pass

# Create a new logger instance with rank information
return logger.bind(rank=rank, world_size=world_size)


def setup_distributed(params):
"""Initialize distributed training environment with explicit device mapping"""
if not params.distributed:
return False

try:
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
params.rank = int(os.environ["RANK"])
params.world_size = int(os.environ["WORLD_SIZE"])
params.local_rank = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
params.rank = int(os.environ["SLURM_PROCID"])
params.local_rank = params.rank % torch.cuda.device_count()
params.world_size = int(os.environ["SLURM_NTASKS"])
else:
raise ValueError("Not running with distributed environment variables set")

torch.cuda.set_device(params.local_rank)
init_method = "env://"
backend = params.dist_backend
if backend == "nccl" and not torch.cuda.is_available():
backend = "gloo"

if not dist.is_initialized():
dist.init_process_group(
backend=backend,
init_method=init_method,
world_size=params.world_size,
rank=params.rank,
)
torch.cuda.set_device(params.local_rank)
dist.barrier(device_ids=[params.local_rank])

return True
except Exception as e:
print(f"Failed to initialize distributed training: {e}")
return False


def cleanup_distributed():
"""Clean up distributed training resources safely with device mapping"""
if dist.is_initialized():
try:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
dist.barrier(device_ids=[local_rank])
dist.destroy_process_group()
except Exception as e:
print(f"Error during distributed cleanup: {e}")
File renamed without changes.
20 changes: 20 additions & 0 deletions src/scripts/run_dsprites.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/bin/bash

torchrun --nproc_per_node=2 --master_port=29502 src/train.py \
--distributed \
--mode seq_train \
--dataset dsprites \
--optim adamw \
--num_ladders 3 \
--batch_size 256 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 8 \
--z_dim 2 \
--coff 0.5 \
--pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
--output_dir ./output/dsprites/ \
--data_path ./data/dsprites/

13 changes: 8 additions & 5 deletions scripts/run_dtd.sh → src/scripts/run_dtd.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
#!/bin/bash

python train.py \
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset dtd \
--optim adamw \
--num_ladders 3 \
--batch_size 32 \
--batch_size 128 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 3 \
--beta 8 \
--z_dim 3 \
--coff 0.2 \
--coff 0.5 \
--pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
--output_dir ./output/dtd/ \
--optim adamw
--data_path ./data/dtd/
13 changes: 9 additions & 4 deletions scripts/run_fashionmnist.sh → src/scripts/run_fashionmnist.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
#!/bin/bash

python train.py \
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset fashionmnist \
--optim adamw \
--num_ladders 3 \
--batch_size 16 \
--batch_size 64 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 3 \
--z_dim 3 \
--z_dim 2 \
--coff 0.5 \
--pre_kl \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/fashionmnist/ \
--optim adamw
--data_path ./data/fashionmnist/

19 changes: 19 additions & 0 deletions src/scripts/run_flowers102.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset flowers102 \
--optim adamw \
--num_ladders 3 \
--batch_size 128 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 8 \
--z_dim 3 \
--coff 0.5 \
--pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
--output_dir ./output/flowers102/ \
--data_path ./data/flowers102/
10 changes: 6 additions & 4 deletions scripts/run_ddp.sh → src/scripts/run_ident3d.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
#!/bin/bash

torchrun --nproc_per_node=2 train_ddp.py \
--distributed True \
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset ident3d \
--optim adamw \
--num_ladders 3 \
--batch_size 128 \
--num_epochs 30 \
--learning_rate 5e-4 \
--beta 1 \
--beta 8 \
--z_dim 3 \
--coff 0.5 \
--pre_kl \
--hidden_dim 64 \
--fade_in_duration 5000 \
--output_dir ./output/ident3d/ \
--optim adamw
--data_path ./data/ident3d/
Loading

0 comments on commit 36e9141

Please sign in to comment.