Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training code #70

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ The default settings are optimized for the best result. However, the behavior of

- `--processing_res`: the processing resolution; set as 0 to process the input resolution directly. When unassigned (`None`), will read default setting from model config. Default: ~~768~~ `None`.
- `--output_processing_res`: produce output at the processing resolution instead of upsampling it to the input resolution. Default: False.
- `--resample_method`: resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`.
- `--resample_method`: the resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic`, or `nearest`. Default: `bilinear`.

- `--half_precision` or `--fp16`: Run with half-precision (16-bit float) to reduce VRAM usage, might lead to suboptimal result.
- `--half_precision` or `--fp16`: Run with half-precision (16-bit float) to reduce VRAM usage, which might lead to suboptimal results.
- `--seed`: Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing `--batch_size 1` helps to increase reproducibility. To ensure full reproducibility, [deterministic mode](https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms) needs to be used.
- `--batch_size`: Batch size of repeated inference. Default: 0 (best value determined automatically).
- `--color_map`: [Colormap](https://matplotlib.org/stable/users/explain/colors/colormaps.html) used to colorize the depth prediction. Default: Spectral. Set to `None` to skip colored depth map generation.
Expand Down Expand Up @@ -196,7 +196,7 @@ python run.py \
--output_dir output/in-the-wild_example
```

## 🦿 Evaluation on test datasets
## 🦿 Evaluation on test datasets <a name="evaluation"></a>

Install additional dependencies:

Expand Down Expand Up @@ -224,6 +224,43 @@ bash script/eval/12_eval_nyu.sh

Note: although the seed has been set, the results might still be slightly different on different hardware.

## 🏋️ Training

Based on the previously created environment, install extended requirements:

```bash
pip install -r requirements++.txt -r requirements+.txt -r requirements.txt
```

Set environment parameters for the data directory:

```bash
export BASE_DATA_DIR=YOUR_DATA_DIR # directory of training data
export BASE_CKPT_DIR=YOUR_CHECKPOINT_DIR # directory of pretrained checkpoint
```

Download Stable Diffusion v2 [checkpoint](https://huggingface.co/stabilityai/stable-diffusion-2) into `${BASE_CKPT_DIR}`

Prepare for [Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) datasets and save into `${BASE_DATA_DIR}`. Please refer to [this README](script/dataset_preprocess/hypersim/README.md) for Hypersim preprocessing.

Run training script

```bash
python train.py --config config/train_marigold.yaml
```

Resume from a checkpoint, e.g.

```bash
python train.py --resume_from output/marigold_base/checkpoint/latest
```

Evaluating results

Only the U-Net is updated and saved during training. To use the inference pipeline with your training result, replace `unet` folder in Marigold checkpoints with that in the `checkpoint` output folder. Then refer to [this section](#evaluation) for evaluation.

**Note**: Although random seeds have been set, the training result might be slightly different on different hardwares. It's recommended to train without interruption.

## ✏️ Contributing

Please refer to [this](CONTRIBUTING.md) instruction.
Expand Down
4 changes: 4 additions & 0 deletions config/dataset/data_hypersim_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: hypersim
disp_name: hypersim_train
dir: hypersim/hypersim_processed_train.tar
filenames: data_split/hypersim/filename_list_train_filtered.txt
4 changes: 4 additions & 0 deletions config/dataset/data_hypersim_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: hypersim
disp_name: hypersim_val
dir: hypersim/hypersim_processed_val.tar
filenames: data_split/hypersim/filename_list_val_filtered.txt
6 changes: 6 additions & 0 deletions config/dataset/data_kitti_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: kitti
disp_name: kitti_val800_from_eigen_train
dir: kitti/kitti_sampled_val_800.tar
filenames: data_split/kitti/eigen_val_from_train_800.txt
kitti_bm_crop: true
valid_mask_crop: eigen
5 changes: 5 additions & 0 deletions config/dataset/data_nyu_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: nyu_v2
disp_name: nyu_train_full
dir: nyuv2/nyu_labeled_extracted.tar
filenames: data_split/nyu/labeled/filename_list_train.txt
eigen_valid_mask: true
6 changes: 6 additions & 0 deletions config/dataset/data_vkitti_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: vkitti
disp_name: vkitti_train
dir: vkitti/vkitti.tar
filenames: data_split/vkitti/vkitti_train.txt
kitti_bm_crop: true
valid_mask_crop: null # no valid_mask_crop for training
6 changes: 6 additions & 0 deletions config/dataset/data_vkitti_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: vkitti
disp_name: vkitti_val
dir: vkitti/vkitti.tar
filenames: data_split/vkitti/vkitti_val.txt
kitti_bm_crop: true
valid_mask_crop: eigen
18 changes: 18 additions & 0 deletions config/dataset/dataset_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
dataset:
train:
name: mixed
prob_ls: [0.9, 0.1]
dataset_list:
- name: hypersim
disp_name: hypersim_train
dir: hypersim/hypersim_processed_train.tar
filenames: data_split/hypersim/filename_list_train_filtered.txt
resize_to_hw:
- 480
- 640
- name: vkitti
disp_name: vkitti_train
dir: vkitti/vkitti.tar
filenames: data_split/vkitti/vkitti_train.txt
kitti_bm_crop: true
valid_mask_crop: null
45 changes: 45 additions & 0 deletions config/dataset/dataset_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
dataset:
val:
# - name: hypersim
# disp_name: hypersim_val
# dir: hypersim/hypersim_processed_val.tar
# filenames: data_split/hypersim/filename_list_val_filtered.txt
# resize_to_hw:
# - 480
# - 640

# - name: nyu_v2
# disp_name: nyu_train_full
# dir: nyuv2/nyu_labeled_extracted.tar
# filenames: data_split/nyu/labeled/filename_list_train.txt
# eigen_valid_mask: true

# - name: kitti
# disp_name: kitti_val800_from_eigen_train
# dir: kitti/kitti_sampled_val_800.tar
# filenames: data_split/kitti/eigen_val_from_train_800.txt
# kitti_bm_crop: true
# valid_mask_crop: eigen

# Smaller subsets for faster validation during training
# The first dataset is used to calculate main eval metric.
- name: hypersim
disp_name: hypersim_val_small_80
dir: hypersim/hypersim_processed_val.tar
filenames: data_split/hypersim/filename_list_val_filtered_small_80.txt
resize_to_hw:
- 480
- 640

- name: nyu_v2
disp_name: nyu_train_small_100
dir: nyuv2/nyu_labeled_extracted.tar
filenames: data_split/nyu/labeled/filename_list_train_small_100.txt
eigen_valid_mask: true

- name: kitti
disp_name: kitti_val_from_train_sub_100
dir: kitti/kitti_sampled_val_800.tar
filenames: data_split/kitti/eigen_val_from_train_sub_100.txt
kitti_bm_crop: true
valid_mask_crop: eigen
9 changes: 9 additions & 0 deletions config/dataset/dataset_vis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
dataset:
vis:
- name: hypersim
disp_name: hypersim_vis
dir: hypersim/hypersim_processed_val.tar
filenames: data_split/hypersim/selected_vis_sample.txt
resize_to_hw:
- 480
- 640
5 changes: 5 additions & 0 deletions config/logging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
logging:
filename: logging.log
format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s'
console_level: 20
file_level: 10
4 changes: 4 additions & 0 deletions config/model_sdv2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model:
name: marigold_pipeline
pretrained_path: stable-diffusion-2
latent_scale_factor: 0.18215
12 changes: 12 additions & 0 deletions config/train_debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
base_config:
- config/train_marigold.yaml


# Training settings
trainer:
save_period: 5
backup_period: 10
validation_period: 5
visualization_period: 5

max_iter: 50
94 changes: 94 additions & 0 deletions config/train_marigold.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
base_config:
- config/logging.yaml
- config/wandb.yaml
- config/dataset/dataset_train.yaml
- config/dataset/dataset_val.yaml
- config/dataset/dataset_vis.yaml
- config/model_sdv2.yaml


pipeline:
name: MarigoldPipeline
kwargs:
scale_invariant: true
shift_invariant: true

depth_normalization:
type: scale_shift_depth
clip: true
norm_min: -1.0
norm_max: 1.0
min_max_quantile: 0.02

augmentation:
lr_flip_p: 0.5

dataloader:
num_workers: 2
effective_batch_size: 32
max_train_batch_size: 2
seed: 2024 # to ensure continuity when resuming from checkpoint

# Training settings
trainer:
name: MarigoldTrainer
training_noise_scheduler:
pretrained_path: stable-diffusion-2
init_seed: 2024 # use null to train w/o seeding
save_period: 50
backup_period: 2000
validation_period: 2000
visualization_period: 2000

multi_res_noise:
strength: 0.9
annealed: true
downscale_strategy: original

gt_depth_type: depth_raw_norm
gt_mask_type: valid_mask_raw

max_epoch: 10000 # a large enough number
max_iter: 30000 # usually converges at around 20k

optimizer:
name: Adam

loss:
name: mse_loss
kwargs:
reduction: mean

lr: 3.0e-05
lr_scheduler:
name: IterExponential
kwargs:
total_iter: 25000
final_ratio: 0.01
warmup_steps: 100

# Validation (and visualization) settings
validation:
denoising_steps: 50
ensemble_size: 1 # simplified setting for on-training validation
processing_res: 0
match_input_res: false
resample_method: bilinear
main_val_metric: abs_relative_difference
main_val_metric_goal: minimize
init_seed: 2024

eval:
alignment: least_square
align_max_res: null
eval_metrics:
- abs_relative_difference
- squared_relative_difference
- rmse_linear
- rmse_log
- log10
- delta1_acc
- delta2_acc
- delta3_acc
- i_rmse
- silog_rmse
3 changes: 3 additions & 0 deletions config/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
wandb:
# entity: your_entity
project: marigold
80 changes: 80 additions & 0 deletions data_split/hypersim/filename_list_val_filtered_small_80.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
ai_003_010/rgb_cam_00_fr0047.png ai_003_010/depth_plane_cam_00_fr0047.png
ai_003_010/rgb_cam_00_fr0048.png ai_003_010/depth_plane_cam_00_fr0048.png
ai_003_010/rgb_cam_01_fr0098.png ai_003_010/depth_plane_cam_01_fr0098.png
ai_004_003/rgb_cam_01_fr0008.png ai_004_003/depth_plane_cam_01_fr0008.png
ai_004_004/rgb_cam_00_fr0025.png ai_004_004/depth_plane_cam_00_fr0025.png
ai_004_004/rgb_cam_00_fr0046.png ai_004_004/depth_plane_cam_00_fr0046.png
ai_004_004/rgb_cam_00_fr0049.png ai_004_004/depth_plane_cam_00_fr0049.png
ai_004_004/rgb_cam_01_fr0023.png ai_004_004/depth_plane_cam_01_fr0023.png
ai_005_005/rgb_cam_00_fr0032.png ai_005_005/depth_plane_cam_00_fr0032.png
ai_006_007/rgb_cam_00_fr0022.png ai_006_007/depth_plane_cam_00_fr0022.png
ai_006_007/rgb_cam_00_fr0095.png ai_006_007/depth_plane_cam_00_fr0095.png
ai_007_001/rgb_cam_00_fr0044.png ai_007_001/depth_plane_cam_00_fr0044.png
ai_007_001/rgb_cam_00_fr0048.png ai_007_001/depth_plane_cam_00_fr0048.png
ai_009_007/rgb_cam_00_fr0017.png ai_009_007/depth_plane_cam_00_fr0017.png
ai_009_007/rgb_cam_00_fr0097.png ai_009_007/depth_plane_cam_00_fr0097.png
ai_009_009/rgb_cam_00_fr0094.png ai_009_009/depth_plane_cam_00_fr0094.png
ai_015_001/rgb_cam_00_fr0058.png ai_015_001/depth_plane_cam_00_fr0058.png
ai_015_001/rgb_cam_00_fr0089.png ai_015_001/depth_plane_cam_00_fr0089.png
ai_017_007/rgb_cam_01_fr0064.png ai_017_007/depth_plane_cam_01_fr0064.png
ai_018_005/rgb_cam_00_fr0014.png ai_018_005/depth_plane_cam_00_fr0014.png
ai_018_005/rgb_cam_00_fr0059.png ai_018_005/depth_plane_cam_00_fr0059.png
ai_022_010/rgb_cam_00_fr0097.png ai_022_010/depth_plane_cam_00_fr0097.png
ai_022_010/rgb_cam_00_fr0099.png ai_022_010/depth_plane_cam_00_fr0099.png
ai_023_003/rgb_cam_00_fr0013.png ai_023_003/depth_plane_cam_00_fr0013.png
ai_023_003/rgb_cam_00_fr0015.png ai_023_003/depth_plane_cam_00_fr0015.png
ai_023_003/rgb_cam_00_fr0036.png ai_023_003/depth_plane_cam_00_fr0036.png
ai_023_003/rgb_cam_00_fr0095.png ai_023_003/depth_plane_cam_00_fr0095.png
ai_023_003/rgb_cam_01_fr0029.png ai_023_003/depth_plane_cam_01_fr0029.png
ai_023_003/rgb_cam_01_fr0036.png ai_023_003/depth_plane_cam_01_fr0036.png
ai_023_003/rgb_cam_01_fr0071.png ai_023_003/depth_plane_cam_01_fr0071.png
ai_032_007/rgb_cam_00_fr0031.png ai_032_007/depth_plane_cam_00_fr0031.png
ai_032_007/rgb_cam_00_fr0040.png ai_032_007/depth_plane_cam_00_fr0040.png
ai_032_007/rgb_cam_00_fr0075.png ai_032_007/depth_plane_cam_00_fr0075.png
ai_035_003/rgb_cam_00_fr0054.png ai_035_003/depth_plane_cam_00_fr0054.png
ai_035_004/rgb_cam_00_fr0077.png ai_035_004/depth_plane_cam_00_fr0077.png
ai_038_009/rgb_cam_00_fr0031.png ai_038_009/depth_plane_cam_00_fr0031.png
ai_038_009/rgb_cam_01_fr0010.png ai_038_009/depth_plane_cam_01_fr0010.png
ai_038_009/rgb_cam_01_fr0088.png ai_038_009/depth_plane_cam_01_fr0088.png
ai_039_003/rgb_cam_01_fr0042.png ai_039_003/depth_plane_cam_01_fr0042.png
ai_039_003/rgb_cam_01_fr0097.png ai_039_003/depth_plane_cam_01_fr0097.png
ai_044_001/rgb_cam_00_fr0043.png ai_044_001/depth_plane_cam_00_fr0043.png
ai_044_001/rgb_cam_01_fr0018.png ai_044_001/depth_plane_cam_01_fr0018.png
ai_044_003/rgb_cam_01_fr0082.png ai_044_003/depth_plane_cam_01_fr0082.png
ai_044_003/rgb_cam_01_fr0087.png ai_044_003/depth_plane_cam_01_fr0087.png
ai_044_003/rgb_cam_02_fr0086.png ai_044_003/depth_plane_cam_02_fr0086.png
ai_044_003/rgb_cam_03_fr0022.png ai_044_003/depth_plane_cam_03_fr0022.png
ai_044_003/rgb_cam_03_fr0063.png ai_044_003/depth_plane_cam_03_fr0063.png
ai_045_008/rgb_cam_00_fr0015.png ai_045_008/depth_plane_cam_00_fr0015.png
ai_045_008/rgb_cam_00_fr0030.png ai_045_008/depth_plane_cam_00_fr0030.png
ai_045_008/rgb_cam_01_fr0029.png ai_045_008/depth_plane_cam_01_fr0029.png
ai_045_008/rgb_cam_01_fr0052.png ai_045_008/depth_plane_cam_01_fr0052.png
ai_045_008/rgb_cam_01_fr0088.png ai_045_008/depth_plane_cam_01_fr0088.png
ai_047_009/rgb_cam_00_fr0097.png ai_047_009/depth_plane_cam_00_fr0097.png
ai_048_001/rgb_cam_00_fr0014.png ai_048_001/depth_plane_cam_00_fr0014.png
ai_048_001/rgb_cam_00_fr0088.png ai_048_001/depth_plane_cam_00_fr0088.png
ai_048_001/rgb_cam_01_fr0045.png ai_048_001/depth_plane_cam_01_fr0045.png
ai_048_001/rgb_cam_02_fr0031.png ai_048_001/depth_plane_cam_02_fr0031.png
ai_048_001/rgb_cam_03_fr0005.png ai_048_001/depth_plane_cam_03_fr0005.png
ai_048_001/rgb_cam_03_fr0045.png ai_048_001/depth_plane_cam_03_fr0045.png
ai_048_001/rgb_cam_03_fr0054.png ai_048_001/depth_plane_cam_03_fr0054.png
ai_048_001/rgb_cam_03_fr0061.png ai_048_001/depth_plane_cam_03_fr0061.png
ai_050_002/rgb_cam_01_fr0016.png ai_050_002/depth_plane_cam_01_fr0016.png
ai_050_002/rgb_cam_02_fr0053.png ai_050_002/depth_plane_cam_02_fr0053.png
ai_050_002/rgb_cam_03_fr0082.png ai_050_002/depth_plane_cam_03_fr0082.png
ai_050_002/rgb_cam_04_fr0033.png ai_050_002/depth_plane_cam_04_fr0033.png
ai_051_004/rgb_cam_00_fr0028.png ai_051_004/depth_plane_cam_00_fr0028.png
ai_051_004/rgb_cam_01_fr0065.png ai_051_004/depth_plane_cam_01_fr0065.png
ai_051_004/rgb_cam_02_fr0054.png ai_051_004/depth_plane_cam_02_fr0054.png
ai_051_004/rgb_cam_02_fr0056.png ai_051_004/depth_plane_cam_02_fr0056.png
ai_051_004/rgb_cam_03_fr0037.png ai_051_004/depth_plane_cam_03_fr0037.png
ai_051_004/rgb_cam_04_fr0083.png ai_051_004/depth_plane_cam_04_fr0083.png
ai_051_004/rgb_cam_05_fr0003.png ai_051_004/depth_plane_cam_05_fr0003.png
ai_052_001/rgb_cam_00_fr0008.png ai_052_001/depth_plane_cam_00_fr0008.png
ai_052_003/rgb_cam_00_fr0097.png ai_052_003/depth_plane_cam_00_fr0097.png
ai_052_003/rgb_cam_01_fr0081.png ai_052_003/depth_plane_cam_01_fr0081.png
ai_052_007/rgb_cam_01_fr0001.png ai_052_007/depth_plane_cam_01_fr0001.png
ai_053_003/rgb_cam_00_fr0005.png ai_053_003/depth_plane_cam_00_fr0005.png
ai_053_005/rgb_cam_00_fr0080.png ai_053_005/depth_plane_cam_00_fr0080.png
ai_055_009/rgb_cam_01_fr0070.png ai_055_009/depth_plane_cam_01_fr0070.png
ai_055_009/rgb_cam_01_fr0086.png ai_055_009/depth_plane_cam_01_fr0086.png
3 changes: 3 additions & 0 deletions data_split/hypersim/selected_vis_sample.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ai_015_004/rgb_cam_00_fr0002.png ai_015_004/depth_plane_cam_00_fr0002.png (val)
ai_044_003/rgb_cam_01_fr0063.png ai_044_003/depth_plane_cam_01_fr0063.png (val)
ai_052_003/rgb_cam_01_fr0076.png ai_052_003/depth_plane_cam_01_fr0076.png (val)
Loading
Loading