diff --git a/README.md b/README.md index 295cf95..ef68d6c 100755 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/ashleve/lightning-hydra-template#license) -[**Project Page**](https://haoyizhu.github.io/spa/) | [**Paper**](https://haoyizhu.github.io/spa/static/images/paper.pdf) | [**Arxiv**]() | [**HF Model**](https://huggingface.co/HaoyiZhu/SPA) +[**Project Page**](https://haoyizhu.github.io/spa/) | [**Paper**](https://haoyizhu.github.io/spa/static/images/paper.pdf) | [**arXiv**](https://arxiv.org/abs/2410.08208) | [**HuggingFace Model**](https://huggingface.co/HaoyiZhu/SPA) | [**Real-World Codebase**](https://github.com/HaoyiZhu/RealRobot) [Haoyi Zhu](https://www.haoyizhu.site/), [Honghui Yang](https://hhyangcs.github.io/), [Yating Wang](https://scholar.google.com/citations?hl=zh-CN&user=5SuBWh0AAAAJ), [Jiange Yang](https://yangjiangeyjg.github.io/), [Liming Wang](https://wanglimin.github.io/), [Tong He](http://tonghe90.github.io/) @@ -20,13 +20,352 @@ **SPA** is a novel representation learning framework that emphasizes the importance of **3D spatial awareness in embodied AI**. It leverages **differentiable neural rendering** on multi-view images to endow a vanilla Vision Transformer (ViT) with intrinsic spatial understanding. We also present the most comprehensive evaluation of embodied representation learning to date, covering **268 tasks** across **8 simulators** with diverse policies in both single-task and language-conditioned multi-task scenarios. +:partying_face: **NEWS**: + +- *Oct. 2024:* Codebase and pre-trained checkpoints are released! Paper is available on [arXiv](https://arxiv.org/abs/2410.08208). + +## :clipboard: Contents + +- [Project Structure](#telescope-project-structure) +- [Installation](#installation) +- [Usage](#star2-usage) +- [Pre-Training](#rocket-pre-training) +- [SPA Large-Scale Evaluation](#bulb-spa-large-scale-evaluation) +- [Gotchas](#tada-gotchas) +- [License](#books-license) +- [Acknowledgement](#sparkles-acknowledgement) +- [Citation](#pencil-citation) + +## :telescope: Project Structure + +Our codebase draws significant inspiration from the excellent [Lightning Hydra Template](https://github.com/ashleve/lightning-hydra-template). The directory structure of this project is organized as follows: + +
+Show directory structure + +``` +├── .github <- Github Actions workflows +│ +├── configs <- Hydra configs +│ ├── callbacks <- Callbacks configs +│ ├── data <- Data configs +│ ├── debug <- Debugging configs +│ ├── experiment <- Experiment configs +│ ├── extras <- Extra utilities configs +│ ├── hydra <- Hydra configs +│ ├── local <- Local configs +│ ├── logger <- Logger configs +│ ├── model <- Model configs +│ ├── paths <- Project paths configs +│ ├── trainer <- Trainer configs +| | +│ └── train.yaml <- Main config for training +│ +├── data <- Project data +│ +├── logs <- Logs generated by hydra and lightning loggers +│ +├── scripts <- Shell or Python scripts +| +├── spa <- Source code of SPA +│ ├── data <- Data scripts +│ ├── models <- Model scripts +│ ├── utils <- Utility scripts +│ │ +│ └── train.py <- Run SPA pre-training +│ +├── .gitignore <- List of files ignored by git +├── .project-root <- File for inferring the position of project root directory +├── requirements.txt <- File for installing python dependencies +├── setup.py <- File for installing project as a package +└── README.md +``` + +
+ +## :hammer: Installation +
+Basics + +```bash +# clone project +git clone https://github.com/HaoyiZhu/SPA.git +cd SPA + +# crerate conda environment +conda create -n spa python=3.11 -y +conda activate spa + +# install PyTorch, please refer to https://pytorch.org/ for other CUDA versions +# e.g. cuda 11.8: +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +# install basic packages +pip3 install -r requirements.txt +``` +
+ +
+SPA + +```bash +# (optional) if you want to use SPA's volume decoder +cd libs/spa-ops +pip install -e . +cd ../.. + +# install SPA, so that you can import from anywhere +pip install -e . +``` +
+ +## :star2: Usage + +
+ Example of Using SPA Pre-trained Encoder + +We provide pre-trained SPA weights for feature extraction. The checkpoints are available on [🤗Hugging Face](https://huggingface.co/HaoyiZhu/SPA). You don't need to manually download the weights, as SPA will automatically handle this if needed. + +```python +import torch + +from spa.models import spa_vit_base_patch16, spa_vit_large_patch16 + +image = torch.rand((1, 3, 224, 224)) # range in [0, 1] + +# Example usage of SPA-Large (recommended) +# or you can use `spa_vit_base_patch16` for SPA-base +model = spa_vit_large_patch16(pretrained=True) +model.eval() + +# Freeze the model +model.freeze() + +# (Recommended) move to CUDA +image = image.cuda() +model = model.cuda() + +# Obtain the [CLS] token +cls_token = model(image) # torch.Size([1, 1024]) + +# Obtain the reshaped feature map concatenated with [CLS] token +feature_map_cat_cls = model( + image, feature_map=True, cat_cls=True +) # torch.Size([1, 2048, 14, 14]) + +# Obtain the reshaped feature map without [CLS] token +feature_map_wo_cls = model( + image, feature_map=True, cat_cls=False +) # torch.Size([1, 1024, 14, 14]) +``` + +> **Note:** The inputs will be automatically resized to `224 x 224` and normalized within the [SPA ViT encoder](spa/models/components/img_backbones/vit.py#L69). + +
+ + + +## :rocket: Pre-Training + +
+ Example of Pre-Training on ScanNet + +We give an example on pre-training SPA on the [ScanNet](http://www.scan-net.org/) v2 dataset. + +1) Prepare the dataset + - Download the [ScanNet](http://www.scan-net.org/) v2 dataset. + - Pre-process and extract RGB-D images following [PonderV2](https://github.com/OpenGVLab/PonderV2/blob/main/docs/data_preparation.md#scannet-v2). The preprocessed data should be put under `data/scannet/`. + - Pre-generate metadata for fast data loading. The following command will generate metadata under `data/scannet/metadata`. + ```bash + python scripts/generate_scannet_metadata.py + ``` + +2) Run the following command for pre-training. Remember to modify hyper-parameters such as number of nodes and GPU devices according to your machines. + ```bash + python spa/train.py experiment=spa_pretrain_vitl trainer.num_nodes=5 trainer.devices=8 + ``` + +
+ +## :bulb: SPA Large-Scale Evaluation + +
+ TBD + +
+ +## :tada: Gotchas + +
+ Override any config parameter from command line + +This codebase is based on [Hydra](https://github.com/facebookresearch/hydra), which allows for convenient configuration overriding: +```bash +python src/train.py trainer.max_epochs=20 seed=300 +``` +> **Note**: You can also add new parameters with `+` sign. +```bash +python src/train.py +some_new_param=some_new_value +``` + +
+ +
+Train on CPU, GPU, multi-GPU and TPU + +```bash +# train on CPU +python src/train.py trainer=cpu + +# train on 1 GPU +python src/train.py trainer=gpu + +# train on TPU +python src/train.py +trainer.tpu_cores=8 + +# train with DDP (Distributed Data Parallel) (4 GPUs) +python src/train.py trainer=ddp trainer.devices=4 + +# train with DDP (Distributed Data Parallel) (8 GPUs, 2 nodes) +python src/train.py trainer=ddp trainer.devices=4 trainer.num_nodes=2 + +# simulate DDP on CPU processes +python src/train.py trainer=ddp_sim trainer.devices=2 + +# accelerate training on mac +python src/train.py trainer=mps +``` + +
+ +
+Train with mixed precision + +```bash +# train with pytorch native automatic mixed precision (AMP) +python src/train.py trainer=gpu +trainer.precision=16 +``` + +
+ +
+Use different tricks available in Pytorch Lightning + +```yaml +# gradient clipping may be enabled to avoid exploding gradients +python src/train.py trainer.gradient_clip_val=0.5 + +# run validation loop 4 times during a training epoch +python src/train.py +trainer.val_check_interval=0.25 + +# accumulate gradients +python src/train.py trainer.accumulate_grad_batches=10 + +# terminate training after 12 hours +python src/train.py +trainer.max_time="00:12:00:00" +``` + +> **Note**: PyTorch Lightning provides about [40+ useful trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags). + +
+ +
+Easily debug + +```bash +# runs 1 epoch in default debugging mode +# changes logging directory to `logs/debugs/...` +# sets level of all command line loggers to 'DEBUG' +# enforces debug-friendly configuration +python src/train.py debug=default + +# run 1 train, val and test loop, using only 1 batch +python src/train.py debug=fdr + +# print execution time profiling +python src/train.py debug=profiler + +# try overfitting to 1 batch +python src/train.py debug=overfit + +# raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf +python src/train.py +trainer.detect_anomaly=true + +# use only 20% of the data +python src/train.py +trainer.limit_train_batches=0.2 \ ++trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2 +``` + +> **Note**: Visit [configs/debug/](configs/debug/) for different debugging configs. + +
+ +
+Resume training from checkpoint + +```yaml +python src/train.py ckpt_path="/path/to/ckpt/name.ckpt" +``` + +> **Note**: Checkpoint can be either path or URL. + +> **Note**: Currently loading ckpt doesn't resume logger experiment, but it will be supported in future Lightning release. + +
+ +
+Create a sweep over hyperparameters + +```bash +# this will run 9 experiments one after the other, +# each with different combination of seed and learning rate +python src/train.py -m seed=100,200,300 model.optimizer.lr=0.0001,0.00005,0.00001 +``` + +> **Note**: Hydra composes configs lazily at job launch time. If you change code or configs after launching a job/sweep, the final composed configs might be impacted. + +
+ +
+Execute all experiments from folder + +```bash +python src/train.py -m 'exp_maniskill2_act_policy/maniskill2_task@maniskill2_task=glob(*)' +``` + +> **Note**: Hydra provides special syntax for controlling behavior of multiruns. Learn more [here](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run). The command above executes all task experiments from [configs/exp_maniskill2_act_policy/maniskill2_task](configs/experiment/). + +
+ +
+Execute run for multiple different seeds + +```bash +python src/train.py -m seed=100,200,300 trainer.deterministic=True +``` + +> **Note**: `trainer.deterministic=True` makes pytorch more deterministic but impacts the performance. + +
+ +For more instructions, refer to the official documentation for [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning), [Hydra](https://github.com/facebookresearch/hydra), and [Lightning Hydra Template](https://github.com/ashleve/lightning-hydra-template). + +## :books: License + +This repository is released under the [MIT license](LICENSE). + +## :sparkles: Acknowledgement + +Our work is primarily built upon [PonderV2](https://github.com/OpenGVLab/PonderV2), [UniPAD](https://github.com/Nightmare-n/UniPAD), [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning), [Hydra](https://github.com/facebookresearch/hydra), [Lightning Hydra Template](https://github.com/ashleve/lightning-hydra-template), [RLBench](https://github.com/stepjam/RLBench), [PerAct](https://github.com/peract/peract), [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO), [Meta-Wolrd](https://github.com/Farama-Foundation/Metaworld), [ACT](https://github.com/tonyzhaozh/act), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [DP3](https://github.com/YanjieZe/3D-Diffusion-Policy), [TIMM](https://github.com/huggingface/pytorch-image-models), [VC1](https://github.com/facebookresearch/eai-vc), [R3M](https://github.com/facebookresearch/r3m). We extend our gratitude to all these authors for their generously open-sourced code and their significant contributions to the community. + +Contact [Haoyi Zhu](https://www.haoyizhu.site/) if you have any questions or suggestions. + ## :pencil: Citation ```bib @article{zhu2024spa, title = {SPA: 3D Spatial-Awareness Enables Effective Embodied Representation}, author = {Zhu, Haoyi and and Yang, Honghui and Wang, Yating and Yang, Jiange and Wang, Limin and He, Tong}, - journal = {arXiv preprint}, + journal = {arXiv preprint arxiv:2410.08208}, year = {2024}, } ``` \ No newline at end of file diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100755 index 0000000..56bf7f4 --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1 @@ +# this file is needed here to include configs when building project as a package diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100755 index 0000000..628fa70 --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,26 @@ +defaults: + - model_checkpoint + - early_stopping + - model_summary + - rich_progress_bar + - lr_monitor + - device_stats_monitor + # - stochastic_weight_averaging + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + monitor: "val/loss" + mode: "min" + save_last: true + auto_insert_metric_name: false + save_top_k: 3 + +early_stopping: + monitor: "val/loss" + patience: 100 + mode: "min" + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/default_ema.yaml b/configs/callbacks/default_ema.yaml new file mode 100644 index 0000000..8908323 --- /dev/null +++ b/configs/callbacks/default_ema.yaml @@ -0,0 +1,32 @@ +defaults: + - ema_model_checkpoint + - early_stopping + - model_summary + - rich_progress_bar + - lr_monitor + - device_stats_monitor + - ema + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + monitor: "val/loss" + mode: "min" + save_last: true + auto_insert_metric_name: false + save_top_k: 3 + +early_stopping: + monitor: "val/loss" + patience: 100 + mode: "min" + +model_summary: + max_depth: -1 + +ema: + decay: 0.999 + validate_original_weights: false + every_n_steps: 1 + cpu_offload: true diff --git a/configs/callbacks/device_stats_monitor.yaml b/configs/callbacks/device_stats_monitor.yaml new file mode 100644 index 0000000..ee34f4b --- /dev/null +++ b/configs/callbacks/device_stats_monitor.yaml @@ -0,0 +1,3 @@ +device_stat_monitor: + _target_: lightning.pytorch.callbacks.DeviceStatsMonitor + cpu_stats: null \ No newline at end of file diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml new file mode 100755 index 0000000..f4c90e0 --- /dev/null +++ b/configs/callbacks/early_stopping.yaml @@ -0,0 +1,15 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html + +early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: ??? # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: false # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: true # whether to crash the training if monitor is not found in the validation metrics + check_finite: true # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/configs/callbacks/ema.yaml b/configs/callbacks/ema.yaml new file mode 100644 index 0000000..02450c7 --- /dev/null +++ b/configs/callbacks/ema.yaml @@ -0,0 +1,6 @@ +ema: + _target_: spa.utils.callbacks.EMA + decay: 0.999 + validate_original_weights: false + every_n_steps: 1 + cpu_offload: true \ No newline at end of file diff --git a/configs/callbacks/ema_model_checkpoint.yaml b/configs/callbacks/ema_model_checkpoint.yaml new file mode 100755 index 0000000..e1a6360 --- /dev/null +++ b/configs/callbacks/ema_model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: spa.utils.callbacks.EMAModelCheckpoint + dirpath: # directory to save the model file + filename: # checkpoint filename + monitor: # name of the logged metric which determines when model is improving + verbose: false # verbosity mode + save_last: # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 5 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: true # when True, the checkpoints filenames will contain the metric name + save_weights_only: false # if True, then only the model’s weights will be saved + every_n_train_steps: # number of training steps between checkpoints + train_time_interval: # checkpoints are monitored at the specified time interval + every_n_epochs: # number of epochs between checkpoints + save_on_train_epoch_end: # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/lr_monitor.yaml b/configs/callbacks/lr_monitor.yaml new file mode 100755 index 0000000..1527b90 --- /dev/null +++ b/configs/callbacks/lr_monitor.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#lightning.pytorch.callbacks.LearningRateMonitor + +lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: 'step' diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100755 index 0000000..0a5a053 --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: # directory to save the model file + filename: # checkpoint filename + monitor: # name of the logged metric which determines when model is improving + verbose: false # verbosity mode + save_last: # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 10 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: true # when True, the checkpoints filenames will contain the metric name + save_weights_only: false # if True, then only the model’s weights will be saved + every_n_train_steps: # number of training steps between checkpoints + train_time_interval: # checkpoints are monitored at the specified time interval + every_n_epochs: # number of epochs between checkpoints + save_on_train_epoch_end: # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100755 index 0000000..b75981d --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml new file mode 100755 index 0000000..e69de29 diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100755 index 0000000..de6f1cc --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,4 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/configs/callbacks/stochastic_weight_averaging.yaml b/configs/callbacks/stochastic_weight_averaging.yaml new file mode 100755 index 0000000..a28826c --- /dev/null +++ b/configs/callbacks/stochastic_weight_averaging.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.StochasticWeightAveraging.html#stochasticweightaveraging + +stochastic_weight_averaging: + _target_: lightning.pytorch.callbacks.StochasticWeightAveraging + swa_lrs: 0.05 diff --git a/configs/data/scannet.yaml b/configs/data/scannet.yaml new file mode 100644 index 0000000..b5b82d1 --- /dev/null +++ b/configs/data/scannet.yaml @@ -0,0 +1,100 @@ +_target_: spa.data.base_cat_datamodule.BaseCatDataModule + +batch_size_train: 4 # per device +batch_size_val: 1 +batch_size_test: 1 +num_workers: 16 +pin_memory: true + +data_scannet_processor_cfg: + enabled_proc_list: + train: + - imresize + - imcrop + - trans_to_local + - merge_trans_matrix + - filter_depth_outlier_old + - calc_ray_from_depth_v2 + - calc_scene_bbox + - calc_voxel_size + - sample_ray + - collect + test: + - imresize + - imcrop + - trans_to_local + - merge_trans_matrix + - filter_depth_outlier_old + - calc_ray_from_depth_v2 + - calc_scene_bbox + - calc_voxel_size + - sample_ray + - collect + proc_config: + imresize: + mv_consistency: true + extra_keys: ["depth"] + resize_scale: + train: [0.47, 0.5] + test: [0.47, 0.47] + imcrop: # (480, 640) + mv_consistency: true + extra_keys: ["depth"] + crop_size: [224, 224] + crop_pos: + train: [[0.0, 1.0], [0.0, 1.0]] + test: [[0.5, 0.5], [0.5, 0.5]] + trans_to_local: + mapping_keys: {} + to_local_key: ['world2cam', True] + merge_trans_matrix: + keys: + trans2d_matrix: [['cam2img', False], ] + trans3d_matrix: [['world2cam', True], ] + trans_normal: + key: ['world2cam', True] + filter_depth_outlier_old: + percentile: 0.05 + calc_ray_from_depth_v2: + placeholder: null + calc_scene_bbox: + type: dynamic_depth + calc_voxel_size: + grid_size: [128, 128, 32] + sample_ray: + collider: + type: AABBBoxColliderNp + near_plane: 0.1 + ray_nsample: 512 + collect: + keys: + train: ['img', 'dataset_name', 'semantic_img', 'depth', 'ori_shape', 'world2cam', 'cam2img', 'trans2d_matrix', 'scene_name', 'frame_list', 'ray_depth', 'ray_o', 'ray_d', 'ray_p', 'ray_near', 'ray_far', 'ray_idx', 'ray_scale', 'point_cloud_range', 'voxel_size', 'grid_size'] + test: ['img', 'dataset_name', 'semantic_img', 'depth', 'ori_shape', 'world2cam', 'cam2img', 'trans2d_matrix', 'scene_name', 'frame_list', 'ray_depth', 'ray_o', 'ray_d', 'ray_p', 'ray_near', 'ray_far', 'ray_idx', 'ray_scale', 'point_cloud_range', 'voxel_size', 'grid_size'] + +train: + - _target_: spa.data.components.scannet.ScanNetMultiViewSPAPretrain + split: train + scene_root: data/scannet/ + frame_interval: 20 + downsample_ratio: 0.5 + num_cameras: 8 + loop: 3 + data_processor_cfg: ${data.data_scannet_processor_cfg} + batch_max_num_img: 24 + mode: train + scene_box_threshold: 0.05 + depth_area_threshold: 0.1 + semantic_size: [1024, 1024] + +val: + _target_: spa.data.components.scannet.ScanNetMultiViewSPAPretrain + split: val + scene_root: data/scannet/ + frame_interval: 20 + num_cameras: 8 + loop: 1 + downsample_ratio: 1 + data_processor_cfg: ${data.data_scannet_processor_cfg} + batch_max_num_img: 24 + mode: test + semantic_size: [1024, 1024] diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml new file mode 100755 index 0000000..b90fa15 --- /dev/null +++ b/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +callbacks: +logger: + +extras: + ignore_warnings: false + enforce_tags: false + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: false # disable gpu memory pin diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml new file mode 100755 index 0000000..7f2d34f --- /dev/null +++ b/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + fast_dev_run: true diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml new file mode 100755 index 0000000..514d77f --- /dev/null +++ b/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml new file mode 100755 index 0000000..22a6b5b --- /dev/null +++ b/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml new file mode 100755 index 0000000..2bd7da8 --- /dev/null +++ b/configs/debug/profiler.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + profiler: "simple" + # profiler: "advanced" + # profiler: "pytorch" diff --git a/configs/experiment/spa_pretrain_vitb.yaml b/configs/experiment/spa_pretrain_vitb.yaml new file mode 100755 index 0000000..4063a38 --- /dev/null +++ b/configs/experiment/spa_pretrain_vitb.yaml @@ -0,0 +1,70 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: scannet + - override /trainer: ddp + - override /model: spa_pretrain + - override /callbacks: default_ema + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +task_name: "spa_pretrain_vitb" + +default_ema: + ema: + decay: 0.999 + validate_original_weights: false + every_n_steps: 1 + cpu_offload: true + +trainer: + devices: 8 + num_nodes: 1 + max_epochs: 2000 + check_val_every_n_epoch: 20 + accelerator: gpu + strategy: auto + precision: 16-mixed + gradient_clip_val: 1.0 + accumulate_grad_batches: 8 + sync_batchnorm: false + limit_val_batches: 2 + use_distributed_sampler: false + +lr_scale: 1.0 + +data: + batch_size_train: 2 + num_workers: 16 + +model: + optimizer: + type: AdamW + lr: ${eval:'0.00001 * ${data.batch_size_train} * ${trainer.accumulate_grad_batches} * ${trainer.devices} * ${trainer.num_nodes} / 8'} + weight_decay: 0.04 + betas: [0.9, 0.95] + layer_decay: 0.8 + verbose: true + param_dicts: null + lr_scheduler: + scheduler: + type: OneCycleLR + max_lr: ${model.optimizer.lr} + pct_start: 0.05 + anneal_strategy: cos + div_factor: 100.0 + final_div_factor: 1000.0 + # monitor: val/loss + interval: step + frequency: 1 + model: + ckpt_name: spa-b + img_backbone: + embed_dim: 768 + depth: 12 + num_heads: 12 + \ No newline at end of file diff --git a/configs/experiment/spa_pretrain_vitl.yaml b/configs/experiment/spa_pretrain_vitl.yaml new file mode 100755 index 0000000..9f6a4d4 --- /dev/null +++ b/configs/experiment/spa_pretrain_vitl.yaml @@ -0,0 +1,63 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: scannet + - override /trainer: ddp + - override /model: spa_pretrain + - override /callbacks: default_ema + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +task_name: "spa_pretrain_vitl" + +default_ema: + ema: + decay: 0.999 + validate_original_weights: false + every_n_steps: 1 + cpu_offload: true + +trainer: + devices: 8 + num_nodes: 1 + max_epochs: 2000 + check_val_every_n_epoch: 20 + accelerator: gpu + strategy: auto + precision: 16-mixed + gradient_clip_val: 1.0 + accumulate_grad_batches: 8 + sync_batchnorm: false + limit_val_batches: 2 + use_distributed_sampler: false + +lr_scale: 1.0 + +data: + batch_size_train: 2 + num_workers: 16 + +model: + optimizer: + type: AdamW + lr: ${eval:'0.000005 * ${data.batch_size_train} * ${trainer.accumulate_grad_batches} * ${trainer.devices} * ${trainer.num_nodes} / 8'} + weight_decay: 0.04 + betas: [0.9, 0.95] + layer_decay: 0.8 + verbose: true + param_dicts: null + lr_scheduler: + scheduler: + type: OneCycleLR + max_lr: ${model.optimizer.lr} + pct_start: 0.05 + anneal_strategy: cos + div_factor: 100.0 + final_div_factor: 1000.0 + # monitor: val/loss + interval: step + frequency: 1 \ No newline at end of file diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100755 index 0000000..7e94b39 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: false + +# ask user for tags if none are provided in the config +enforce_tags: true + +# pretty print config tree at the start of the run using Rich library +print_config: true diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100755 index 0000000..a61e9b3 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${task_name}.log diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep new file mode 100755 index 0000000..e69de29 diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml new file mode 100755 index 0000000..2abfa27 --- /dev/null +++ b/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100755 index 0000000..a669d67 --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: + # experiment_name: "" + experiment_key: # set to resume experiment + offline: false + prefix: "" diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100755 index 0000000..fa028e9 --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml new file mode 100755 index 0000000..934ff79 --- /dev/null +++ b/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet.yaml + - csv + # - mlflow.yaml + # - neptune.yaml + - tensorboard + - wandb diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100755 index 0000000..ed2535a --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: + # save_dir: "./mlruns" + prefix: "" + artifact_location: + # run_id: "" diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml new file mode 100755 index 0000000..dfd7f00 --- /dev/null +++ b/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: true + prefix: "" diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100755 index 0000000..97775e1 --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: + log_graph: false + default_hp_metric: true + prefix: "" + version: ${task_name} diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml new file mode 100755 index 0000000..c169bd9 --- /dev/null +++ b/configs/logger/wandb.yaml @@ -0,0 +1,16 @@ +# https://wandb.ai + +wandb: + _target_: lightning.pytorch.loggers.wandb.WandbLogger + # name: "" # name of the run (normally generated by wandb) + save_dir: "${paths.output_dir}" + offline: false + id: # pass correct id to resume experiment! + anonymous: # enable anonymous logging + project: "lightning-hydra-template" + log_model: false # upload lightning ckpts + prefix: "" # a string to put at the beginning of metric keys + # entity: "" # set to name of your wandb team + group: "" + tags: [] + job_type: "" diff --git a/configs/model/spa_pretrain.yaml b/configs/model/spa_pretrain.yaml new file mode 100644 index 0000000..b0b0039 --- /dev/null +++ b/configs/model/spa_pretrain.yaml @@ -0,0 +1,185 @@ +_target_: spa.models.spa_pretrain_module.SPAPretrainModule + +optimizer: + type: AdamW + lr: ${eval:'0.00001 * ${data.batch_size_train} * ${trainer.accumulate_grad_batches} * ${trainer.devices} * ${trainer.num_nodes} / 8'} + weight_decay: 0.04 + betas: [0.9, 0.95] + layer_decay: 0.8 + verbose: true +param_dicts: null + +lr_scheduler: + scheduler: + type: OneCycleLR + max_lr: ${model.optimizer.lr} + pct_start: 0.05 + anneal_strategy: cos + div_factor: 100.0 + final_div_factor: 1000.0 + # monitor: val/loss + interval: step + frequency: 1 + +model: + _target_: spa.models.components.SPA + ckpt_name: spa-l + data_processor_cfg: + enabled_proc_list: + train: + - random_photometric_distort + - imnormalize + test: + - imnormalize + proc_config: + random_photometric_distort: + mv_consistency: true + brightness: [0.875, 1.125] + contrast: [0.5, 1.5] + saturation: [0.5, 1.5] + hue: [-0.05, 0.05] + p: 0.5 + imnormalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + fp16_enabled_layers: ["img_backbone"] + img_backbone: + _target_: spa.models.components.img_backbones.vit.SPAViT + img_size: 224 + patch_size: 16 + in_chans: 3 + embed_dim: 1024 + depth: 24 + num_heads: 16 + # depth / mae shared decoder + decoder_embed_dim: 512 + mlp_ratio: 4.0 + qkv_bias: true + pretrained_weight: + mask_ratio: 0.5 + out_feature_channels: 64 + view_transform: + _target_: spa.models.components.view_transforms.lss_voxelformer.LSSVoxelformer + in_channels: ${model.model.img_backbone.out_feature_channels} + grid_size: [128, 128, 32] + feature_map_stride: 1 # bev_stride + transformer_cfg: + num_layers: 1 + transformerlayers: + attn_cfgs: + - type: VoxelformerCrossAttention + embed_dims: ${model.model.img_backbone.out_feature_channels} + num_heads: 8 + num_points: 5 + num_levels: 1 + - type: VoxelformerSelfAttention + embed_dims: ${model.model.img_backbone.out_feature_channels} + num_convs: 1 + ffn_cfgs: + feedforward_channels: ${eval:'${model.model.img_backbone.out_feature_channels} * 4'} + num_fcs: 2 + ffn_drop: 0.0 + operation_order: ['cross_attn', 'norm', 'self_attn', 'norm', 'ffn', 'norm'] + dense_head: + _target_: spa.models.components.dense_heads.render_head.RenderHead + in_channels: ${model.model.img_backbone.out_feature_channels} + feature_map_stride: ${model.model.view_transform.feature_map_stride} + val_ray_split: 8192 + feature_type: 3d_to_3d + collider_cfg: + type: AABBBoxCollider + near_plane: 0.4 + semantic_cfg: + use_semantic: true + type: radio + img_radio_cfg: + model: radio_v2 + proj_cfg: + type: ExpConv3D + mid_channels: 64 + sdf_channels: 33 + rgb_channels: 27 # 3 * (sh_deg + 1) ** 2, sh_deg = 2 + group_norm: true + semantic_channels: 1280 + use_convnext_block: true + depth: 1 + render_cfg: + type: NeuSModel + scene_scale: 1.6 + field_cfg: + type: SDFFieldExp + beta_init: 0.3 + use_gradient: True + volume_type: default + padding_mode: zeros + render_rgb: true + render_semantic: true + shared_volume: true + use_alpha: true + sampler_cfg: + type: NeuSSampler + initial_sampler: UniformSampler + num_samples: 72 + num_samples_importance: 24 + num_upsample_steps: 1 + train_stratified: true + single_jitter: true + loss_cfg: + sensor_depth_truncation: 0.25 + temperature: 0.01 + weights: + eikonal_loss: 0.01 + free_space_loss: 1.0 + sdf_loss: 10.0 + depth_loss: 1.0 + rgb_loss: 10.0 + semantic_loss: 1.0 + +train_metrics: + _target_: spa.utils.metrics.Metrics + metrics: + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + input_keys: + - loss + - rgb_loss + - psnr + - depth_loss + - semantic_loss + output_keys: + - train/loss + - train/rgb_loss + - train/psnr + - train/depth_loss + - train/semantic_loss +val_metrics: + _target_: spa.utils.metrics.Metrics + metrics: + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + - _target_: torchmetrics.MeanMetric + input_keys: + - loss + - rgb_loss + - psnr + - depth_loss + - semantic_loss + output_keys: + - val/loss + - val/rgb_loss + - val/psnr + - val/depth_loss + - val/semantic_loss +best_val_metrics: + _target_: spa.utils.metrics.Metrics + metrics: + - _target_: torchmetrics.MinMetric + input_keys: + - val/loss + output_keys: + - val/loss diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml new file mode 100755 index 0000000..ec81db2 --- /dev/null +++ b/configs/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100755 index 0000000..e82ca8b --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,52 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: scannet + - model: spa + - callbacks: default_ema + - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: ddp + - paths: default + - extras: default + - hydra: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: + + # config for hyperparameter optimization + - hparams_search: + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: + +# task name, determines output directory path +task_name: "train" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["dev"] + +# set False to skip model training +train: true + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: false + +# compile model for faster training with pytorch 2.0 +compile: false + +# simply provide checkpoint path to resume training +ckpt_path: + +# seed for random number generators in pytorch, numpy and python.random +seed: 3407 diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100755 index 0000000..b7d6767 --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100755 index 0000000..303d0b9 --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,15 @@ +defaults: + - default + +strategy: ddp # _find_unused_parameters_true # ddp + +accelerator: gpu +devices: 4 +num_nodes: 1 +sync_batchnorm: true +max_epochs: 100 +check_val_every_n_epoch: 2 +gradient_clip_val: 0.5 +accumulate_grad_batches: 1 +precision: 32-true +log_every_n_steps: 50 \ No newline at end of file diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml new file mode 100755 index 0000000..8404419 --- /dev/null +++ b/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100755 index 0000000..292f4d6 --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,19 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 # prevents early stopping +max_epochs: 10 + +accelerator: cpu +devices: 1 + +# mixed precision for extra speed-up +# precision: 16 + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: false diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100755 index 0000000..b238951 --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml new file mode 100755 index 0000000..1ecf6d5 --- /dev/null +++ b/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/libs/spa-ops/__init__.py b/libs/spa-ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/spa-ops/deform_attn/ms_deform_attn_utils.py b/libs/spa-ops/deform_attn/ms_deform_attn_utils.py new file mode 100644 index 0000000..406af6f --- /dev/null +++ b/libs/spa-ops/deform_attn/ms_deform_attn_utils.py @@ -0,0 +1,60 @@ +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from . import ms_deform_attn_cuda + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward( + ctx, + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step, + ): + ctx.im2col_step = im2col_step + output = ms_deform_attn_cuda.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ctx.im2col_step, + ) + ctx.save_for_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + ( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + ) = ctx.saved_tensors + ( + grad_value, + grad_sampling_loc, + grad_attn_weight, + ) = ms_deform_attn_cuda.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output, + ctx.im2col_step, + ) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn.cpp b/libs/spa-ops/deform_attn/src/ms_deform_attn.cpp new file mode 100644 index 0000000..2201f63 --- /dev/null +++ b/libs/spa-ops/deform_attn/src/ms_deform_attn.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn.h b/libs/spa-ops/deform_attn/src/ms_deform_attn.h new file mode 100644 index 0000000..288f7b7 --- /dev/null +++ b/libs/spa-ops/deform_attn/src/ms_deform_attn.h @@ -0,0 +1,57 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include "ms_deform_attn_cuda.h" + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.cu b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.cu new file mode 100644 index 0000000..af37ca5 --- /dev/null +++ b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data_ptr()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data_ptr(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data_ptr() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.h b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.h new file mode 100644 index 0000000..c7ae53f --- /dev/null +++ b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/libs/spa-ops/deform_attn/src/ms_deform_im2col_cuda.cuh b/libs/spa-ops/deform_attn/src/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000..f504dd9 --- /dev/null +++ b/libs/spa-ops/deform_attn/src/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + // const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + // const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + // const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + // const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + // const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + // const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + // const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/libs/spa-ops/grid_sampler/__init__.py b/libs/spa-ops/grid_sampler/__init__.py new file mode 100644 index 0000000..e6dfef8 --- /dev/null +++ b/libs/spa-ops/grid_sampler/__init__.py @@ -0,0 +1,3 @@ +from .grid_sampler import GridSampler2D, GridSampler3D + +__all__ = ["GridSampler2D", "GridSampler3D"] diff --git a/libs/spa-ops/grid_sampler/grid_sampler.py b/libs/spa-ops/grid_sampler/grid_sampler.py new file mode 100644 index 0000000..317d1a0 --- /dev/null +++ b/libs/spa-ops/grid_sampler/grid_sampler.py @@ -0,0 +1,156 @@ +import torch +from torch import nn as nn +from torch.autograd import Function + +from . import grid_sampler_cuda + +padding_mode_enum = {"zeros": 0, "border": 1, "reflection": 2} + + +class GridSampler3DBackward(Function): + @staticmethod + def forward( + ctx, + input, + grid, + grad_output, + padding_mode="zeros", + align_corners=True, + ): + ctx.align_corners = align_corners + ctx.padding_mode = padding_mode + grad_input, grad_grid = grid_sampler_cuda.grid_sampler_3d_backward( + grad_output, + input, + grid, + padding_mode_enum[padding_mode], + ctx.align_corners, + ) + ctx.save_for_backward(input, grid, grad_output) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_gard_input, grad2_grad_grid): + input, grid, grad_output = ctx.saved_tensors + + ( + grad_input, + grad_grid, + grad_grad_output, + ) = grid_sampler_cuda.grid_sampler_3d_backward_backward( + grad2_gard_input.contiguous(), + grad2_grad_grid.contiguous(), + input, + grid, + grad_output, + padding_mode_enum[ctx.padding_mode], + ctx.align_corners, + ) + return grad_input, grad_grid, grad_grad_output, None, None + + +class GridSampler3D(Function): + @staticmethod + def forward( + ctx, + input, + grid, + padding_mode="zeros", + align_corners=True, + ): + output = grid_sampler_cuda.grid_sampler_3d_forward( + input, + grid, + padding_mode_enum[padding_mode], + align_corners, + ) + ctx.save_for_backward(input, grid) + ctx.align_corners = align_corners + ctx.padding_mode = padding_mode + return output + + @staticmethod + def backward(ctx, grad_out): + input, grid = ctx.saved_tensors + d_input, d_grid = GridSampler3DBackward.apply( + input, + grid, + grad_out.contiguous(), + ctx.padding_mode, + ctx.align_corners, + ) + return d_input, d_grid, None, None + + +class GridSampler2DBackward(Function): + @staticmethod + def forward( + ctx, + input, + grid, + grad_output, + padding_mode="zeros", + align_corners=True, + ): + ctx.align_corners = align_corners + ctx.padding_mode = padding_mode + grad_input, grad_grid = grid_sampler_cuda.grid_sampler_2d_backward( + grad_output, + input, + grid, + padding_mode_enum[padding_mode], + ctx.align_corners, + ) + ctx.save_for_backward(input, grid, grad_output) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_gard_input, grad2_grad_grid): + input, grid, grad_output = ctx.saved_tensors + ( + grad_input, + grad_grid, + grad_grad_output, + ) = grid_sampler_cuda.grid_sampler_2d_backward_backward( + grad2_gard_input.contiguous(), + grad2_grad_grid.contiguous(), + input, + grid, + grad_output, + padding_mode_enum[ctx.padding_mode], + ctx.align_corners, + ) + return grad_input, grad_grid, grad_grad_output, None, None + + +class GridSampler2D(Function): + @staticmethod + def forward( + ctx, + input, + grid, + padding_mode="zeros", + align_corners=True, + ): + output = grid_sampler_cuda.grid_sampler_2d_forward( + input, + grid, + padding_mode_enum[padding_mode], + align_corners, + ) + ctx.save_for_backward(input, grid) + ctx.align_corners = align_corners + ctx.padding_mode = padding_mode + return output + + @staticmethod + def backward(ctx, grad_out): + input, grid = ctx.saved_tensors + d_input, d_grid = GridSampler2DBackward.apply( + input, + grid, + grad_out.contiguous(), + ctx.padding_mode, + ctx.align_corners, + ) + return d_input, d_grid, None, None diff --git a/libs/spa-ops/grid_sampler/src/grid_sampler.cpp b/libs/spa-ops/grid_sampler/src/grid_sampler.cpp new file mode 100644 index 0000000..9557ddd --- /dev/null +++ b/libs/spa-ops/grid_sampler/src/grid_sampler.cpp @@ -0,0 +1,133 @@ +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA Tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +void launch_grid_sampler_2d_forward_kernel( + const torch::TensorBase &output, const torch::TensorBase &input, const torch::TensorBase &grid, + int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_3d_forward_kernel( + const torch::TensorBase &output, const torch::TensorBase &input, const torch::TensorBase &grid, + int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_2d_backward_kernel( + const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid, + const torch::TensorBase &grad_output, const torch::TensorBase &input, + const torch::TensorBase &grid, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_3d_backward_kernel( + const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid, + const torch::TensorBase &grad_output, const torch::TensorBase &input, + const torch::TensorBase &grid, int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_2d_backward_backward_kernel( + const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid, const torch::TensorBase &grad_grad_output, + const torch::TensorBase &grad2_grad_input, const torch::TensorBase &grad2_grad_grid, + const torch::TensorBase &input, const torch::TensorBase &grid, const torch::TensorBase &grad_output, + int64_t padding_mode, bool align_corners); + +void launch_grid_sampler_3d_backward_backward_kernel( + const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid, const torch::TensorBase &grad_grad_output, + const torch::TensorBase &grad2_grad_input, const torch::TensorBase &grad2_grad_grid, + const torch::TensorBase &input, const torch::TensorBase &grid, const torch::TensorBase &grad_output, + int64_t padding_mode, bool align_corners); + +torch::Tensor grid_sampler_2d_forward(const torch::Tensor& input, const torch::Tensor& grid, + int64_t padding_mode, bool align_corners) { + CHECK_INPUT(input) + CHECK_INPUT(grid) + auto in_size = input.sizes(); + auto grid_size = grid.sizes(); + auto output = at::empty( + {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options()); + launch_grid_sampler_2d_forward_kernel( + output, input, grid, padding_mode, align_corners); + return output; +} + +torch::Tensor grid_sampler_3d_forward(const torch::Tensor &input, const torch::Tensor &grid, + int64_t padding_mode, bool align_corners) { + CHECK_INPUT(input) + CHECK_INPUT(grid) + auto in_size = input.sizes(); + auto grid_size = grid.sizes(); + auto output = torch::empty( + {in_size[0], in_size[1], grid_size[1], grid_size[2], grid_size[3]}, input.options()); + launch_grid_sampler_3d_forward_kernel( + output, input, grid, padding_mode, align_corners); + return output; +} + +std::tuple +grid_sampler_2d_backward(const torch::Tensor &grad_output, const torch::Tensor &input, + const torch::Tensor &grid, int64_t padding_mode, bool align_corners) { + CHECK_INPUT(grad_output) + CHECK_INPUT(input) + CHECK_INPUT(grid) + auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + launch_grid_sampler_2d_backward_kernel( + grad_input, grad_grid, grad_output, input, + grid, padding_mode, align_corners); + return std::make_tuple(grad_input, grad_grid); +} + +std::tuple +grid_sampler_3d_backward(const torch::Tensor &grad_output, const torch::Tensor &input, + const torch::Tensor &grid, int64_t padding_mode, bool align_corners) { + CHECK_INPUT(grad_output) + CHECK_INPUT(input) + CHECK_INPUT(grid) + auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + launch_grid_sampler_3d_backward_kernel( + grad_input, grad_grid, grad_output, input, + grid, padding_mode, align_corners); + return std::make_tuple(grad_input, grad_grid); +} + +std::tuple +grid_sampler_2d_backward_backward(const torch::Tensor &grad2_grad_input, const torch::Tensor &grad2_grad_grid, + const torch::Tensor &input, const torch::Tensor &grid, const torch::Tensor &grad_output, + int64_t padding_mode, bool align_corners) { + CHECK_INPUT(grad2_grad_input) + CHECK_INPUT(grad2_grad_grid) + CHECK_INPUT(input) + CHECK_INPUT(grid) + CHECK_INPUT(grad_output) + auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + launch_grid_sampler_2d_backward_backward_kernel(grad_input, grad_grid, grad_grad_output, grad2_grad_input, grad2_grad_grid, + input, grid, grad_output, padding_mode, align_corners); + return std::make_tuple(grad_input, grad_grid, grad_grad_output); +} + +std::tuple +grid_sampler_3d_backward_backward(const torch::Tensor &grad2_grad_input, const torch::Tensor &grad2_grad_grid, + const torch::Tensor &input, const torch::Tensor &grid, const torch::Tensor &grad_output, + int64_t padding_mode, bool align_corners) { + CHECK_INPUT(grad2_grad_input) + CHECK_INPUT(grad2_grad_grid) + CHECK_INPUT(input) + CHECK_INPUT(grid) + CHECK_INPUT(grad_output) + auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + launch_grid_sampler_3d_backward_backward_kernel(grad_input, grad_grid, grad_grad_output, grad2_grad_input, grad2_grad_grid, + input, grid, grad_output, padding_mode, align_corners); + return std::make_tuple(grad_input, grad_grid, grad_grad_output); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_sampler_3d_forward", &grid_sampler_3d_forward, "grid_sampler_3d_forward"); + m.def("grid_sampler_3d_backward", &grid_sampler_3d_backward, "grid_sampler_3d_backward"); + m.def("grid_sampler_3d_backward_backward", &grid_sampler_3d_backward_backward, "grid_sampler_3d_backward_backward"); + m.def("grid_sampler_2d_forward", &grid_sampler_2d_forward, "grid_sampler_2d_forward"); + m.def("grid_sampler_2d_backward", &grid_sampler_2d_backward, "grid_sampler_2d_backward"); + m.def("grid_sampler_2d_backward_backward", &grid_sampler_2d_backward_backward, "grid_sampler_2d_backward_backward"); +} \ No newline at end of file diff --git a/libs/spa-ops/grid_sampler/src/grid_sampler_cuda.cu b/libs/spa-ops/grid_sampler/src/grid_sampler_cuda.cu new file mode 100644 index 0000000..453f8e3 --- /dev/null +++ b/libs/spa-ops/grid_sampler/src/grid_sampler_cuda.cu @@ -0,0 +1,1276 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +using namespace at::cuda::detail; +using at::native::detail::GridSamplerPadding; + + +template +C10_LAUNCH_BOUNDS_1(256) +__global__ void grid_sampler_2d_kernel( + const index_t nthreads, + TensorInfo input, + TensorInfo grid, + TensorInfo output, + const GridSamplerPadding padding_mode, + bool align_corners) { + + index_t C = input.sizes[1]; + index_t inp_H = input.sizes[2]; + index_t inp_W = input.sizes[3]; + index_t out_H = grid.sizes[1]; + index_t out_W = grid.sizes[2]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sH = input.strides[2]; + index_t inp_sW = input.strides[3]; + index_t grid_sN = grid.strides[0]; + index_t grid_sH = grid.strides[1]; + index_t grid_sW = grid.strides[2]; + index_t grid_sCoor = grid.strides[3]; + index_t out_sN = output.strides[0]; + index_t out_sC = output.strides[1]; + index_t out_sH = output.strides[2]; + index_t out_sW = output.strides[3]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t n = index / (out_H * out_W); + const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; + + scalar_t ix = at::native::grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + scalar_t iy = at::native::grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_nw = static_cast(::floor(ix)); + index_t iy_nw = static_cast(::floor(iy)); + index_t ix_ne = ix_nw + 1; + index_t iy_ne = iy_nw; + index_t ix_sw = ix_nw; + index_t iy_sw = iy_nw + 1; + index_t ix_se = ix_nw + 1; + index_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + *out_ptr_NCHW = static_cast(0); + if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + } + } +} + +template +C10_LAUNCH_BOUNDS_1(512) +__global__ void grid_sampler_3d_kernel( + const index_t nthreads, + TensorInfo input, + TensorInfo grid, + TensorInfo output, + const GridSamplerPadding padding_mode, + bool align_corners) { + + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + index_t out_sN = output.strides[0]; + index_t out_sC = output.strides[1]; + index_t out_sD = output.strides[2]; + index_t out_sH = output.strides[3]; + index_t out_sW = output.strides[4]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t d = (index / (out_H * out_W)) % out_D; + const index_t n = index / (out_D * out_H * out_W); + const index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + ix = at::native::grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = at::native::grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + iz = at::native::grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); + + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(::floor(ix)); + index_t iy_tnw = static_cast(::floor(iy)); + index_t iz_tnw = static_cast(::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input.data + n * inp_sN; + auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse + *out_ptr_NCDHW = static_cast(0); + if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } + } +} + +// Note [Passing pointer and offset to fastAtomicAdd] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// For its internal bounds checking, fastAtomicAdd needs to know where the destination address +// lies relative to the entire tensor, so we pass the base grad_input.data and full offset information, +// including batch * channel offset (NC_offset). + +template +C10_LAUNCH_BOUNDS_1(256) +__global__ void grid_sampler_2d_backward_kernel( + const index_t nthreads, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_input, // initialized to zeros (or unused if input_requires_grad is false) + TensorInfo grad_grid, // initialized to empty + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_H = input.sizes[2]; + index_t inp_W = input.sizes[3]; + index_t out_H = grid.sizes[1]; + index_t out_W = grid.sizes[2]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sH = input.strides[2]; + index_t inp_sW = input.strides[3]; + index_t grid_sN = grid.strides[0]; + index_t grid_sH = grid.strides[1]; + index_t grid_sW = grid.strides[2]; + index_t grid_sCoor = grid.strides[3]; + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sH = grad_output.strides[2]; + index_t gOut_sW = grad_output.strides[3]; + // gInp_* (and NC_offset below) are not really needed if input_requires_grad is false. + index_t gInp_sN = grad_input.strides[0]; + index_t gInp_sC = grad_input.strides[1]; + index_t gInp_sH = grad_input.strides[2]; + index_t gInp_sW = grad_input.strides[3]; + + index_t gGrid_sW = grad_grid.strides[2]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t n = index / (out_H * out_W); + const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t x = grid.data[grid_offset]; + scalar_t y = grid.data[grid_offset + grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult; + scalar_t ix = at::native::grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult); + scalar_t iy = at::native::grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_nw = static_cast(::floor(ix)); + index_t iy_nw = static_cast(::floor(iy)); + index_t ix_ne = ix_nw + 1; + index_t iy_ne = iy_nw; + index_t ix_sw = ix_nw; + index_t iy_sw = iy_nw + 1; + index_t ix_se = ix_nw + 1; + index_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + scalar_t gix = static_cast(0), giy = static_cast(0); + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + index_t NC_offset = n * gInp_sN; + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) { + scalar_t gOut = *gOut_ptr_NCHW; + + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + at::native::safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut, NC_offset, grad_input_memory_span); + at::native::safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut, NC_offset, grad_input_memory_span); + + // calculate grad_grid + if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW + // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1] + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } +} + +template +C10_LAUNCH_BOUNDS_1(256) +__global__ void grid_sampler_3d_backward_kernel( + const index_t nthreads, + TensorInfo grad_output, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_input, // initialized to zeros (or unused if input_requires_grad is false) + TensorInfo grad_grid, // initialized to empty + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sD = grad_output.strides[2]; + index_t gOut_sH = grad_output.strides[3]; + index_t gOut_sW = grad_output.strides[4]; + // gInp_* (and NC_offset below) are not really needed if input_requires_grad is false. + int64_t gInp_sN = grad_input.strides[0]; + int64_t gInp_sC = grad_input.strides[1]; + int64_t gInp_sD = grad_input.strides[2]; + int64_t gInp_sH = grad_input.strides[3]; + int64_t gInp_sW = grad_input.strides[4]; + + index_t gGrid_sW = grad_grid.strides[3]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t d = (index / (out_H * out_W)) % out_D; + const index_t n = index / (out_D * out_H * out_W); + const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + scalar_t gix_mult, giy_mult, giz_mult; + ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + iz = at::native::grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult); + + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + index_t ix_tnw = static_cast(::floor(ix)); + index_t iy_tnw = static_cast(::floor(iy)); + index_t iz_tnw = static_cast(::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); + scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + index_t NC_offset = n * gInp_sN; + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + // calculate bilinear weighted pixel value and set output pixel + for (index_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) { + scalar_t gOut = *gOut_ptr_NCDHW; + + // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd]. + at::native::safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut, + NC_offset, grad_input_memory_span); + // calculate grad_grid + if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + + // assuming grad_grid is contiguous + // thus we can + // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW + // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2] + scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } +} + + +template + C10_LAUNCH_BOUNDS_1(256) + __global__ void grid_sampler_2d_backward_backward_kernel( + const index_t nthreads, + TensorInfo grad2_grad_input, + TensorInfo grad2_grad_grid, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_output, + TensorInfo grad_input, + TensorInfo grad_grid, + TensorInfo grad_grad_output, + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_H = input.sizes[2]; + index_t inp_W = input.sizes[3]; + + index_t out_H = grid.sizes[1]; + index_t out_W = grid.sizes[2]; + + index_t g2inp_sN = grad2_grad_input.strides[0]; + index_t g2inp_sC = grad2_grad_input.strides[1]; + index_t g2inp_sH = grad2_grad_input.strides[2]; + index_t g2inp_sW = grad2_grad_input.strides[3]; + + index_t g2grid_sN = grad2_grad_grid.strides[0]; + index_t g2grid_sH = grad2_grad_grid.strides[1]; + index_t g2grid_sW = grad2_grad_grid.strides[2]; + index_t g2grid_sCoor = grad2_grad_grid.strides[3]; + + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sH = grad_output.strides[2]; + index_t gOut_sW = grad_output.strides[3]; + + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sH = input.strides[2]; + index_t inp_sW = input.strides[3]; + + index_t grid_sN = grid.strides[0]; + index_t grid_sH = grid.strides[1]; + index_t grid_sW = grid.strides[2]; + index_t grid_sCoor = grid.strides[3]; + + index_t gInp_sN = grad_input.strides[0]; + index_t gInp_sC = grad_input.strides[1]; + index_t gInp_sH = grad_input.strides[2]; + index_t gInp_sW = grad_input.strides[3]; + + index_t gGrid_sW = grad_grid.strides[2]; + + index_t ggOut_sN = grad_grad_output.strides[0]; + index_t ggOut_sC = grad_grad_output.strides[1]; + index_t ggOut_sH = grad_grad_output.strides[2]; + index_t ggOut_sW = grad_grad_output.strides[3]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t n = index / (out_H * out_W); + + /* Grid related staff */ + index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult; + ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_nw = static_cast(::floor(ix)); + index_t iy_nw = static_cast(::floor(iy)); + + index_t ix_ne = ix_nw + 1; + index_t iy_ne = iy_nw; + + index_t ix_sw = ix_nw; + index_t iy_sw = iy_nw + 1; + + index_t ix_se = ix_nw + 1; + index_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + /* grad2_grad_input related init */ + scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN; + + /* grad2_grad_grid related init */ + grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW; + scalar_t dx = grad2_grad_grid.data[grid_offset]; + scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor]; + + dx = dx * gix_mult; + dy = dy * giy_mult; + + /* grad_output related init */ + scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW; + + /* input related init */ + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + /* grad_grad_output related init */ + scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW; + + /* grad_input related init */ + index_t NC_offset = n * gInp_sN; + + /* grad_grid related init */ + scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW; + scalar_t gix = static_cast(0), giy = static_cast(0); + + scalar_t nw_val = static_cast(0), ne_val = static_cast(0), sw_val = static_cast(0), se_val = static_cast(0); + scalar_t g2_nw_val = static_cast(0), g2_ne_val = static_cast(0), g2_sw_val = static_cast(0), g2_se_val = static_cast(0); + + for (index_t c = 0; c < C; ++c, g2_inp_ptr_NC += g2inp_sC, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC, ggOut_ptr_NCHW += ggOut_sC) { + if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]; + g2_nw_val = g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW]; + } + if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]; + g2_ne_val = g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW]; + } + if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]; + g2_sw_val = g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW]; + } + if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]; + g2_se_val = g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW]; + } + // Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val + // grad2_grad_input * x * y + *ggOut_ptr_NCHW = static_cast(0); + *ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se; + + scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix); + scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw); + scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix); + scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw); + + // grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val + *ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val; + + // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x + scalar_t gOut = *gOut_ptr_NCHW; + //scalar_t val; + //val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix)); + at::native::safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw)); + at::native::safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix)); + at::native::safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span); + //val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw)); + at::native::safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span); + + scalar_t dxy = nw_val - ne_val - sw_val + se_val; + // Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut + gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy) + -g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw)); + gix += gOut * dy * dxy; + + // Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut + giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw) + +g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw)); + giy += gOut * dx * dxy; + } + + gGrid_ptr_NHW[0] = gix * gix_mult; + gGrid_ptr_NHW[1] = giy * giy_mult; + } +} + +template + C10_LAUNCH_BOUNDS_1(256) + __global__ void grid_sampler_3d_backward_backward_kernel( + const index_t nthreads, + TensorInfo grad2_grad_input, + TensorInfo grad2_grad_grid, + TensorInfo input, + TensorInfo grid, + TensorInfo grad_output, + TensorInfo grad_input, + TensorInfo grad_grid, + TensorInfo grad_grad_output, + const GridSamplerPadding padding_mode, + bool align_corners, + const index_t grad_input_memory_span) { + + index_t C = input.sizes[1]; + index_t inp_D = input.sizes[2]; + index_t inp_H = input.sizes[3]; + index_t inp_W = input.sizes[4]; + + index_t out_D = grid.sizes[1]; + index_t out_H = grid.sizes[2]; + index_t out_W = grid.sizes[3]; + + index_t g2inp_sN = grad2_grad_input.strides[0]; + index_t g2inp_sC = grad2_grad_input.strides[1]; + index_t g2inp_sD = grad2_grad_input.strides[2]; + index_t g2inp_sH = grad2_grad_input.strides[3]; + index_t g2inp_sW = grad2_grad_input.strides[4]; + + index_t g2grid_sN = grad2_grad_grid.strides[0]; + index_t g2grid_sD = grad2_grad_grid.strides[1]; + index_t g2grid_sH = grad2_grad_grid.strides[2]; + index_t g2grid_sW = grad2_grad_grid.strides[3]; + index_t g2grid_sCoor = grad2_grad_grid.strides[4]; + + index_t gOut_sN = grad_output.strides[0]; + index_t gOut_sC = grad_output.strides[1]; + index_t gOut_sD = grad_output.strides[2]; + index_t gOut_sH = grad_output.strides[3]; + index_t gOut_sW = grad_output.strides[4]; + + index_t inp_sN = input.strides[0]; + index_t inp_sC = input.strides[1]; + index_t inp_sD = input.strides[2]; + index_t inp_sH = input.strides[3]; + index_t inp_sW = input.strides[4]; + + index_t grid_sN = grid.strides[0]; + index_t grid_sD = grid.strides[1]; + index_t grid_sH = grid.strides[2]; + index_t grid_sW = grid.strides[3]; + index_t grid_sCoor = grid.strides[4]; + + index_t gInp_sN = grad_input.strides[0]; + index_t gInp_sC = grad_input.strides[1]; + index_t gInp_sD = grad_input.strides[2]; + index_t gInp_sH = grad_input.strides[3]; + index_t gInp_sW = grad_input.strides[4]; + + index_t gGrid_sW = grad_grid.strides[3]; + + index_t ggOut_sN = grad_grad_output.strides[0]; + index_t ggOut_sC = grad_grad_output.strides[1]; + index_t ggOut_sD = grad_grad_output.strides[2]; + index_t ggOut_sH = grad_grad_output.strides[3]; + index_t ggOut_sW = grad_grad_output.strides[4]; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { + const index_t w = index % out_W; + const index_t h = (index / out_W) % out_H; + const index_t d = (index / (out_H * out_W)) % out_D; + const index_t n = index / (out_D * out_H * out_W); + + /* Grid related staff */ + index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t ix = grid.data[grid_offset]; + scalar_t iy = grid.data[grid_offset + grid_sCoor]; + scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix and iy + scalar_t gix_mult, giy_mult, giz_mult; + ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); + iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); + iz = at::native::grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult); + + // get NE, NW, SE, SW pixel values from (x, y) + index_t ix_tnw = static_cast(::floor(ix)); + index_t iy_tnw = static_cast(::floor(iy)); + index_t iz_tnw = static_cast(::floor(iz)); + + index_t ix_tne = ix_tnw + 1; + index_t iy_tne = iy_tnw; + index_t iz_tne = iz_tnw; + + index_t ix_tsw = ix_tnw; + index_t iy_tsw = iy_tnw + 1; + index_t iz_tsw = iz_tnw; + + index_t ix_tse = ix_tnw + 1; + index_t iy_tse = iy_tnw + 1; + index_t iz_tse = iz_tnw; + + index_t ix_bnw = ix_tnw; + index_t iy_bnw = iy_tnw; + index_t iz_bnw = iz_tnw + 1; + + index_t ix_bne = ix_tnw + 1; + index_t iy_bne = iy_tnw; + index_t iz_bne = iz_tnw + 1; + + index_t ix_bsw = ix_tnw; + index_t iy_bsw = iy_tnw + 1; + index_t iz_bsw = iz_tnw + 1; + + index_t ix_bse = ix_tnw + 1; + index_t iy_bse = iy_tnw + 1; + index_t iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + /* grad2_grad_input related init */ + scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN; + + /* grad2_grad_grid related init */ + grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW; + scalar_t dx = grad2_grad_grid.data[grid_offset]; + scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor]; + scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor]; + + dx = dx * gix_mult; + dy = dy * giy_mult; + dz = dz * giz_mult; + + /* grad_output related init */ + scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + + /* input related init */ + scalar_t *inp_ptr_NC = input.data + n * inp_sN; + + /* grad_grad_output related init */ + scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW; + + /* grad_input related init */ + index_t NC_offset = n * gInp_sN; + + /* grad_grid related init */ + scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW; + scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0); + + scalar_t tnw_val = static_cast(0), tne_val = static_cast(0), tsw_val = static_cast(0), tse_val = static_cast(0); + scalar_t bnw_val = static_cast(0), bne_val = static_cast(0), bsw_val = static_cast(0), bse_val = static_cast(0); + scalar_t g2_tnw_val = static_cast(0), g2_tne_val = static_cast(0), g2_tsw_val = static_cast(0), g2_tse_val = static_cast(0); + scalar_t g2_bnw_val = static_cast(0), g2_bne_val = static_cast(0), g2_bsw_val = static_cast(0), g2_bse_val = static_cast(0); + + for (index_t c = 0; c < C; ++c, g2_inp_ptr_NC += g2inp_sC, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCDHW += gOut_sC, ggOut_ptr_NCDHW += ggOut_sC) { + + if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; + g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW]; + } + if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; + g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW]; + } + if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; + g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW]; + } + if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; + g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW]; + } + if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; + g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW]; + } + if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; + g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW]; + } + if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; + g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW]; + } + if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; + g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW]; + } + + // Computing gradient wrt to grad_output = + // grad2_grad_input * x * y * z + *ggOut_ptr_NCDHW = static_cast(0); + *ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse + +g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse; + + // +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y) + scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy)); + scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy)); + scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne)); + scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw)); + scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy)); + scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy)); + scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne)); + scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw)); + + *ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp + +bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp; + + // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z + + // grad2_grad_grid_z * grad_output * y * z + scalar_t gOut = *gOut_ptr_NCDHW; + + at::native::safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut, + NC_offset, grad_input_memory_span); + at::native::safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut, + NC_offset, grad_input_memory_span); + + //Computing gradient wrt grid + scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz) + -tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz) + +bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw) + -bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw)); + + scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy) + +tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw) + -bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy) + -bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw)); + + scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw) + -tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw) + -bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw) + +bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw)); + + + // Computing gradient wrt grid_x = + // grad2_grad_input * z * y * gOut + gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz) + -g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz) + -g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw) + -g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw)); + + //+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut + gix += gOut * (dz * dxz + dy * dxy); + + // Computing gradient wrt grid_y = + // grad2_grad_input * x * z * gOut + giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz) + +g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz) + -g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw) + +g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw)); + //+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut + giy += gOut * (dx * dxy + dz * dyz); + + // Computing gradient wrt grid_z = + // grad2_grad_input * x * y * gOut + giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy) + -g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw) + +g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy) + +g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw)); + //+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut + giz += gOut * (dx * dxz + dy * dyz); + } + + gGrid_ptr_NDHW[0] = gix * gix_mult; + gGrid_ptr_NDHW[1] = giy * giy_mult; + gGrid_ptr_NDHW[2] = giz * giz_mult; + } +} + +void launch_grid_sampler_2d_forward_kernel( + const at::TensorBase &output, const at::TensorBase &input, const at::TensorBase &grid, + int64_t padding_mode, bool align_corners) { + auto N = input.size(0); + auto H = grid.size(1); + auto W = grid.size(2); + int64_t count = N * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(output)) { + grid_sampler_2d_kernel + <<>>( + static_cast(count), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(padding_mode), + align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_2d_kernel + <<>>( + count, + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(padding_mode), + align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_grid_sampler_3d_forward_kernel( + const at::TensorBase &output, const at::TensorBase &input, const at::TensorBase &grid, + int64_t padding_mode, bool align_corners) { + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + int64_t count = N * D * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(output)) { + grid_sampler_3d_kernel + <<>>( + static_cast(count), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(padding_mode), + align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_3d_kernel + <<>>( + count, + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(output), + static_cast(padding_mode), + align_corners); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_grid_sampler_2d_backward_kernel( + const at::TensorBase &grad_input, const at::TensorBase &grad_grid, + const at::TensorBase &grad_output, const at::TensorBase &input, + const at::TensorBase &grid, int64_t padding_mode, bool align_corners) { + auto N = input.size(0); + auto H = grid.size(1); + auto W = grid.size(2); + int64_t count = N * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_2d_backward_kernel + <<>>( + static_cast(count), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_2d_backward_kernel + <<>>( + count, + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_grid_sampler_3d_backward_kernel( + const at::TensorBase &grad_input, const at::TensorBase &grad_grid, + const at::TensorBase &grad_output, const at::TensorBase &input, + const at::TensorBase &grid, int64_t padding_mode, bool align_corners) { + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + int64_t count = N * D * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_3d_backward_kernel + <<>>( + static_cast(count), + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_3d_backward_kernel + <<>>( + count, + getTensorInfo(grad_output), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_grid_sampler_2d_backward_backward_kernel( + const at::TensorBase &grad_input, const at::TensorBase &grad_grid, const at::TensorBase &grad_grad_output, + const at::TensorBase &grad2_grad_input, const at::TensorBase &grad2_grad_grid, + const at::TensorBase &input, const at::TensorBase &grid, const at::TensorBase &grad_output, + int64_t padding_mode, bool align_corners) { + + auto N = input.size(0); + auto H = grid.size(1); + auto W = grid.size(2); + int64_t count = N * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_backward_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_2d_backward_backward_kernel + <<>>( + static_cast(count), + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + getTensorInfo(grad_grad_output), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_2d_backward_backward_kernel + <<>>( + count, + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + getTensorInfo(grad_grad_output), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} + +void launch_grid_sampler_3d_backward_backward_kernel( + const at::TensorBase &grad_input, const at::TensorBase &grad_grid, const at::TensorBase &grad_grad_output, + const at::TensorBase &grad2_grad_input, const at::TensorBase &grad2_grad_grid, + const at::TensorBase &input, const at::TensorBase &grid, const at::TensorBase &grad_output, + int64_t padding_mode, bool align_corners) { + + auto N = input.size(0); + auto D = grid.size(1); + auto H = grid.size(2); + auto W = grid.size(3); + int64_t count = N * D * H * W; + if (count > 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_backward_cuda", [&] { + if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && + canUse32BitIndexMath(grad_output)) { + grid_sampler_3d_backward_backward_kernel + <<>>( + static_cast(count), + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + getTensorInfo(grad_grad_output), + static_cast(padding_mode), + align_corners, + static_cast(grad_input.numel())); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + grid_sampler_3d_backward_backward_kernel + <<>>( + count, + getTensorInfo(grad2_grad_input), + getTensorInfo(grad2_grad_grid), + getTensorInfo(input), + getTensorInfo(grid), + getTensorInfo(grad_output), + getTensorInfo(grad_input), + getTensorInfo(grad_grid), + getTensorInfo(grad_grad_output), + static_cast(padding_mode), + align_corners, + grad_input.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + } +} \ No newline at end of file diff --git a/libs/spa-ops/setup.py b/libs/spa-ops/setup.py new file mode 100644 index 0000000..e46c28a --- /dev/null +++ b/libs/spa-ops/setup.py @@ -0,0 +1,91 @@ +import os +import subprocess + +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +def get_git_commit_number(): + if not os.path.exists(".git"): + return "0000000" + + cmd_out = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE) + git_commit_number = cmd_out.stdout.decode("utf-8")[:7] + return git_commit_number + + +def make_cuda_ext(name, module, sources): + cuda_ext = CUDAExtension( + name="%s.%s" % (module, name), + sources=[os.path.join(*module.split("."), src) for src in sources], + define_macros=[("WITH_CUDA", None)], + extra_compile_args={ + "cxx": [], + "nvcc": [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "-O2", + "-std=c++17", + ], + }, + ) + return cuda_ext + + +def write_version_to_file(version, target_file): + with open(target_file, "w") as f: + print('__version__ = "%s"' % version, file=f) + + +if __name__ == "__main__": + version = "0.6.0+%s" % get_git_commit_number() + setup( + name="spa_ops", + version=version, + description="SPA ops", + install_requires=[ + "numpy", + "llvmlite", + "numba", + "tensorboardX", + "easydict", + "pyyaml", + "scikit-image", + "tqdm", + ], + author="Haoyi Zhu", + author_email="hyizhu1108@gmail.com", + license="Apache License 2.0", + packages=find_packages(), + cmdclass={ + "build_ext": BuildExtension, + }, + ext_modules=[ + make_cuda_ext( + name="voxel_pool_ext", + module="voxel_pool", + sources=[ + "src/voxel_pool.cpp", + "src/voxel_pool_cuda.cu", + ], + ), + make_cuda_ext( + name="grid_sampler_cuda", + module="grid_sampler", + sources=[ + "src/grid_sampler.cpp", + "src/grid_sampler_cuda.cu", + ], + ), + make_cuda_ext( + name="ms_deform_attn_cuda", + module="deform_attn", + sources=[ + "src/ms_deform_attn.cpp", + "src/ms_deform_attn_cuda.cu", + ], + ), + ], + ) diff --git a/libs/spa-ops/voxel_pool/__init__.py b/libs/spa-ops/voxel_pool/__init__.py new file mode 100644 index 0000000..d4e0077 --- /dev/null +++ b/libs/spa-ops/voxel_pool/__init__.py @@ -0,0 +1 @@ +from .voxel_pool import voxel_pool diff --git a/libs/spa-ops/voxel_pool/src/voxel_pool.cpp b/libs/spa-ops/voxel_pool/src/voxel_pool.cpp new file mode 100644 index 0000000..ce3b396 --- /dev/null +++ b/libs/spa-ops/voxel_pool/src/voxel_pool.cpp @@ -0,0 +1,70 @@ +#include +#include + +// CUDA function declarations +void voxel_pool(int b, int z, int y, int x, int n, int c, int n_intervals, const float* feats, + const int* coords, const int* interval_starts, const int* interval_lengths, float* out); + +void voxel_pool_grad(int b, int z, int y, int x, int n, int c, int n_intervals, const float* out_grad, + const int* coords, const int* interval_starts, const int* interval_lengths, float* feats_grad); + + +at::Tensor voxel_pool_forward( + const at::Tensor _feats, + const at::Tensor _coords, + const at::Tensor _interval_lengths, + const at::Tensor _interval_starts, + int b, int z, int y, int x +) { + int n = _feats.size(0); + int c = _feats.size(1); + int n_intervals = _interval_lengths.size(0); + + const float* feats = _feats.data_ptr(); + const int* coords = _coords.data_ptr(); + const int* interval_lengths = _interval_lengths.data_ptr(); + const int* interval_starts = _interval_starts.data_ptr(); + + auto options = torch::TensorOptions().dtype(_feats.dtype()).device(_feats.device()); + at::Tensor _out = torch::zeros({b, c, z, y, x}, options); + float* out = _out.data_ptr(); + voxel_pool( + b, z, y, x, n, c, n_intervals, feats, + coords, interval_starts, interval_lengths, out + ); + return _out; +} + + +at::Tensor voxel_pool_backward( + const at::Tensor _out_grad, + const at::Tensor _coords, + const at::Tensor _interval_lengths, + const at::Tensor _interval_starts, + int b, int z, int y, int x +) { + int n = _coords.size(0); + int c = _out_grad.size(1); + int n_intervals = _interval_lengths.size(0); + + const float* out_grad = _out_grad.data_ptr(); + const int* coords = _coords.data_ptr(); + const int* interval_lengths = _interval_lengths.data_ptr(); + const int* interval_starts = _interval_starts.data_ptr(); + + auto options = torch::TensorOptions().dtype(_out_grad.dtype()).device(_out_grad.device()); + at::Tensor _feats_grad = torch::zeros({n, c}, options); + float* feats_grad = _feats_grad.data_ptr(); + + voxel_pool_grad( + b, z, y, x, n, c, n_intervals, out_grad, + coords, interval_starts, interval_lengths, feats_grad + ); + + return _feats_grad; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("voxel_pool_forward", &voxel_pool_forward, "voxel_pool_forward"); + m.def("voxel_pool_backward", &voxel_pool_backward, "voxel_pool_backward"); +} diff --git a/libs/spa-ops/voxel_pool/src/voxel_pool_cuda.cu b/libs/spa-ops/voxel_pool/src/voxel_pool_cuda.cu new file mode 100644 index 0000000..fb4fcd0 --- /dev/null +++ b/libs/spa-ops/voxel_pool/src/voxel_pool_cuda.cu @@ -0,0 +1,67 @@ +#include +#include + +#define THREADS_PER_BLOCK 256 +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +__global__ void voxel_pool_kernel(int b, int z, int y, int x, int n, int c, int n_intervals, + const float* feats, + const int* coords, + const int* interval_starts, + const int* interval_lengths, + float* out) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int index = idx / c; + int cur_c = idx % c; + if (index >= n_intervals) return; + int interval_start = interval_starts[index]; + int interval_length = interval_lengths[index]; + const int* cur_coords = coords + interval_start * 4; + const float* cur_feats = feats + interval_start * c + cur_c; + // (b, c, z, y, x) + float* cur_out = out + cur_coords[0] * c * z * y * x + cur_c * z * y * x + + cur_coords[1] * y * x + cur_coords[2] * x + + cur_coords[3]; + float psum = 0; + for(int i = 0; i < interval_length; i++){ + psum += cur_feats[i * c]; + } + *cur_out = psum; +} + + +__global__ void voxel_pool_grad_kernel(int b, int z, int y, int x, int n, int c, int n_intervals, + const float* out_grad, + const int* coords, + const int* interval_starts, + const int* interval_lengths, + float* feats_grad) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int index = idx / c; + int cur_c = idx % c; + if (index >= n_intervals) return; + int interval_start = interval_starts[index]; + int interval_length = interval_lengths[index]; + const int* cur_coords = coords + interval_start * 4; + float* cur_feats_grad = feats_grad + interval_start * c + cur_c; + const float* cur_out_grad = out_grad + cur_coords[0] * c * z * y * x + + cur_c * z * y * x + cur_coords[1] * y * x + + cur_coords[2] * x + cur_coords[3]; + for(int i = 0; i < interval_length; i++){ + cur_feats_grad[i * c] = *cur_out_grad; + } +} + +void voxel_pool(int b, int z, int y, int x, int n, int c, int n_intervals, const float* feats, + const int* coords, const int* interval_starts, const int* interval_lengths, float* out) { + voxel_pool_kernel<<>>( + b, z, y, x, n, c, n_intervals, feats, coords, interval_starts, interval_lengths, out + ); +} + +void voxel_pool_grad(int b, int z, int y, int x, int n, int c, int n_intervals, const float* out_grad, + const int* coords, const int* interval_starts, const int* interval_lengths, float* feats_grad) { + voxel_pool_grad_kernel<<>>( + b, z, y, x, n, c, n_intervals, out_grad, coords, interval_starts, interval_lengths, feats_grad + ); +} diff --git a/libs/spa-ops/voxel_pool/voxel_pool.py b/libs/spa-ops/voxel_pool/voxel_pool.py new file mode 100644 index 0000000..c399c3d --- /dev/null +++ b/libs/spa-ops/voxel_pool/voxel_pool.py @@ -0,0 +1,65 @@ +import torch + +from . import voxel_pool_ext + + +class VoxelPoolFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, feats, coords, ranks, B, Z, Y, X): + kept = torch.ones(ranks.shape[0], device=feats.device, dtype=torch.bool) + kept[1:] = ranks[1:] != ranks[:-1] + interval_starts = torch.nonzero(kept, as_tuple=True)[0].int() + interval_lengths = torch.zeros_like(interval_starts) + interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1] + interval_lengths[-1] = ranks.shape[0] - interval_starts[-1] + coords = coords.int() + + out = voxel_pool_ext.voxel_pool_forward( + feats, + coords, + interval_lengths, + interval_starts, + B, + Z, + Y, + X, + ) + + ctx.save_for_backward(interval_starts, interval_lengths, coords) + ctx.saved_shapes = B, Z, Y, X + return out + + @staticmethod + def backward(ctx, out_grad): + interval_starts, interval_lengths, coords = ctx.saved_tensors + B, Z, Y, X = ctx.saved_shapes + + out_grad = out_grad.contiguous() + feats_grad = voxel_pool_ext.voxel_pool_backward( + out_grad, + coords, + interval_lengths, + interval_starts, + B, + Z, + Y, + X, + ) + + return feats_grad, None, None, None, None, None, None + + +def voxel_pool(feats, coords, B, Z, Y, X): + # (bs_idx, z, y, x) + ranks = ( + coords[:, 0] * (Z * Y * X) + + coords[:, 1] * (Y * X) + + coords[:, 2] * X + + coords[:, 3] + ) + indices = ranks.argsort() + feats, coords, ranks = feats[indices], coords[indices], ranks[indices] + + x = VoxelPoolFunction.apply(feats, coords, ranks, B, Z, Y, X) + + return x diff --git a/requirements.txt b/requirements.txt old mode 100755 new mode 100644 index f985c13..073cd1b --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,25 @@ -h5py==3.11.0 -pyyaml==6.0.1 -tensorboard==2.16.2 -tensorboardx==2.6.2.2 -yapf==0.40.2 addict==2.4.0 +black==24.4.2 einops==0.8.0 -scipy==1.13.1 -termcolor==2.4.0 -timm==0.9.10 +h5py==3.11.0 hydra-colorlog==1.2.0 hydra-core==1.3.2 hydra-optuna-sweeper==1.2.0 +imageio==2.34.1 +isort==5.13.2 lightning>=2.0.0 +matplotlib opencv-python==4.10.0.82 pre-commit==3.7.1 pytest==8.2.2 +python-dotenv==1.0.1 +pyyaml==6.0.1 rich==13.7.1 -isort==5.13.2 -black==24.4.2 -torchmetrics>=0.11.4 rootutils==1.0.7 -imageio==2.34.1 -python-dotenv==1.0.1 -matplotlib \ No newline at end of file +scipy==1.13.1 +tensorboard==2.16.2 +tensorboardx==2.6.2.2 +termcolor==2.4.0 +timm==0.9.10 +torchmetrics>=0.11.4 +yapf==0.40.2 \ No newline at end of file diff --git a/scripts/generate_scannet_metadata.py b/scripts/generate_scannet_metadata.py new file mode 100644 index 0000000..80b6f99 --- /dev/null +++ b/scripts/generate_scannet_metadata.py @@ -0,0 +1,84 @@ +import os +from collections import defaultdict +from io import StringIO +from multiprocessing import Pool + +import numpy as np +import src.utils as U +import torch + +scene_root = "data/scannet" +rgbd_root = "data/scannet/rgbd" +scan_root = "data/scannet/scans" + + +def save_meta(scene_name, split): + meta_data = defaultdict(dict) + frame_list = U.io_utils.listdir(os.path.join(rgbd_root, scene_name, "color")) + frame_list = list(frame_list) + frame_list = [frame for frame in frame_list if frame.endswith(".jpg")] + frame_list.sort(key=lambda x: int(x.split(".")[0])) + + intrinsic = U.io_utils.load_numpy_text( + os.path.join(rgbd_root, scene_name, "intrinsic", "intrinsic_depth.txt") + ) + if intrinsic is None or np.isnan(intrinsic).any() or np.isinf(intrinsic).any(): + return + meta_data = defaultdict(dict) + meta_data["scene_name"] = scene_name + meta_data["intrinsic"] = intrinsic + meta_data["frames"] = defaultdict(dict) + for frame_name in frame_list: + pose = np.loadtxt( + StringIO( + U.io_utils.client.get_text( + os.path.join( + rgbd_root, + scene_name, + "pose", + frame_name.replace(".jpg", ".txt"), + ) + ) + ) + ) + if pose is None or np.isnan(pose).any() or np.isinf(pose).any(): + continue + pose = np.linalg.inv(pose) + if pose is None or np.isnan(pose).any() or np.isinf(pose).any(): + continue + meta_data["frames"][frame_name]["color_path"] = os.path.join( + rgbd_root, scene_name, "color", frame_name + ) + meta_data["frames"][frame_name]["depth_path"] = os.path.join( + rgbd_root, scene_name, "depth", frame_name.replace(".jpg", ".png") + ) + meta_data["frames"][frame_name]["extrinsic"] = pose + + torch.save( + meta_data, + os.path.join( + scene_root, + "metadata", + split, + f"{scene_name}.pth", + ), + ) + + +def main(): + for split in ["train", "val"]: + scene_list = [ + filename.split(".")[0] + for filename in os.listdir(os.path.join(scene_root, split)) + ] + + with Pool(processes=8) as pool: + for scene_name in scene_list: + pool.apply_async(save_meta, args=(scene_name, split)) + + pool.close() + pool.join() + + +if __name__ == "__main__": + main() diff --git a/spa/__init__.py b/spa/__init__.py new file mode 100755 index 0000000..8b13789 --- /dev/null +++ b/spa/__init__.py @@ -0,0 +1 @@ + diff --git a/spa/data/__init__.py b/spa/data/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/spa/data/base_cat_datamodule.py b/spa/data/base_cat_datamodule.py new file mode 100644 index 0000000..136c457 --- /dev/null +++ b/spa/data/base_cat_datamodule.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import math +from collections.abc import Mapping, Sequence +from functools import partial +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.distributed as dist +from lightning import LightningDataModule +from torch.utils.data import BatchSampler +from torch.utils.data import ConcatDataset as _ConcatDataset +from torch.utils.data import DataLoader, Dataset, RandomSampler, default_collate + +from .combined_loader import CombinedLoader + + +class ConcatDataset(_ConcatDataset): + pass + + +class ConcatRandomSampler(RandomSampler): + def __init__( + self, data_source, replacement=False, num_samples=None, generator=None + ): + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self._generator = generator + self._samplers = [ + RandomSampler( + _data_source, + replacement=replacement, + num_samples=num_samples, + generator=generator, + ) + for _data_source in data_source.datasets + ] + self.epoch = 0 + + @property + def generator(self): + return self._generator + + @generator.setter + def generator(self, generator): + self._generator = generator + for sampler in self._samplers: + sampler.generator = generator + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + raise NotImplementedError + + +class DistributedConcatBatchSampler(BatchSampler): + def __init__( + self, + sampler, + batch_size, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + + super().__init__(sampler, batch_size, drop_last) + self._batch_samplers = [ + BatchSampler(_sampler, batch_size, drop_last) + for _sampler in sampler._samplers + ] + self._length = [len(batch_sampler) for batch_sampler in self._batch_samplers] + self._cumulative_sizes = sampler.data_source.cumulative_sizes + + self.num_replicas = num_replicas + self.rank = rank + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.sampler.epoch) + self.sampler.generator = g + + consumed = [0] * len(self._length) + iter_num = 0 + iterators = [iter(batch_sampler) for batch_sampler in self._batch_samplers] + + initial_batch = [] + batch_to_yield = [] + while iter_num < sum(self._length): + for i, iterator in enumerate(iterators): + if consumed[i] >= self._length[i]: + continue + batch_indices = next(iterator) + cumulative_batch_indices = [ + idx + (self._cumulative_sizes[i - 1] if i > 0 else 0) + for idx in batch_indices + ] + if iter_num < self.num_replicas: + initial_batch.append(cumulative_batch_indices) + if iter_num % self.num_replicas == self.rank: + batch_to_yield = cumulative_batch_indices + if iter_num % self.num_replicas == self.num_replicas - 1: + yield batch_to_yield + batch_to_yield = [] + iter_num += 1 + consumed[i] += 1 + + if len(batch_to_yield) > 0: + yield batch_to_yield + elif iter_num < self.__len__() * self.num_replicas: + yield initial_batch[self.rank] + + def __len__(self): + return (sum(self._length) + self.num_replicas - 1) // self.num_replicas + + +class BaseCatDataModule(LightningDataModule): + """`LightningDataModule` for basic datasets. + + A `LightningDataModule` implements 7 key methods: + + ```python + def prepare_data(self): + # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). + # Download data, pre-process, split, save to disk, etc... + + def setup(self, stage): + # Things to do on every process in DDP. + # Load data, set variables, etc... + + def train_dataloader(self): + # return train dataloader + + def val_dataloader(self): + # return validation dataloader + + def test_dataloader(self): + # return test dataloader + + def predict_dataloader(self): + # return predict dataloader + + def teardown(self, stage): + # Called on every process in DDP. + # Clean up after fit or test. + ``` + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://lightning.ai/docs/pytorch/latest/data/datamodule.html + """ + + def __init__( + self, + train, + val, + test=None, + **kwargs, + ) -> None: + """Initialize a `MNISTDataModule`. + + :param data_dir: The data directory. Defaults to `"data/"`. + :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. + :param batch_size: The batch size. Defaults to `64`. + :param num_workers: The number of workers. Defaults to `0`. + :param pin_memory: Whether to pin memory. Defaults to `False`. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, ignore=["train", "val", "test"]) + + self.data_train: Optional[Dataset] = train + self.data_val: Optional[Dataset] = val + self.data_test: Optional[Dataset] = test + + def prepare_data(self) -> None: + """Download data if needed. Lightning ensures that `self.prepare_data()` is called only + within a single process on CPU, so you can safely add your downloading logic within. In + case of multi-node training, the execution of this hook depends upon + `self.prepare_data_per_node()`. + + Do not use it to assign state (self.x = y). + """ + pass + + def setup(self, stage: Optional[str] = None) -> None: + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and + `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after + `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to + `self.setup()` once the data is prepared and available for use. + + :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. + """ + # load and split datasets only if not loaded already + if not self.data_train and not self.data_val and not self.data_test: + self.data_train = self.hparams.get("train") + self.data_val = self.hparams.get("val") + self.data_test = self.hparams.get("test") + + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ + + if hasattr(self.data_train[0], "_collate_fn"): + _collate_fn = self.data_train[0]._collate_fn + else: + _collate_fn = default_collate + + data_train = ConcatDataset(self.data_train) + sampler = ConcatRandomSampler(data_train) + batch_sampler = DistributedConcatBatchSampler( + sampler, batch_size=self.hparams.batch_size_train + ) + return DataLoader( + dataset=data_train, + batch_sampler=batch_sampler, + num_workers=self.hparams.num_workers, + persistent_workers=(True if self.hparams.num_workers > 0 else False), + pin_memory=self.hparams.pin_memory, + collate_fn=_collate_fn, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ + if hasattr(self.data_val, "_collate_fn"): + _collate_fn = self.data_val._collate_fn + else: + _collate_fn = default_collate + + return DataLoader( + dataset=self.data_val, + batch_size=self.hparams.batch_size_val, + num_workers=self.hparams.num_workers, + persistent_workers=True if self.hparams.num_workers > 0 else False, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=_collate_fn, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ + if hasattr(self.data_test, "_collate_fn"): + _collate_fn = self.data_test._collate_fn + else: + _collate_fn = default_collate + + return DataLoader( + dataset=self.data_test, + batch_size=self.hparams.batch_size_test, + num_workers=self.hparams.num_workers, + persistent_workers=True if self.hparams.num_workers > 0 else False, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=_collate_fn, + ) + + def teardown(self, stage: Optional[str] = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, + `trainer.test()`, and `trainer.predict()`. + + :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Defaults to ``None``. + """ + pass + + def state_dict(self) -> dict[Any, Any]: + """Called when saving a checkpoint. Implement to generate and save the datamodule state. + + :return: A dictionary containing the datamodule state that you want to save. + """ + return {} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Called when loading a checkpoint. Implement to reload datamodule state given datamodule + `state_dict()`. + + :param state_dict: The datamodule state returned by `self.state_dict()`. + """ + pass diff --git a/spa/data/combined_loader.py b/spa/data/combined_loader.py new file mode 100644 index 0000000..c58b6a4 --- /dev/null +++ b/spa/data/combined_loader.py @@ -0,0 +1,83 @@ +from collections.abc import Iterable +from typing import List, Optional, Union + +from lightning.fabric.utilities.data import sized_len +from lightning.pytorch.utilities.combined_loader import ( + _ITERATOR_RETURN, + _SUPPORTED_MODES, + CombinedLoader, + _CombinationMode, + _ModeIterator, +) +from typing_extensions import Self, override + + +class _CyclicSequential(_ModeIterator): + def __init__( + self, + iterables: List[Iterable], + limits: Optional[List[Union[int, float]]] = None, + ) -> None: + super().__init__(iterables, limits) + self._consumed: List[int] = [] + self._length: List[int] = [] + self._iterator_idx = 0 + + @override + def __next__(self) -> _ITERATOR_RETURN: + n = len(self.iterators) + out = [None] * n + find = False + for i in range(n): + if self._consumed[self._iterator_idx] < self._length[self._iterator_idx]: + out[self._iterator_idx] = next(self.iterators[self._iterator_idx]) + self._consumed[self._iterator_idx] += 1 + find = True + break + self._iterator_idx = (self._iterator_idx + 1) % n + assert find, "All iterators are exhausted." + + iterator_idx = self._iterator_idx + self._iterator_idx = (self._iterator_idx + 1) % n + batch_idx = self._idx + self._idx += 1 + return out, batch_idx, iterator_idx + + @override + def __iter__(self) -> Self: + super().__iter__() + self._consumed = [0] * len(self.iterables) + lengths = _get_iterables_lengths(self.iterables) + if self.limits is not None: + lengths = [ + min(length, limit) for length, limit in zip(lengths, self.limits) + ] + self._length = lengths + self._iterator_idx = 0 + return self + + @override + def __len__(self) -> int: + lengths = _get_iterables_lengths(self.iterables) + if self.limits is not None: + return sum(min(length, limit) for length, limit in zip(lengths, self.limits)) # type: ignore[misc] + return sum(lengths) # type: ignore[arg-type] + + @override + def reset(self) -> None: + super().reset() + self._consumed = [] + self._length = [] + self._iterator_idx = 0 + + +_SUPPORTED_MODES["max_size_cycle"] = _CombinationMode( + fn=sum, iterator=_CyclicSequential +) + + +def _get_iterables_lengths(iterables: List[Iterable]) -> List[Union[int, float]]: + return [ + (float("inf") if (length := sized_len(iterable)) is None else length) + for iterable in iterables + ] diff --git a/spa/data/components/__init__.py b/spa/data/components/__init__.py new file mode 100755 index 0000000..99e6668 --- /dev/null +++ b/spa/data/components/__init__.py @@ -0,0 +1,2 @@ +from .processor import DataProcessor +from .scannet import ScanNetMultiViewSPAPretrain diff --git a/spa/data/components/processor/__init__.py b/spa/data/components/processor/__init__.py new file mode 100644 index 0000000..9e44cd4 --- /dev/null +++ b/spa/data/components/processor/__init__.py @@ -0,0 +1 @@ +from .data_processor import DataProcessor diff --git a/spa/data/components/processor/augmentor_utils.py b/spa/data/components/processor/augmentor_utils.py new file mode 100644 index 0000000..0712ab4 --- /dev/null +++ b/spa/data/components/processor/augmentor_utils.py @@ -0,0 +1,513 @@ +import cv2 +import numpy as np +import torch +from PIL import Image + + +def check_numpy_to_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x).float(), True + return x, False + + +def rotate_points_along_z(points, angle): + """ + Args: + points: (B, N, 3 + C) + angle: (B), angle along z-axis, angle increases x ==> y + Returns: + + """ + points, is_numpy = check_numpy_to_torch(points) + angle, _ = check_numpy_to_torch(angle) + + cosa = torch.cos(angle) + sina = torch.sin(angle) + zeros = angle.new_zeros(points.shape[0]) + ones = angle.new_ones(points.shape[0]) + rot_matrix = ( + torch.stack((cosa, sina, zeros, -sina, cosa, zeros, zeros, zeros, ones), dim=1) + .view(-1, 3, 3) + .float() + ) + points_rot = torch.matmul(points[:, :, 0:3], rot_matrix) + points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1) + return points_rot.numpy() if is_numpy else points_rot + + +def angle2matrix(angle): + """ + Args: + angle: angle along z-axis, angle increases x ==> y + Returns: + rot_matrix: (3x3 Tensor) rotation matrix + """ + angle, is_numpy = check_numpy_to_torch(angle) + cosa = torch.cos(angle) + sina = torch.sin(angle) + rot_matrix = torch.tensor([[cosa, -sina, 0], [sina, cosa, 0], [0, 0, 1]]) + return rot_matrix.numpy() if is_numpy else rot_matrix + + +def global_drop(points, drop_ratio, prob=0.5): + enable = np.random.choice( + [False, True], + replace=False, + p=[1 - prob, prob], + ) + if enable: + choice = np.arange(0, len(points), dtype=np.int32) + choice = np.random.choice( + choice, int((1 - drop_ratio) * len(points)), replace=False + ) + points = points[choice] + return points + + +def random_flip_along_x(gt_boxes, points, prob=0.5): + """ + Args: + gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] + points: (M, 3 + C) + Returns: + """ + enable = np.random.choice([False, True], replace=False, p=[1 - prob, prob]) + matrix = np.eye(4) + if enable: + if gt_boxes is not None: + gt_boxes[:, 1] = -gt_boxes[:, 1] + gt_boxes[:, 6] = -gt_boxes[:, 6] + if gt_boxes.shape[1] > 7: + gt_boxes[:, 8] = -gt_boxes[:, 8] + if points is not None: + points[:, 1] = -points[:, 1] + matrix[1, 1] = -1 + + return gt_boxes, points, matrix + + +def random_flip_along_y(gt_boxes, points, prob=0.5): + """ + Args: + gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] + points: (M, 3 + C) + Returns: + """ + enable = np.random.choice([False, True], replace=False, p=[1 - prob, prob]) + matrix = np.eye(4) + if enable: + if gt_boxes is not None: + gt_boxes[:, 0] = -gt_boxes[:, 0] + gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi) + if gt_boxes.shape[1] > 7: + gt_boxes[:, 7] = -gt_boxes[:, 7] + if points is not None: + points[:, 0] = -points[:, 0] + matrix[0, 0] = -1 + + return gt_boxes, points, matrix + + +def global_rotation(gt_boxes, points, rot_range, prob=0.5): + """ + Args: + gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] + points: (M, 3 + C), + rot_range: [min, max] + Returns: + """ + enable = np.random.choice( + [False, True], + replace=False, + p=[1 - prob, prob], + ) + matrix = np.eye(4) + if enable: + noise_rotation = np.random.uniform(rot_range[0], rot_range[1]) + if points is not None: + points = rotate_points_along_z( + points[np.newaxis, :, :], np.array([noise_rotation]) + )[0] + if gt_boxes is not None: + gt_boxes[:, 0:3] = rotate_points_along_z( + gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]) + )[0] + gt_boxes[:, 6] += noise_rotation + if gt_boxes.shape[1] > 7: + gt_boxes[:, 7:9] = rotate_points_along_z( + np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[ + np.newaxis, :, : + ], + np.array([noise_rotation]), + )[0][:, 0:2] + matrix[:3, :3] = angle2matrix(np.array(noise_rotation)) + + return gt_boxes, points, matrix + + +def global_scaling(gt_boxes, points, scale_range, prob=0.5): + """ + Args: + gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading] + points: (M, 3 + C), + scale_range: [min, max] + Returns: + """ + enable = np.random.choice( + [False, True], + replace=False, + p=[1 - prob, prob], + ) + matrix = np.eye(4) + if enable: + noise_scale = np.random.uniform(scale_range[0], scale_range[1]) + if points is not None: + points[:, :3] *= noise_scale + if gt_boxes is not None: + gt_boxes[:, :6] *= noise_scale + if gt_boxes.shape[1] > 7: + gt_boxes[:, 7:] *= noise_scale + matrix[:3, :3] *= noise_scale + + return gt_boxes, points, matrix + + +def global_translation(gt_boxes, points, translate_std, prob=0.5): + """ + Args: + gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading] + points: (M, 3 + C), + scale_range: [min, max] + Returns: + """ + enable = np.random.choice( + [False, True], + replace=False, + p=[1 - prob, prob], + ) + matrix = np.eye(4) + if enable: + noise_translate = np.array( + [ + np.random.normal(0, translate_std[0], 1), + np.random.normal(0, translate_std[1], 1), + np.random.normal(0, translate_std[2], 1), + ], + dtype=np.float32, + ).T + if points is not None: + points[:, :3] += noise_translate + if gt_boxes is not None: + gt_boxes[:, :3] += noise_translate + matrix[:3, 3] = noise_translate.squeeze() + + return gt_boxes, points, matrix + + +def create_grid_mask( + size, + mask_ratio=0.5, + probability=0.7, + rotate_angle=0, + add_noise=False, + rise_with_epoch=False, + epoch_state=None, +): + if rise_with_epoch: + assert epoch_state is not None + probability = epoch_state[0] / epoch_state[1] * probability + + ratio = 1 - mask_ratio + h, w = size[:2] + hh, ww = int(1.5 * h), int(1.5 * w) + d = np.random.randint(2, h) + l = min(max(int(d * ratio + 0.5), 1), d - 1) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + + mask = np.ones((hh, ww), dtype=np.uint8) + for i in range(hh // d): + s = d * i + st_h + t = min(s + l, hh) + mask[s:t, :] = 0 + + for i in range(ww // d): + s = d * i + st_w + t = min(s + l, ww) + mask[:, s:t] = 0 + + if rotate_angle > 0: + angle = np.random.randint(rotate_angle + 1) + c_x, c_y = (ww - 1) * 0.5, (hh - 1) * 0.5 + matrix = cv2.getRotationMatrix2D((c_x, c_y), -angle, 1) + mask = cv2.warpAffine(mask, matrix, (ww, hh)) + + mask = mask[(hh - h) // 2 : (hh - h) // 2 + h, (ww - w) // 2 : (ww - w) // 2 + w] + mask = 1 - mask # 1 for visible, 0 for mask + + if add_noise: + # add noise for mask region + noise = 2 * (np.random.rand(*mask.shape) - 0.5) # [-1, 1] + mask_noise = noise * (1 - mask) + else: + mask_noise = np.zeros_like(mask) + + return mask, mask_noise + + +def crop(img, start_h, start_w, crop_h, crop_w): + img_src = np.zeros((crop_h, crop_w, *img.shape[2:]), dtype=img.dtype) + hsize, wsize = crop_h, crop_w + dh, dw, sh, sw = start_h, start_w, 0, 0 + if dh < 0: + sh = -dh + hsize += dh + dh = 0 + if dh + hsize > img.shape[0]: + hsize = img.shape[0] - dh + if dw < 0: + sw = -dw + wsize += dw + dw = 0 + if dw + wsize > img.shape[1]: + wsize = img.shape[1] - dw + img_src[sh : sh + hsize, sw : sw + wsize] = img[dh : dh + hsize, dw : dw + wsize] + return img_src + + +cv2_interp_codes = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, +} + +pillow_interp_codes = { + "nearest": Image.NEAREST, + "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC, + "box": Image.BOX, + "lanczos": Image.LANCZOS, + "hamming": Image.HAMMING, +} + + +def resize(img, size, interpolation="bilinear", backend="cv2"): + """Resize image to a given size. + + Args: + img (ndarray): The input image. + size (tuple[int]): Target size (w, h). + return_scale (bool): Whether to return `w_scale` and `h_scale`. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + out (ndarray): The output destination. + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. Default: None. + + Returns: + tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or + `resized_img`. + """ + h, w = img.shape[:2] + if h == size[1] and w == size[0]: + return img + + if backend not in ["cv2", "pillow"]: + raise ValueError( + f"backend: {backend} is not supported for resize." + f"Supported backends are 'cv2', 'pillow'" + ) + + if backend == "pillow": + assert img.dtype == np.uint8, "Pillow backend only support uint8 type" + pil_image = Image.fromarray(img) + pil_image = pil_image.resize(size, pillow_interp_codes[interpolation]) + resized_img = np.array(pil_image) + else: + resized_img = cv2.resize( + img, size, interpolation=cv2_interp_codes[interpolation] + ) + + return resized_img + + +class ColorJitter(object): + def __init__( + self, + contrast=[0.5, 1.5], + saturation=[0.5, 1.5], + hue=[-0.05, 0.05], + brightness=[0.875, 1.125], + p=0.5, + ): + self.contrast = contrast + self.saturation = saturation + self.hue = hue + self.brightness = brightness + self.p = p + self.reset_params() + + def rgb_to_grayscale(self, img): + r, g, b = img.unbind(dim=-3) + # This implementation closely follows the TF one: + # https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype) + l_img = l_img.unsqueeze(dim=-3) + return l_img + + def adjust_brightness(self, img, brightness_factor): + return self._blend(img, torch.zeros_like(img), brightness_factor) + + def adjust_contrast(self, img, contrast_factor): + mean = torch.mean(self.rgb_to_grayscale(img), dim=(-3, -2, -1), keepdim=True) + return self._blend(img, mean, contrast_factor) + + def adjust_hue(self, img, hue_factor): + img = self._rgb2hsv(img) + h, s, v = img.unbind(dim=-3) + h = (h + hue_factor) % 1.0 + img = torch.stack((h, s, v), dim=-3) + img_hue_adj = self._hsv2rgb(img) + return img_hue_adj + + def adjust_saturation(self, img, saturation_factor): + return self._blend(img, self.rgb_to_grayscale(img), saturation_factor) + + def _blend(self, img1, img2, ratio): + return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, 1).to(img1.dtype) + + def _rgb2hsv(self, img): + r, g, b = img.unbind(dim=-3) + + # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/ + # src/libImaging/Convert.c#L330 + maxc = torch.max(img, dim=-3).values + minc = torch.min(img, dim=-3).values + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occurring, so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + cr = maxc - minc + # Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine. + ones = torch.ones_like(maxc) + s = cr / torch.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + cr_divisor = torch.where(eqc, ones, cr) + rc = (maxc - r) / cr_divisor + gc = (maxc - g) / cr_divisor + bc = (maxc - b) / cr_divisor + + hr = (maxc == r) * (bc - gc) + hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) + hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) + h = hr + hg + hb + h = torch.fmod((h / 6.0 + 1.0), 1.0) + return torch.stack((h, s, maxc), dim=-3) + + def _hsv2rgb(self, img): + h, s, v = img.unbind(dim=-3) + i = torch.floor(h * 6.0) + f = (h * 6.0) - i + i = i.to(dtype=torch.int32) + + p = torch.clamp((v * (1.0 - s)), 0.0, 1.0) + q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0) + t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0) + i = i % 6 + + mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) + + a1 = torch.stack((v, q, p, p, t, v), dim=-3) + a2 = torch.stack((t, v, v, q, p, p), dim=-3) + a3 = torch.stack((p, p, t, v, v, q), dim=-3) + a4 = torch.stack((a1, a2, a3), dim=-4) + + return torch.einsum("...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4) + + def reset_params(self): + """Get the parameters for the randomized transform to be applied on image. + + Args: + brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen + uniformly. Pass None to turn off the transformation. + contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen + uniformly. Pass None to turn off the transformation. + saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen + uniformly. Pass None to turn off the transformation. + hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. + Pass None to turn off the transformation. + + Returns: + tuple: The parameters used to apply the randomized transform + along with their random order. + """ + r = torch.rand(7) + b = ( + float(torch.empty(1).uniform_(self.brightness[0], self.brightness[1])) + if r[0] < self.p + else None + ) + c = ( + float(torch.empty(1).uniform_(self.contrast[0], self.contrast[1])) + if r[2] < self.p or r[5] < self.p + else None + ) + s = ( + float(torch.empty(1).uniform_(self.saturation[0], self.saturation[1])) + if r[3] < self.p + else None + ) + h = ( + float(torch.empty(1).uniform_(self.hue[0], self.hue[1])) + if r[4] < self.p + else None + ) + p = torch.randperm(3) if r[6] < self.p else None + self.jitter_params = (r, b, c, s, h, p) + + def __call__(self, img): + """ + Args: + img (PIL Image or Tensor): Input image. + + Returns: + PIL Image or Tensor: Color jittered image. + """ + r, b, c, s, h, p = self.jitter_params + + if r[0] < self.p: + img = self.adjust_brightness(img, b) + + contrast_before = r[1] < 0.5 + if contrast_before: + if r[2] < self.p: + img = self.adjust_contrast(img, c) + + if r[3] < self.p: + img = self.adjust_saturation(img, s) + + if r[4] < self.p: + img = self.adjust_hue(img, h) + + if not contrast_before: + if r[5] < self.p: + img = self.adjust_contrast(img, c) + + if r[6] < self.p: + img = img[..., p, :, :] + + return img diff --git a/spa/data/components/processor/data_processor.py b/spa/data/components/processor/data_processor.py new file mode 100644 index 0000000..a44e399 --- /dev/null +++ b/spa/data/components/processor/data_processor.py @@ -0,0 +1,691 @@ +import copy +from functools import partial, reduce + +import cv2 +import numpy as np +import torch + +from spa.models.components.model_utils.render_utils import scene_colliders + +from . import augmentor_utils + + +class DataProcessor(object): + def __init__(self, processor_cfg, mode, logger): + self.mode = mode + self.logger = logger + self.collider = None + enabled_proc_list = processor_cfg.get("enabled_proc_list", {self.mode: []}) + print(f"Init {self.mode} DataProcessor with {enabled_proc_list}") + proc_config = processor_cfg.get("proc_config", {}) + self.data_processor_queue = [] + for proc_name in enabled_proc_list[self.mode]: + assert proc_name in proc_config.keys(), f"{proc_name} not in proc_config" + cur_processor = getattr(self, proc_name)(config=proc_config[proc_name]) + self.data_processor_queue.append(cur_processor) + + def random_world_drop(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_drop, config=config) + + points = data_dict["points"] + points = getattr(augmentor_utils, "global_drop")( + points, config["drop_ratio"], config["probability"] + ) + + data_dict["points"] = points + return data_dict + + def random_world_flip(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_flip, config=config) + + gt_boxes = data_dict.get("gt_boxes", None) + points = data_dict.get("points", None) + matrix = [] + for cur_axis in config["along_axis_list"]: + assert cur_axis in ["x", "y"] + gt_boxes, points, mat = getattr( + augmentor_utils, "random_flip_along_%s" % cur_axis + )(gt_boxes, points, config["probability"]) + matrix.append(mat) + matrix = reduce(np.dot, matrix[::-1]) + + if gt_boxes is not None: + data_dict["gt_boxes"] = gt_boxes + if points is not None: + data_dict["points"] = points + data_dict["trans3d_matrix"] = matrix @ data_dict["trans3d_matrix"] + return data_dict + + def random_world_rotation(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_rotation, config=config) + + gt_boxes = data_dict.get("gt_boxes", None) + points = data_dict.get("points", None) + rot_range = config["world_rot_angle"] + gt_boxes, points, matrix = augmentor_utils.global_rotation( + gt_boxes, points, rot_range, config["probability"] + ) + + if gt_boxes is not None: + data_dict["gt_boxes"] = gt_boxes + if points is not None: + data_dict["points"] = points + data_dict["trans3d_matrix"] = matrix @ data_dict["trans3d_matrix"] + return data_dict + + def random_world_scaling(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_scaling, config=config) + + gt_boxes = data_dict.get("gt_boxes", None) + points = data_dict.get("points", None) + gt_boxes, points, matrix = augmentor_utils.global_scaling( + gt_boxes, + points, + config["world_scale_range"], + config["probability"], + ) + + if gt_boxes is not None: + data_dict["gt_boxes"] = gt_boxes + if points is not None: + data_dict["points"] = points + data_dict["trans3d_matrix"] = matrix @ data_dict["trans3d_matrix"] + return data_dict + + def random_world_translation(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.random_world_translation, config=config) + + gt_boxes = data_dict.get("gt_boxes", None) + points = data_dict.get("points", None) + noise_translate_std = config["noise_translate_std"] + assert len(noise_translate_std) == 3 + gt_boxes, points, matrix = augmentor_utils.global_translation( + gt_boxes, + points, + noise_translate_std, + config["probability"], + ) + + if gt_boxes is not None: + data_dict["gt_boxes"] = gt_boxes + if points is not None: + data_dict["points"] = points + data_dict["trans3d_matrix"] = matrix @ data_dict["trans3d_matrix"] + return data_dict + + def filter_depth_outlier(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.filter_depth_outlier, config=config) + + new_depth = [] + for i, _depth in enumerate(data_dict["depth"]): + mask = _depth > 1e-3 + valid_depth = _depth[mask] + k = int(len(valid_depth) * config["percentile"]) + dmax = np.sort(np.partition(valid_depth, -k)[-k:])[0] + dmin = np.sort(np.partition(valid_depth, k)[:k])[-1] + mask &= (_depth > dmin) & (_depth < dmax) + new_depth.append(_depth * mask) + + data_dict["depth"] = new_depth + return data_dict + + def imresize(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.imresize, config=config) + + extra_keys = config.get("extra_keys", []) + extra_imgs = {key: [] for key in extra_keys} + new_img = [] + new_trans2d_matrix = [] + resize = None + for i, _img in enumerate(data_dict["img"]): + if not config["mv_consistency"] or resize is None: + resize = np.random.uniform(*config["resize_scale"][self.mode]) + h, w = _img.shape[:2] + new_size = (int(w * resize), int(h * resize)) + new_img.append(augmentor_utils.resize(_img, new_size, "lanczos", "pillow")) + matrix = np.eye(4) + matrix[0, 0] = new_size[0] / w + matrix[1, 1] = new_size[1] / h + new_trans2d_matrix.append(matrix @ data_dict["trans2d_matrix"][i]) + + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + extra_img = data_dict[extra_key][i] + assert extra_img.shape[:2] == _img.shape[:2] + extra_imgs[extra_key].append( + augmentor_utils.resize(extra_img, new_size, "nearest", "cv2") + ) + + data_dict["img"] = new_img + data_dict["trans2d_matrix"] = new_trans2d_matrix + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + data_dict[extra_key] = extra_imgs[extra_key] + return data_dict + + def imcrop(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.imcrop, config=config) + + extra_keys = config.get("extra_keys", []) + extra_imgs = {key: [] for key in extra_keys} + new_img = [] + new_trans2d_matrix = [] + crop_pos_w, crop_pos_h = None, None + for i, _img in enumerate(data_dict["img"]): + crop_size = config["crop_size"] + if not config["mv_consistency"] or crop_pos_w is None: + crop_pos_h = np.random.uniform(*config["crop_pos"][self.mode][0]) + crop_pos_w = np.random.uniform(*config["crop_pos"][self.mode][1]) + + start_h = int(crop_pos_h * max(0, _img.shape[0] - crop_size[0])) + start_w = int(crop_pos_w * max(0, _img.shape[1] - crop_size[1])) + + _img_src = augmentor_utils.crop( + _img, start_h, start_w, crop_size[0], crop_size[1] + ) + new_img.append(_img_src) + matrix = np.eye(4) + matrix[0, 2] = -start_w + matrix[1, 2] = -start_h + new_trans2d_matrix.append(matrix @ data_dict["trans2d_matrix"][i]) + + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + extra_img = data_dict[extra_key][i] + assert extra_img.shape[:2] == _img.shape[:2] + _extra_img_src = augmentor_utils.crop( + extra_img, start_h, start_w, crop_size[0], crop_size[1] + ) + extra_imgs[extra_key].append(_extra_img_src) + + data_dict["img"] = new_img + data_dict["trans2d_matrix"] = new_trans2d_matrix + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + data_dict[extra_key] = extra_imgs[extra_key] + return data_dict + + def imflip(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.imflip, config=config) + + extra_keys = config.get("extra_keys", []) + extra_imgs = {key: [] for key in extra_keys} + new_img = [] + new_trans2d_matrix = [] + enable = None + for i, _img in enumerate(data_dict["img"]): + if not config["mv_consistency"] or enable is None: + flip_ratio = config["flip_ratio"] + enable = np.random.choice( + [False, True], replace=False, p=[1 - flip_ratio, flip_ratio] + ) + matrix = np.eye(4) + if enable: + _img = np.flip(_img, axis=1) + matrix[0, 0] = -1 + matrix[0, 2] = _img.shape[1] - 1 + new_img.append(_img) + new_trans2d_matrix.append(matrix @ data_dict["trans2d_matrix"][i]) + + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + extra_img = data_dict[extra_key][i] + assert extra_img.shape[:2] == _img.shape[:2] + extra_imgs[extra_key].append( + np.flip(extra_img, axis=1) if enable else extra_img + ) + + data_dict["img"] = new_img + data_dict["trans2d_matrix"] = new_trans2d_matrix + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + data_dict[extra_key] = extra_imgs[extra_key] + return data_dict + + def imrotate(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.imrotate, config=config) + + extra_keys = config.get("extra_keys", []) + extra_imgs = {key: [] for key in extra_keys} + new_img = [] + new_trans2d_matrix = [] + angle = None + for i, _img in enumerate(data_dict["img"]): + if not config["mv_consistency"] or angle is None: + angle = np.random.uniform(*config["rotate_angle"]) + h, w = _img.shape[:2] + c_x, c_y = (w - 1) * 0.5, (h - 1) * 0.5 + matrix = cv2.getRotationMatrix2D((c_x, c_y), -angle, 1) + new_img.append(cv2.warpAffine(_img, matrix, (w, h))) + rot_sin, rot_cos = np.sin(angle / 180 * np.pi), np.cos(angle / 180 * np.pi) + matrix = np.eye(4) + matrix[:2, :3] = np.array( + [ + [rot_cos, -rot_sin, (1 - rot_cos) * c_x + rot_sin * c_y], + [rot_sin, rot_cos, (1 - rot_cos) * c_y - rot_sin * c_x], + ] + ) + new_trans2d_matrix.append(matrix @ data_dict["trans2d_matrix"][i]) + + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + extra_img = data_dict[extra_key][i] + assert extra_img.shape[:2] == _img.shape[:2] + extra_imgs[extra_key].append( + cv2.warpAffine(extra_img, matrix, (w, h), flags=cv2.INTER_NEAREST) + ) + + data_dict["img"] = new_img + data_dict["trans2d_matrix"] = new_trans2d_matrix + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + data_dict[extra_key] = extra_imgs[extra_key] + return data_dict + + def imnormalize(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.imnormalize, config=config) + new_img = [] + mean = np.array(config["mean"], dtype=np.float32) + std = np.array(config["std"], dtype=np.float32) + to_rgb = config.get("to_rgb", False) + for i, _img in enumerate(data_dict["img"]): + _img = _img.astype(np.float32) + if to_rgb: + _img = cv2.cvtColor(_img, cv2.COLOR_BGR2RGB) + _img = (_img - mean) / std + new_img.append(_img) + data_dict["img"] = new_img + data_dict["img_norm_cfg"] = {"to_rgb": to_rgb, "mean": mean, "std": std} + return data_dict + + def impad(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.impad, config=config) + + extra_keys = config.get("extra_keys", []) + extra_imgs = {key: [] for key in extra_keys} + new_img = [] + for i, _img in enumerate(data_dict["img"]): + if config.get("size", None) is not None: + size = config["size"] + else: + size_divisor = config["size_divisor"] + size = ( + int(np.ceil(_img.shape[0] / size_divisor)) * size_divisor, + int(np.ceil(_img.shape[1] / size_divisor)) * size_divisor, + ) + padding = (0, 0, size[1] - _img.shape[1], size[0] - _img.shape[0]) + _img = cv2.copyMakeBorder( + _img, + padding[1], + padding[3], + padding[0], + padding[2], + cv2.BORDER_CONSTANT, + value=0, + ) + new_img.append(_img) + + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + extra_img = data_dict[extra_key][i] + assert extra_img.shape[:2] == _img.shape[:2] + extra_imgs[extra_key].append( + cv2.copyMakeBorder( + extra_img, + padding[1], + padding[3], + padding[0], + padding[2], + cv2.BORDER_CONSTANT, + value=0, + ) + ) + + data_dict["img"] = new_img + for extra_key in extra_keys: + if extra_key not in data_dict.keys(): + continue + data_dict[extra_key] = extra_imgs[extra_key] + return data_dict + + def grid_mask(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.grid_mask, config=config) + + new_img = [] + mask, mask_noise = None, None + for i, _img in enumerate(data_dict["img"]): + if not config["mv_consistency"] or mask is None: + mask, mask_noise = augmentor_utils.create_grid_mask( + _img.shape, + mask_ratio=config["mask_ratio"], + probability=config["probability"], + rotate_angle=config["rotate_angle"], + add_noise=config["add_noise"], + rise_with_epoch=config["rise_with_epoch"], + epoch_state=data_dict["epoch_state"], + ) + new_img.append(_img * mask[..., None] + mask_noise[..., None]) + data_dict["img"] = new_img + return data_dict + + def trans_to_local(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.trans_to_local, config=config) + for key, val in config["mapping_keys"].items(): + data_dict[val] = data_dict.pop(key) + if config.get("to_local_key", None) is not None: + key, inverse = config["to_local_key"] + # We assume the matrix shape is: (N, 4, 4) + inv_mat = np.linalg.inv(data_dict[key][0]) + data_dict[key] = ( + data_dict[key] @ inv_mat[None] + if inverse + else inv_mat[None] @ data_dict[key] + ) + + return data_dict + + def merge_trans_matrix(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.merge_trans_matrix, config=config) + + if "trans2d_matrix" in data_dict.keys(): + data_dict["trans2d_matrix"] = np.stack(data_dict["trans2d_matrix"], axis=0) + + for key, val in config["keys"].items(): + for data_key, inverse in val: + trans_matrix = ( + np.linalg.inv(data_dict[key]) if inverse else data_dict[key] + ) + data_dict[data_key] = ( + data_dict[data_key] @ trans_matrix + if inverse + else trans_matrix @ data_dict[data_key] + ) + + return data_dict + + def filter_depth_outlier_old(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.filter_depth_outlier_old, config=config) + + depth = np.stack(data_dict["depth"], axis=0) + mask = depth > 1e-3 + valid_depth = depth[mask] + k = int(len(valid_depth) * config["percentile"]) + dmax = np.sort(np.partition(valid_depth, -k)[-k:])[0] + dmin = np.sort(np.partition(valid_depth, k)[:k])[-1] + mask &= (depth > dmin) & (depth < dmax) + + new_depth = [] + for i in range(len(depth)): + new_depth.append(depth[i] * mask[i]) + + data_dict["depth"] = new_depth + return data_dict + + def calc_ray_from_depth(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.calc_ray_from_depth, config=config) + + # (N, 4, 4) + cam2img = data_dict["cam2img"] + world2cam = data_dict["world2cam"] + # (N, H, W) + depth = np.stack(data_dict["depth"], axis=0) + N = len(depth) + + mask = depth > 1e-3 + img2cam = np.linalg.inv(cam2img) + cam2world = np.linalg.inv(world2cam) + img2world = cam2world @ img2cam + + H, W = depth.shape[-2:] + pixel_y, pixel_x = np.meshgrid( + np.linspace(0.0, H - 1.0, H), + np.linspace(0.0, W - 1.0, W), + indexing="ij", + ) + ray_end = np.stack( + [ + np.broadcast_to(pixel_x, depth.shape), + np.broadcast_to(pixel_y, depth.shape), + np.where(mask, depth, np.ones_like(depth)), + np.ones_like(depth), + ], + axis=-1, + ) + ray_end[..., :2] *= ray_end[..., 2:3] + + # (N, H, W, 4, 4) @ (N, H, W, 4, 1) -> (N, H, W, 3) + ray_end = np.matmul(img2world[:, None, None, :, :], ray_end[..., None])[ + ..., :3, 0 + ] + ray_o = np.broadcast_to(cam2world[:, None, None, :3, 3], ray_end.shape) + ray_d = ray_end - ray_o + ray_d_unit = ray_d / np.linalg.norm(ray_d, axis=-1, keepdims=True) + ray_depth = np.linalg.norm(ray_d, axis=-1) + ray_depth[~mask] = 0.0 + + ray_scale = max(ray_depth[mask].sum() / max(mask.sum(), 1), 1e-6) + + ray_p = np.stack([pixel_x / W, pixel_y / H], axis=-1) + ray_p = np.broadcast_to(ray_p[None, ...], ray_d[..., :2].shape) + + data_dict["ray_depth"] = ray_depth.reshape(N, -1) + data_dict["ray_scale"] = ray_scale + data_dict["ray_o"] = ray_o.reshape(N, -1, 3) + data_dict["ray_d"] = ray_d_unit.reshape(N, -1, 3) + data_dict["ray_p"] = ray_p.reshape(N, -1, 2) + return data_dict + + def calc_ray_from_depth_v2(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.calc_ray_from_depth_v2, config=config) + + # (N, 4, 4) + cam2img = data_dict["cam2img"] + world2cam = data_dict["world2cam"] + # (N, H, W) + depth = np.stack(data_dict["depth"], axis=0) + N = len(depth) + + mask = depth > 1e-3 + img2cam = np.linalg.inv(cam2img) + cam2world = np.linalg.inv(world2cam) + img2world = cam2world @ img2cam + + H, W = depth.shape[-2:] + pixel_y, pixel_x = np.meshgrid( + np.linspace(0.0, H - 1.0, H), + np.linspace(0.0, W - 1.0, W), + indexing="ij", + ) + ray_end = np.stack( + [ + np.broadcast_to(pixel_x, depth.shape), + np.broadcast_to(pixel_y, depth.shape), + np.where(mask, depth, np.ones_like(depth)), + np.ones_like(depth), + ], + axis=-1, + ) + ray_end[..., :2] *= ray_end[..., 2:3] + + # (N, H, W, 4, 4) @ (N, H, W, 4, 1) -> (N, H, W, 3) + ray_end = np.matmul(img2world[:, None, None, :, :], ray_end[..., None])[ + ..., :3, 0 + ] + ray_o = np.broadcast_to(cam2world[:, None, None, :3, 3], ray_end.shape) + ray_d = ray_end - ray_o + ray_d_unit = ray_d / np.linalg.norm(ray_d, axis=-1, keepdims=True) + ray_depth = np.linalg.norm(ray_d, axis=-1) + ray_depth[~mask] = 0.0 + + ray_scale = max(ray_depth[mask].sum() / max(mask.sum(), 1), 1e-6) + + ray_p = np.stack([pixel_x, pixel_y], axis=-1) + ray_p = np.broadcast_to(ray_p[None, ...], ray_d[..., :2].shape) + + data_dict["ray_depth"] = ray_depth.reshape(N, -1) + data_dict["ray_scale"] = ray_scale + data_dict["ray_o"] = ray_o.reshape(N, -1, 3) + data_dict["ray_d"] = ray_d_unit.reshape(N, -1, 3) + data_dict["ray_p"] = ray_p.reshape(N, -1, 2) + + return data_dict + + def calc_scene_bbox(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.calc_scene_bbox, config=config) + + if config["type"] == "dynamic_depth": + ray_depth = data_dict["ray_depth"] + ray_d = data_dict["ray_d"] + ray_o = data_dict["ray_o"] + mask = ray_depth > 1e-3 + ray_depth = ray_depth[mask] + ray_d = ray_d[mask] + ray_o = ray_o[mask] + pc = ray_o + ray_d * ray_depth[..., None] + point_cloud_range = np.concatenate([pc.min(axis=0), pc.max(axis=0)]) + elif config["type"] == "dynamic_point": + raise NotImplementedError + elif config["type"] == "static": + point_cloud_range = np.array(config["point_cloud_range"], dtype=np.float32) + else: + raise NotImplementedError + + data_dict["point_cloud_range"] = point_cloud_range.astype(np.float32) + return data_dict + + def calc_voxel_size(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.calc_voxel_size, config=config) + + point_cloud_range = data_dict["point_cloud_range"] + grid_size = np.array(config["grid_size"], dtype=np.int64) + voxel_size = (point_cloud_range[3:] - point_cloud_range[:3]) / grid_size + data_dict["voxel_size"] = voxel_size.astype(np.float32) + data_dict["grid_size"] = grid_size + return data_dict + + def sample_ray(self, data_dict=None, config=None): + if data_dict is None: + self.collider = getattr(scene_colliders, config["collider"]["type"])( + **config["collider"] + ) + return partial(self.sample_ray, config=config) + + scene_bbox = data_dict["point_cloud_range"] + ( + iray_depth, + iray_o, + iray_d, + iray_near, + iray_far, + iray_p, + iray_idx, + ) = ([], [], [], [], [], [], []) + for cidx in range(len(data_dict["depth"])): + ray_depth = data_dict["ray_depth"][cidx] + ray_o = data_dict["ray_o"][cidx] + ray_d = data_dict["ray_d"][cidx] + ray_p = data_dict["ray_p"][cidx] + ray_near, ray_far, ray_mask = self.collider(ray_o, ray_d, scene_bbox) + + if self.mode == "train": + ray_mask = ray_mask & (ray_depth > 1e-3) + valid_idx = np.nonzero(ray_mask)[0] + assert len(valid_idx) > 0, "No ray is valid for camera %d" % cidx + sampled_idx = np.random.choice( + len(valid_idx), + config["ray_nsample"], + replace=False, + ) + sampled_idx = valid_idx[sampled_idx] + + ray_depth, ray_o, ray_d, ray_p, ray_near, ray_far = ( + ray_depth[sampled_idx], + ray_o[sampled_idx], + ray_d[sampled_idx], + ray_p[sampled_idx], + ray_near[sampled_idx], + ray_far[sampled_idx], + ) + + iray_depth.append(ray_depth) + iray_o.append(ray_o) + iray_d.append(ray_d) + iray_p.append(ray_p) + iray_near.append(ray_near) + iray_far.append(ray_far) + iray_idx.append(np.full_like(ray_depth, cidx)) + + assert len(iray_depth) > 0, "No ray is valid" + data_dict["ray_depth"] = np.concatenate(iray_depth, axis=0) + data_dict["ray_o"] = np.concatenate(iray_o, axis=0) + data_dict["ray_d"] = np.concatenate(iray_d, axis=0) + data_dict["ray_p"] = np.concatenate(iray_p, axis=0) + data_dict["ray_near"] = np.concatenate(iray_near, axis=0) + data_dict["ray_far"] = np.concatenate(iray_far, axis=0) + data_dict["ray_idx"] = np.concatenate(iray_idx, axis=0) + + return data_dict + + def collect(self, data_dict=None, config=None): + if data_dict is None: + return partial(self.collect, config=config) + + collected_data_dict = {} + for key in config["keys"][self.mode]: + # gt_boxes and gt_names are not necessary for test + if key not in data_dict.keys(): + continue + + if key in ["img", "ori_img", "semantic_img"]: + # (H, W, 3) -> (3, H, W) + data_dict[key] = [ + np.ascontiguousarray(_img.transpose(2, 0, 1).astype(np.float32)) + / 255.0 + for _img in data_dict[key] + ] + collected_data_dict[key] = data_dict[key] + return collected_data_dict + + def forward(self, data_dict): + """ + Args: + data_dict: + points: (N, 3 + C_in) + gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] + gt_names: optional, (N), string + ... + + Returns: + """ + for cur_processor in self.data_processor_queue: + data_dict = cur_processor(data_dict=data_dict) + return data_dict diff --git a/spa/data/components/processor/data_processor_gpu.py b/spa/data/components/processor/data_processor_gpu.py new file mode 100644 index 0000000..7a4b483 --- /dev/null +++ b/spa/data/components/processor/data_processor_gpu.py @@ -0,0 +1,224 @@ +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F + +from . import augmentor_utils + + +class DataProcessorGPU(object): + def __init__(self, processor_cfg, mode, logger): + self.mode = mode + self.logger = logger + self.color_jitter = None + + enabled_proc_list = processor_cfg.get("enabled_proc_list", {self.mode: []}) + proc_config = processor_cfg.get("proc_config", {}) + self.data_processor_queue = [] + message = "gpu processor:" + for proc_name in enabled_proc_list[self.mode]: + message += f" {proc_name}" + assert proc_name in proc_config.keys(), f"{proc_name} not in proc_config" + cur_processor = getattr(self, proc_name)(config=proc_config[proc_name]) + self.data_processor_queue.append(cur_processor) + # self.logger.info(message) + + def random_photometric_distort(self, batch_dict=None, config=None): + assert self.mode == "train" + if batch_dict is None: + self.color_jitter = augmentor_utils.ColorJitter( + contrast=config["contrast"], + saturation=config["saturation"], + hue=config["hue"], + brightness=config["brightness"], + p=config["p"], + ) + return partial(self.random_photometric_distort, config=config) + + assert config["mv_consistency"] + + img = batch_dict["img"] + for bs_idx in range(len(img)): + self.color_jitter.reset_params() + for cam_idx in range(len(img[bs_idx])): + img[bs_idx][cam_idx] = self.color_jitter(img[bs_idx][cam_idx]) + + batch_dict["img"] = img + return batch_dict + + def imnormalize(self, batch_dict=None, config=None): + if batch_dict is None: + return partial(self.imnormalize, config=config) + img = batch_dict["img"] + for bs_idx in range(len(img)): + mean = img[bs_idx].new_tensor(config["mean"]) + std = img[bs_idx].new_tensor(config["std"]) + img[bs_idx] = (img[bs_idx] - mean[:, None, None]) / std[:, None, None] + batch_dict["img_norm_cfg"] = {"mean": mean, "std": std} + return batch_dict + + def filter_depth_outlier(self, batch_dict=None, config=None): + if batch_dict is None: + return partial(self.filter_depth_outlier, config=config) + + depth = [] + for bidx in range(len(batch_dict["depth"])): + i_depth = batch_dict["depth"][bidx].clone() + i_mask = i_depth > 1e-3 + valid_depth = i_depth[i_mask] + k = int(valid_depth.numel() * config["percentile"]) + rmax = torch.topk(valid_depth, k, largest=True)[0][-1] + rmin = torch.topk(valid_depth, k, largest=False)[0][-1] + i_mask &= (i_depth > rmin) & (i_depth < rmax) + i_depth[~i_mask] = 0.0 + depth.append(i_depth) + + batch_dict["depth"] = depth + return batch_dict + + def calc_ray_from_depth(self, batch_dict=None, config=None): + if batch_dict is None: + return partial(self.calc_ray_from_depth, config=config) + + cam2img = batch_dict["cam2img"] + world2cam = batch_dict["world2cam"] + depth = batch_dict["depth"] + img = batch_dict["img"] + + ( + batch_ray_depth, + batch_ray_rgb, + batch_ray_scale, + batch_ray_o, + batch_ray_d, + batch_ray_p, + ) = ([], [], [], [], [], []) + for bidx in range(len(depth)): + # (N, 3, H, W) + ray_rgb = img[bidx] + # (N, 3, H, W) -> (N, H, W, 3) + ray_rgb = ray_rgb.transpose(-3, -2).transpose(-2, -1).contiguous() + + i_depth = depth[bidx] + i_mask = i_depth > 1e-3 + i_img2cam = torch.linalg.inv(cam2img[bidx]) + i_cam2world = torch.linalg.inv(world2cam[bidx]) + i_img2world = i_cam2world @ i_img2cam + + assert (i_depth.shape[1] == ray_rgb.shape[1]) and ( + i_depth.shape[2] == ray_rgb.shape[2] + ) + H, W = i_depth.shape[-2:] + pixel_y, pixel_x = torch.meshgrid( + torch.linspace(0.5, H - 0.5, H, device=i_depth.device), + torch.linspace(0.5, W - 0.5, W, device=i_depth.device), + indexing="ij", + ) + # (N, H, W, 4) + ray_end = torch.stack( + [ + pixel_x[None].expand_as(i_depth), + pixel_y[None].expand_as(i_depth), + torch.where(i_mask, i_depth, torch.ones_like(i_depth)), + torch.ones_like(i_depth), + ], + dim=-1, + ) + ray_end[..., :2] *= ray_end[..., 2:3] + + # (N, H, W, 4, 4) @ (N, H, W, 4, 1) -> (N, H, W, 3) + ray_end = torch.matmul( + i_img2world[:, None, None, :, :], ray_end[..., None] + )[..., :3, 0] + ray_o = i_cam2world[:, None, None, :3, 3].expand_as(ray_end) + ray_d = ray_end - ray_o + ray_depth = torch.linalg.norm(ray_d, dim=-1, keepdim=False) + ray_depth[~i_mask] = 0.0 + + ray_scale = torch.clamp( + ray_depth[i_mask].sum() / torch.clamp(i_mask.sum(), min=1), min=1e-6 + ).item() + + # (H, W, 2) -> (N, H, W, 2) + ray_p = torch.stack([pixel_x / W, pixel_y / H], dim=-1) + ray_p = ray_p[None, ...].expand(ray_d[..., :2].shape) + + batch_ray_depth.append(ray_depth.flatten(1, 2).contiguous()) + batch_ray_rgb.append(ray_rgb.flatten(1, 2).contiguous()) + batch_ray_scale.append(ray_scale) + batch_ray_o.append(ray_o.flatten(1, 2).contiguous()) + batch_ray_d.append(F.normalize(ray_d, dim=-1).flatten(1, 2).contiguous()) + batch_ray_p.append(ray_p.flatten(1, 2).contiguous()) + + batch_dict["ray_depth"] = batch_ray_depth + batch_dict["ray_rgb"] = batch_ray_rgb + batch_dict["ray_scale"] = batch_ray_scale + batch_dict["ray_o"] = batch_ray_o + batch_dict["ray_d"] = batch_ray_d + batch_dict["ray_p"] = batch_ray_p + return batch_dict + + def calc_scene_bbox(self, batch_dict=None, config=None): + if batch_dict is None: + return partial(self.calc_scene_bbox, config=config) + + point_cloud_range = [] + if config["type"] == "dynamic_depth": + ray_depth = batch_dict["ray_depth"] + ray_d = batch_dict["ray_d"] + ray_o = batch_dict["ray_o"] + for bidx in range(len(ray_o)): + i_mask = ray_depth[bidx] > 1e-3 + i_ray_depth = ray_depth[bidx][i_mask] + i_ray_d = ray_d[bidx][i_mask] + i_ray_o = ray_o[bidx][i_mask] + pc = i_ray_o + i_ray_d * i_ray_depth[..., None] + point_cloud_range.append( + torch.cat([pc.min(0).values, pc.max(0).values]) + .cpu() + .numpy() + .astype(np.float32) + ) + elif config["type"] == "dynamic_point": + raise NotImplementedError + elif config["type"] == "static": + point_cloud_range.append( + np.array(config["point_cloud_range"], dtype=np.float32) + ) + else: + raise NotImplementedError + + batch_dict["point_cloud_range"] = point_cloud_range + return batch_dict + + def calc_voxel_size(self, batch_dict=None, config=None): + if batch_dict is None: + return partial(self.calc_voxel_size, config=config) + + point_cloud_range = batch_dict["point_cloud_range"] + grid_size = [] + voxel_size = [] + for bidx in range(len(point_cloud_range)): + pcr = point_cloud_range[bidx] + gs = np.array(config["grid_size"], dtype=np.int64) + grid_size.append(gs) + voxel_size.append((pcr[3:] - pcr[:3]) / gs) + batch_dict["voxel_size"] = voxel_size + batch_dict["grid_size"] = grid_size + return batch_dict + + def forward(self, batch_dict): + """ + Args: + batch_dict: + points: (N, 3 + C_in) + gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] + gt_names: optional, (N), string + ... + + Returns: + """ + for cur_processor in self.data_processor_queue: + batch_dict = cur_processor(batch_dict=batch_dict) + return batch_dict diff --git a/spa/data/components/scannet/__init__.py b/spa/data/components/scannet/__init__.py new file mode 100644 index 0000000..b923f9b --- /dev/null +++ b/spa/data/components/scannet/__init__.py @@ -0,0 +1 @@ +from .scannet_multiview_spa_pretrain import ScanNetMultiViewSPAPretrain diff --git a/spa/data/components/scannet/scannet_multiview_spa_pretrain.py b/spa/data/components/scannet/scannet_multiview_spa_pretrain.py new file mode 100644 index 0000000..32c5a23 --- /dev/null +++ b/spa/data/components/scannet/scannet_multiview_spa_pretrain.py @@ -0,0 +1,240 @@ +import glob +import json +import os +from collections import defaultdict +from collections.abc import Sequence +from copy import deepcopy +from io import StringIO + +import cv2 +import numpy as np +import torch +from torch.utils.data import Dataset, default_collate + +import spa.utils as U +from spa.data.components.processor import DataProcessor, augmentor_utils + + +class ScanNetMultiViewSPAPretrain(Dataset): + def __init__( + self, + split="train", + scene_root="data/scannet", + frame_interval=10, + num_cameras=5, + loop=1, + downsample_ratio=1, + data_processor_cfg=None, + batch_max_num_img=16, + max_refetch=10, + semantic_size=None, + mode="train", + scene_box_threshold=0.2, + depth_area_threshold=0.1, + **kwargs, + ): + super().__init__() + self.scene_root = scene_root + self.split = split + self.loop = loop + + self.frame_interval = frame_interval + self.num_cameras = num_cameras + self.scene_interval = max(1, round(1 / downsample_ratio + 1e-6)) + self.semantic_size = semantic_size + + self.logger = U.RankedLogger(__name__, rank_zero_only=True) + self.data_list = self.get_data_list() + self.logger.info( + "Totally {} x {} x {} samples in {} set.".format( + len(self.data_list), self.loop, downsample_ratio, split + ) + ) + + self.data_processor = DataProcessor( + data_processor_cfg, + mode=mode, + logger=self.logger, + ) + self.batch_max_num_img = batch_max_num_img + self.max_refetch = max_refetch + self.scene_box_threshold = scene_box_threshold + self.depth_area_threshold = depth_area_threshold + + def get_data_list(self): + if isinstance(self.split, str): + data_list = glob.glob( + os.path.join(self.scene_root, "metadata", self.split, "*.pth") + ) + elif isinstance(self.split, Sequence): + data_list = [] + for split in self.split: + data_list += glob.glob( + os.path.join(self.scene_root, "metadata", split, "*.pth") + ) + else: + raise NotImplementedError + return data_list + + def get_data(self, idx): + metadata = torch.load(self.data_list[idx % len(self.data_list)]) + scene_name = metadata["scene_name"] + intrinsic = metadata["intrinsic"] + frames = metadata["frames"] + + num_cameras = np.random.randint(1, self.num_cameras + 1) + + frame_idx_start = np.random.randint( + 0, max(len(frames) - num_cameras * self.frame_interval + 1, 1) + ) + frame_keys = list(frames.keys())[ + frame_idx_start : frame_idx_start + + num_cameras * self.frame_interval : self.frame_interval + ] + frames = [frames[frame_key] for frame_key in frame_keys] + + intrinsics = np.stack([intrinsic for _ in range(len(frames))], axis=0) + extrinsics = np.stack([frame["extrinsic"] for frame in frames], axis=0) + + assert (not np.isnan(extrinsics).any()) and ( + not np.isinf(extrinsics).any() + ), "invalid extrinsics" + + data_dict = dict() + + depth = [ + U.io_utils.load_image(frame["depth_path"]).astype(np.float32) / 1000.0 + for frame in frames + ] + ori_rgb = [U.io_utils.load_image(frame["color_path"]) for frame in frames] + h, w = depth[0].shape[-2:] + rgb = [ + augmentor_utils.resize( + _rgb, + (w, h), + "lanczos", + "pillow", + ) + for _rgb in ori_rgb + ] + if self.semantic_size is not None: + semantic_rgb = [ + augmentor_utils.resize( + _rgb, + self.semantic_size, + "lanczos", + "pillow", + ) + for _rgb in ori_rgb + ] + data_dict["semantic_img"] = semantic_rgb + + data_dict.update( + dict( + img=rgb, # n, h, w, c + ori_shape=np.stack([x.shape[:2] for x in rgb], axis=0), + world2cam=extrinsics, + cam2img=intrinsics, # n, 4, 4 + depth=depth, # n, h, w + ) + ) + data_dict["trans3d_matrix"] = np.eye(4) + data_dict["trans2d_matrix"] = [np.eye(4) for _img in data_dict["img"]] + data_dict = self.data_processor.forward(data_dict=data_dict) + for d in data_dict["depth"]: + assert (d > 1e-3).astype( + float + ).mean() > self.depth_area_threshold, ( + f"valid depth area is small: {(d > 1e-3).astype(float).mean()}" + ) + + data_dict["scene_name"] = scene_name + data_dict["frame_list"] = frame_keys + data_dict["dataset_name"] = "scannet" + + if "point_cloud_range" in data_dict.keys(): + scene_box = data_dict["point_cloud_range"] + assert ( + scene_box[3:] - scene_box[:3] > self.scene_box_threshold + ).all(), f"too small scene box: {scene_box[3:] - scene_box[:3]}, scene: {scene_name}, frame: {frame_keys}" + + for key in data_dict: + if isinstance(data_dict[key], list): + data_dict[key] = np.stack(data_dict[key]) + if key in [ + "scene_name", + "dataset_name", + "frame_list", + "point_cloud_range", + "voxel_size", + "grid_size", + "ray_scale", + ]: + continue + data_dict[key] = torch.from_numpy(data_dict[key]).float() + + return data_dict + + def _collate_fn(self, batch, trunc_batch=True): + if not isinstance(batch, Sequence): + raise TypeError(f"{batch.dtype} is not supported.") + + if trunc_batch and self.batch_max_num_img > 0: + accum_num_imgs = 0 + ret_batches = [] + for batch_id, data in enumerate(batch): + num_imgs = len(data["img"]) + if accum_num_imgs + num_imgs > self.batch_max_num_img: + # log.info( + # f"Truncating batch {batch_id} since accum_num_imgs {accum_num_imgs} + num_imgs {num_imgs} > batch_max_num_img {self.batch_max_num_img}." + # ) + continue + accum_num_imgs += num_imgs + ret_batches.append(data) + return self._collate_fn(ret_batches, trunc_batch=False) + + return_dict = dict() + return_dict["batch_size"] = len(batch) + + for k in batch[0]: + return_dict[k] = [d[k] for d in batch] + + return return_dict + + def get_data_name(self, scene_path): + return os.path.basename(scene_path).split(".")[0] + + def prepare_train_data(self, idx): + # load data + try: + data_dict = self.get_data(idx) + except Exception as e: + return None, e + return data_dict, None + + def __getitem__(self, idx): + for _ in range(self.max_refetch): + new_idx = idx * self.scene_interval + np.random.randint( + 0, self.scene_interval + ) + data, e = self.prepare_train_data(new_idx) + if data is None: + self.logger.warning( + f"Failed to load data from {self.data_list[new_idx % len(self.data_list)]} for error {e}." + ) + idx = self._rand_another() + continue + else: + return data + raise e + + def _rand_another(self) -> int: + """Get random index. + + Returns: + int: Random index from 0 to ``len(self)-1`` + """ + return np.random.randint(0, len(self)) + + def __len__(self): + return int(len(self.data_list) * self.loop) // self.scene_interval diff --git a/spa/data/components/transform.py b/spa/data/components/transform.py new file mode 100755 index 0000000..92fc1fa --- /dev/null +++ b/spa/data/components/transform.py @@ -0,0 +1,218 @@ +import copy +import os +from collections.abc import Mapping, Sequence +from typing import Union + +import numpy as np +import torch + + +def to_tensor( + data: Union[torch.Tensor, np.ndarray, Sequence, int, float] +) -> torch.Tensor: + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + + Returns: + torch.Tensor: the converted data. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, str): + # note that str is also a kind of sequence, judgement should before sequence + return data + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool): + return torch.from_numpy(data) + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer): + return torch.from_numpy(data).long() + elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating): + return torch.from_numpy(data).float() + elif isinstance(data, Mapping): + result = {sub_key: to_tensor(item) for sub_key, item in data.items()} + return result + elif isinstance(data, Sequence): + result = [to_tensor(item) for item in data] + return result + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +class Collect: + """Collect data from the loader relevant to the specific task. + This keeps the items in ``keys`` as it is, and collect items in + ``meta_keys`` into a meta item called ``meta_name``.This is usually + the last stage of the data loader pipeline. + For example, when keys='imgs', meta_keys=('filename', 'label', + 'original_shape'), meta_name='img_metas', the results will be a dict with + keys 'imgs' and 'img_metas', where 'img_metas' is a DataContainer of + another dict with keys 'filename', 'label', 'original_shape'. + Args: + keys (Sequence[str]): Required keys to be collected. + meta_name (str): The name of the key that contains meta information. + This key is always populated. Default: "img_metas". + meta_keys (Sequence[str]): Keys that are collected under meta_name. + The contents of the ``meta_name`` dictionary depends on + ``meta_keys``. + By default this includes: + - "filename": path to the image file + - "label": label of the image file + - "original_shape": original shape of the image as a tuple + (h, w, c) + - "img_shape": shape of the image input to the network as a tuple + (h, w, c). Note that images may be zero padded on the + bottom/right, if the batch tensor is larger than this shape. + - "pad_shape": image shape after padding + - "flip_direction": a str in ("horiziontal", "vertival") to + indicate if the image is fliped horizontally or vertically. + - "img_norm_cfg": a dict of normalization information: + - mean - per channel mean subtraction + - std - per channel std divisor + - to_rgb - bool indicating if bgr was converted to rgb + nested (bool): If set as True, will apply data[x] = [data[x]] to all + items in data. The arg is added for compatibility. Default: False. + """ + + def __init__( + self, + keys, + meta_keys=( + "filename", + "label", + "original_shape", + "img_shape", + "pad_shape", + "flip_direction", + "img_norm_cfg", + ), + meta_name="img_metas", + nested=False, + ): + self.keys = keys + self.meta_keys = meta_keys if meta_keys is not None else [] + self.meta_name = meta_name + self.nested = nested + + def __call__(self, results): + """Performs the Collect formatting. + + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + data = {} + for key in self.keys: + data[key] = results[key] + if len(self.meta_keys) != 0: + meta = {} + for key in self.meta_keys: + meta[key] = results[key] + data[self.meta_name] = meta + if self.nested: + for k in data: + data[k] = [data[k]] + return data + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"keys={self.keys}, meta_keys={self.meta_keys}, " + f"nested={self.nested})" + ) + + +class Copy: + def __init__(self, keys_dict=dict()): + self.keys_dict = keys_dict + + def __call__(self, data_dict): + for key, value in self.keys_dict.items(): + if isinstance(data_dict[key], np.ndarray): + data_dict[value] = data_dict[key].copy() + elif isinstance(data_dict[key], torch.Tensor): + data_dict[value] = data_dict[key].clone().detach() + else: + data_dict[value] = copy.deepcopy(data_dict[key]) + return data_dict + + +class ToTensor: + """Convert some values in results dict to `torch.Tensor` type in data loader pipeline. + + Args: + keys (Sequence[str]): Required keys to be converted. + """ + + def __init__(self, keys=None): + self.keys = keys + + def __call__(self, results): + """Performs the ToTensor formatting. + + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + if self.keys is None: + return to_tensor(results) + + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return f"{self.__class__.__name__}(keys={self.keys})" + + +class NormalizeColor: + def __init__(self, keys=["obs_imgs"], mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]): + if isinstance(keys, str): + keys = [keys] + self.keys = keys + self.mean = np.array(mean, dtype=np.float32).reshape(1, -1, 1, 1) + self.std = np.array(std, dtype=np.float32).reshape(1, -1, 1, 1) + + def __call__(self, data_dict): + for key in self.keys: + assert ( + data_dict[key].shape[-3] == 3 or data_dict[key].shape[-3] == 4 + ), f"Only support RGB or RGB-D, but key {key} has shape {data_dict[key].shape}" + data_dict[key][..., :3, :, :] = data_dict[key][..., :3, :, :] / 255.0 + data_dict[key][..., : self.mean.shape[1], :, :] = ( + data_dict[key][..., : self.mean.shape[1], :, :] - self.mean + ) / self.std + + return data_dict + + +class Compose: + def __init__(self, cfg=None): + self.cfg = cfg if cfg is not None else [] + self.transforms = [] + for t_cfg in self.cfg: + self.transforms.append(TRANSFORMS.build(t_cfg)) + + def __call__(self, data_dict): + for t in self.transforms: + data_dict = t(data_dict) + if data_dict is None: + return None + return data_dict + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string diff --git a/spa/models/__init__.py b/spa/models/__init__.py new file mode 100755 index 0000000..879fdc5 --- /dev/null +++ b/spa/models/__init__.py @@ -0,0 +1 @@ +from .components import SPA, spa_vit_base_patch16, spa_vit_large_patch16 diff --git a/spa/models/components/__init__.py b/spa/models/components/__init__.py new file mode 100755 index 0000000..52afae9 --- /dev/null +++ b/spa/models/components/__init__.py @@ -0,0 +1,2 @@ +from .img_backbones import spa_vit_base_patch16, spa_vit_large_patch16 +from .spa import SPA diff --git a/spa/models/components/dense_heads/__init__.py b/spa/models/components/dense_heads/__init__.py new file mode 100644 index 0000000..65e893c --- /dev/null +++ b/spa/models/components/dense_heads/__init__.py @@ -0,0 +1 @@ +from .render_head import RenderHead diff --git a/spa/models/components/dense_heads/render_head.py b/spa/models/components/dense_heads/render_head.py new file mode 100644 index 0000000..c452829 --- /dev/null +++ b/spa/models/components/dense_heads/render_head.py @@ -0,0 +1,308 @@ +import os +from collections import defaultdict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from einops import rearrange +from torchvision.transforms import InterpolationMode + +from spa.utils import RankedLogger + +from ..model_utils import unet3d_utils +from ..model_utils.render_utils import models as render_models +from ..model_utils.render_utils import rays + +log = RankedLogger(__name__, rank_zero_only=True) + + +class RenderHead(nn.Module): + def __init__( + self, + *, + in_channels, + val_ray_split=8192, + feature_type="3d_to_3d", + proj_cfg=None, + render_cfg=None, + semantic_cfg=None, + **kwargs, + ): + super().__init__() + self.val_ray_split = val_ray_split + self.feature_type = feature_type + self.semantic_cfg = semantic_cfg + self.use_semantic = ( + self.semantic_cfg.get("use_semantic", True) + if self.semantic_cfg is not None + else False + ) + if self.use_semantic: + self.load_semantic_model() + proj_cfg = proj_cfg + self.proj_net = getattr(unet3d_utils, proj_cfg["type"])( + in_channels=in_channels, **proj_cfg + ) + render_cfg = render_cfg + self.renderer = getattr(render_models, render_cfg["type"])(**render_cfg) + self.forward_ret_dict = {} + self.freeze_stages() + + def freeze_stages(self): + if not self.use_semantic: + return + elif self.semantic_cfg.type == "radio": + self.radio.eval() + self.radio.requires_grad_(False) + else: + raise NotImplementedError + + def train(self, mode=True): + super().train(mode) + self.freeze_stages() + + @torch.no_grad() + def load_semantic_model(self): + assert ( + self.semantic_cfg.type == "radio" + ), f"Unsupported semantic model type: {self.semantic_cfg.type}" + if os.path.exists(os.path.expanduser("~/.cache/torch/hub/NVlabs_RADIO_main")): + self.radio = torch.hub.load( + os.path.expanduser("~/.cache/torch/hub/NVlabs_RADIO_main"), + "radio_model", + version=self.semantic_cfg.img_radio_cfg.model, + progress=True, + source="local", + ) + else: + self.radio = torch.hub.load( + "NVlabs/RADIO", + "radio_model", + version=self.semantic_cfg.img_radio_cfg.model, + progress=True, + skip_validation=True, + ) + log.info("Loading pretrained radio model") + + @torch.no_grad() + def create_imsemantic(self, img): + img_radio = img + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + _, spatial_features = self.radio(img_radio) + ret_features = rearrange( + spatial_features, + "b (h w) d -> b d h w", + h=img_radio.shape[-2] // self.radio.patch_size, + w=img_radio.shape[-1] // self.radio.patch_size, + ) + + ret_features = ret_features.to(torch.float32) + return ret_features + + @torch.no_grad() + def prepare_ray(self, batch_dict): + bray_p, bray_idx = batch_dict["ray_p"], batch_dict["ray_idx"] + trans2d_matrix = batch_dict["trans2d_matrix"] + ori_shape = batch_dict["ori_shape"] + img = batch_dict["img"] + img_norm_cfg = batch_dict["img_norm_cfg"] + + if self.use_semantic: + semantic_img = torch.cat(batch_dict["semantic_img"], dim=0) + bsemantic = self.create_imsemantic(semantic_img) + bsemantic = bsemantic.split([len(img[i]) for i in range(len(img))], dim=0) + + ori_img = [ + _img * img_norm_cfg["std"][:, None, None] + + img_norm_cfg["mean"][:, None, None] + for _img in img + ] + + ray_dict = { + "rgb": [], + "depth": [d[..., None] for d in batch_dict["ray_depth"]], + "origin": batch_dict["ray_o"], + "direction": batch_dict["ray_d"], + "near": batch_dict["ray_near"], + "far": batch_dict["ray_far"], + "scene_scale": batch_dict["ray_scale"], + "scene_bbox": batch_dict["point_cloud_range"], + } + + if self.use_semantic: + ray_dict["semantic"] = [] + + for bidx in range(batch_dict["batch_size"]): + iray_rgb, iray_semantic = [], [] + ray_idx = bray_idx[bidx] + for cidx in range(len(img[bidx])): + ray_p = bray_p[bidx][ray_idx == cidx] + + ray_p_rgb = ( + torch.stack( + [ + ray_p[:, 0] / (img[bidx].shape[-1] - 1), + ray_p[:, 1] / (img[bidx].shape[-2] - 1), + ], + dim=-1, + ) + * 2 + - 1 + ) + assert torch.all((ray_p_rgb >= -1) & (ray_p_rgb <= 1)) + ray_rgb = ( + F.grid_sample( + ori_img[bidx][cidx : cidx + 1].contiguous(), + ray_p_rgb[None, None].contiguous(), + align_corners=True, + ) + .squeeze(0) + .squeeze(1) + .transpose(0, 1) + ) + iray_rgb.append(ray_rgb) + + if self.use_semantic: + shp = ori_shape[bidx][cidx] + img2aug = trans2d_matrix[bidx][cidx] + ray_p_semantic = torch.stack( + [ + ray_p[:, 0], + ray_p[:, 1], + torch.ones_like(ray_p[:, 0]), + torch.ones_like(ray_p[:, 0]), + ], + dim=-1, + ) + ray_p_semantic = ray_p_semantic @ torch.linalg.inv(img2aug).T + ray_p_semantic = ( + torch.stack( + [ + ray_p_semantic[:, 0] / (shp[1] - 1), + ray_p_semantic[:, 1] / (shp[0] - 1), + ], + dim=-1, + ) + * 2 + - 1 + ) + assert torch.all((ray_p_semantic >= -1) & (ray_p_semantic <= 1)) + ray_semantic = ( + F.grid_sample( + bsemantic[bidx][cidx : cidx + 1].contiguous(), + ray_p_semantic[None, None].contiguous(), + align_corners=True, + ) + .squeeze(0) + .squeeze(1) + .transpose(0, 1) + ).contiguous() + iray_semantic.append(ray_semantic) + + ray_dict["rgb"].append(torch.cat(iray_rgb, dim=0)) + if self.use_semantic: + ray_dict["semantic"].append(torch.cat(iray_semantic, dim=0)) + + for k in [ + "ray_depth", + "ray_rgb", + "ray_o", + "ray_d", + "ray_p", + "ray_scale", + "ray_near", + "ray_far", + "ray_idx", + ]: + batch_dict.pop(k, None) + + rearrange_ray_dict = [] + for bidx in range(batch_dict["batch_size"]): + rearrange_ray_dict.append({k: v[bidx] for k, v in ray_dict.items()}) + return rearrange_ray_dict + + def prepare_volume(self, batch_dict): + if self.feature_type == "2d_to_3d": + volume_feat = batch_dict["spatial_features_2d"] + elif self.feature_type == "3d_to_3d": + volume_feat = batch_dict["encoded_spconv_tensor"] + else: + raise NotImplementedError + if "dataset_name" in batch_dict.keys(): + dataset = batch_dict["dataset_name"][0] + else: + dataset = "default" + + volume_feat = self.proj_net(volume_feat, dataset=dataset) + return volume_feat + + def render_func(self, ray_dict, volume_feature): + batched_render_out = [] + for i in range(len(ray_dict)): + i_ray_o, i_ray_d, i_ray_near, i_ray_far = ( + ray_dict[i]["origin"], + ray_dict[i]["direction"], + ray_dict[i]["near"], + ray_dict[i]["far"], + ) + i_volume_feature = [v[i] for v in volume_feature] + i_scene_bbox = ray_dict[i]["scene_bbox"] + i_scene_scale = ray_dict[i]["scene_scale"] + + if self.training: + ray_bundle = rays.RayBundle( + origins=i_ray_o, + directions=i_ray_d, + nears=i_ray_near, + fars=i_ray_far, + ) + render_out = self.renderer( + ray_bundle, i_volume_feature, i_scene_bbox, i_scene_scale + ) + else: + render_out = defaultdict(list) + for j_ray_o, j_ray_d, j_ray_near, j_ray_far in zip( + i_ray_o.split(self.val_ray_split, dim=0), + i_ray_d.split(self.val_ray_split, dim=0), + i_ray_near.split(self.val_ray_split, dim=0), + i_ray_far.split(self.val_ray_split, dim=0), + ): + ray_bundle = rays.RayBundle( + origins=j_ray_o, + directions=j_ray_d, + nears=j_ray_near, + fars=j_ray_far, + ) + part_render_out = self.renderer( + ray_bundle, i_volume_feature, i_scene_bbox, i_scene_scale + ) + for k, v in part_render_out.items(): + render_out[k].append(v.detach()) + del part_render_out + torch.cuda.empty_cache() + for k, v in render_out.items(): + render_out[k] = torch.cat(v, dim=0) + batched_render_out.append(render_out) + + return batched_render_out + + def get_loss(self, ray_preds, ray_targets): + batch_size = len(ray_targets) + loss_dict = defaultdict(list) + for bs_idx in range(batch_size): + i_loss_dict = self.renderer.get_loss(ray_preds[bs_idx], ray_targets[bs_idx]) + for k, v in i_loss_dict.items(): + loss_dict[k].append(v) + for k, v in loss_dict.items(): + loss_dict[k] = torch.stack(v, dim=0).mean() + loss = sum(_value for _key, _value in loss_dict.items() if "loss" in _key) + return loss, loss_dict + + def forward(self, batch_dict): + ray_dict = self.prepare_ray(batch_dict) + volume_feature = self.prepare_volume(batch_dict) + render_out = self.render_func(ray_dict, volume_feature) + batch_dict.update({"render_out": render_out, "ray_dict": ray_dict}) + return batch_dict diff --git a/spa/models/components/img_backbones/__init__.py b/spa/models/components/img_backbones/__init__.py new file mode 100644 index 0000000..787306f --- /dev/null +++ b/spa/models/components/img_backbones/__init__.py @@ -0,0 +1 @@ +from .vit import SPAViT, spa_vit_base_patch16, spa_vit_large_patch16 diff --git a/spa/models/components/img_backbones/modules.py b/spa/models/components/img_backbones/modules.py new file mode 100644 index 0000000..b1cf07a --- /dev/null +++ b/spa/models/components/img_backbones/modules.py @@ -0,0 +1,69 @@ +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize( + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + warning=False, +): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, align_corners=False): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + return x + + +class SimpleUpsample(nn.Module): + """ + Initialize: inplanes, planes, upscale_factor + OUTPUT: (planes // upscale_factor^2) * ht * wd + """ + + def __init__(self, inplanes, planes, upscale_factor=4, norm_layer=nn.BatchNorm2d): + super(SimpleUpsample, self).__init__() + self.conv = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, bias=False) + self.gelu = nn.GELU() + self.pixel_shuffle = nn.PixelShuffle(upscale_factor) + + def forward(self, x): + x = self.conv(x) + x = self.gelu(x) + x = self.pixel_shuffle(x) + return x.contiguous() diff --git a/spa/models/components/img_backbones/utils.py b/spa/models/components/img_backbones/utils.py new file mode 100644 index 0000000..b29cdf7 --- /dev/null +++ b/spa/models/components/img_backbones/utils.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# adapted from: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +import os +import sys +import urllib +from functools import partial +from os.path import expanduser + +import hydra +import numpy as np +import omegaconf +import six +import timm.models.vision_transformer +import torch +import torch.nn as nn +import torchvision.transforms as T +from timm.models.vision_transformer import Block, PatchEmbed, resize_pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def reshape_embedding(x): + N, L, D = x.shape + H = W = int(L**0.5) + x = x.reshape(N, H, W, D) + x = torch.einsum("nhwd->ndhw", x) + return x + + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed diff --git a/spa/models/components/img_backbones/vit.py b/spa/models/components/img_backbones/vit.py new file mode 100644 index 0000000..41e62c1 --- /dev/null +++ b/spa/models/components/img_backbones/vit.py @@ -0,0 +1,394 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# adapted from: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +import os +import sys +import urllib +from functools import partial +from os.path import expanduser + +import numpy as np +import six +import timm.models.vision_transformer +import torch +import torch.nn as nn +import torchvision.transforms as T +from einops import rearrange, repeat +from timm.models.vision_transformer import Block, PatchEmbed, resize_pos_embed + +from spa.utils.fp16_utils import auto_fp16 + +from .modules import SimpleUpsample +from .utils import get_2d_sincos_pos_embed + + +class ToTensorIfNot(T.ToTensor): + def __call__(self, pic): + if not torch.is_tensor(pic): + return super().__call__(pic) + return pic + + +class ViTEncoder(timm.models.vision_transformer.VisionTransformer): + """Vision Transformer Encoder with image transforms, used for downstream purpose""" + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + del self.head + + self.img_size = kwargs.get("img_size", 224) + + from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + self.image_transform = T.Compose( + [ + T.Resize(self.img_size, interpolation=T.InterpolationMode.BICUBIC), + ToTensorIfNot(), + T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ] + ) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x, feature_map=False, cat_cls=True): + # If not `feature_map` (default), will return [cls] token (b c) + # otherwise, return reshaped feature map (b c h w) + # if both `feature_map` and `cat_cls`, concatenate feature map with [cls] token + x = self.image_transform(x) + latents = self.forward_features(x) + if not feature_map: + return latents[:, 0] + else: + h = w = int(latents[:, 1:].shape[1] ** 0.5) + feature_map = rearrange( + latents[:, 1:], + "b (h w) c -> b c h w", + h=h, + w=w, + ) + + if cat_cls: + cls_token = repeat(latents[:, 0:1], "n 1 c -> n c h w", h=h, w=w) + return torch.cat([feature_map, cls_token], dim=1) + else: + return feature_map + + +class SPAViT(timm.models.vision_transformer.VisionTransformer): + """Vision Transformer with SPA's upsampler decoder, used for pre-training""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + decoder_embed_dim=512, + mlp_ratio=4.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + pretrained_weight=None, + mask_ratio=0.75, + out_feature_channels=128, + **kwargs, + ): + super().__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + **kwargs, + ) + del self.head + + self.in_chans = in_chans + self.img_size = img_size + self.patch_size = patch_size + num_patches = self.patch_embed.num_patches + self.decoder_embed_dim = decoder_embed_dim + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.readout_project = nn.Sequential( + nn.Linear(decoder_embed_dim * 2, decoder_embed_dim, bias=True), + nn.GELU(approximate="none"), + ) + self.upsample = nn.Sequential( + SimpleUpsample( + decoder_embed_dim, out_feature_channels * 16, upscale_factor=4 + ), + SimpleUpsample( + out_feature_channels, out_feature_channels * 16, upscale_factor=4 + ), + ) + # -------------------------------------------------------------------------- + + self.mask_ratio = mask_ratio + + if pretrained_weight is None: + self.initialize_weights() + else: + self.load_weight(pretrained_weight) + + def load_weight(self, pretrained_weight): + state_dict = torch.load(pretrained_weight, map_location="cpu") + if state_dict.get("model", None) is not None: + state_dict = state_dict["model"] + if state_dict["pos_embed"].shape != self.pos_embed.shape: + state_dict["pos_embed"] = resize_pos_embed( + state_dict["pos_embed"], + self.pos_embed, + getattr(self, "num_tokens", 1), + self.patch_embed.grid_size, + ) + + # filter out keys with name decoder or mask_token + state_dict = { + k: v + for k, v in state_dict.items() + if "decoder" not in k and "mask_token" not in k + } + + self.load_state_dict(state_dict, strict=False) + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.patch_embed.num_patches**0.5), + cls_token=True, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=0.02) + torch.nn.init.normal_(self.mask_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_chans)) + return x + + def unpatchify(self, x, n_channel=3, p=None): + """ + x: (N, L, patch_size**2 *n_channel) + imgs: (N, n_channel, H, W) + """ + if p is None: + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, n_channel)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], n_channel, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + @auto_fp16(apply_to=("x"), out_fp32=True) + def forward_encoder(self, x, mask_ratio): # (n, 3, h, w) + # embed patches + x = self.patch_embed(x) # (n, p*p, c) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking( + x, mask_ratio + ) # (n, p*p*(1-mask_ratio), c) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) # (n, p*p*(1-mask_ratio) + 1, c) + + # apply Transformer blocks + for i, blk in enumerate(self.blocks): + x = blk(x) # (n, p*p*(1-mask_ratio) + 1, c) + + x = self.norm(x) + return mask, ids_restore, x + + @auto_fp16(apply_to=("x"), out_fp32=True) + def forward_decoder(self, latent, ids_restore): + x = self.decoder_embed(latent) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 + ) + + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather( + x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + + readout = x[:, :1, :].expand_as(x_) + x = torch.cat([x_, readout], dim=-1) + x = self.readout_project(x) + x = self.unpatchify(x, self.decoder_embed_dim, p=1) # b c p p + x = self.upsample(x) # b c p*n p*n + + return x + + def forward(self, batch_dict, mask_ratio=None): + if mask_ratio is None: + mask_ratio = self.mask_ratio + + imgs = torch.cat(batch_dict["img"], dim=0) + imgs = imgs.view(-1, *imgs.shape[-3:]) + + mask, ids_restore, latent = self.forward_encoder(imgs, mask_ratio) + feature_map = self.forward_decoder(latent, ids_restore) + batch_dict["img_features"] = [feature_map] + return batch_dict + + +def spa_vit_base_patch16(img_size=224, pretrained=True, **kwargs): + model = ViTEncoder( + img_size=img_size, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + qkv_bias=True, + **kwargs, + ) + if pretrained: + model = load_pretrained(model, "spa-b") + + return model + + +def spa_vit_large_patch16(img_size=224, pretrained=True, **kwargs): + model = ViTEncoder( + img_size=img_size, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + qkv_bias=True, + **kwargs, + ) + if pretrained: + model = load_pretrained(model, "spa-l") + + return model + + +def load_pretrained(model: nn.Module, ckpt_name: str): + from collections import OrderedDict + + from huggingface_hub import hf_hub_download + + try: + import safetensors.torch + + _has_safetensors = True + except ImportError: + _has_safetensors = False + + if _has_safetensors: + from safetensors.torch import load_file + + ckpt_file = hf_hub_download( + repo_id="HaoyiZhu/SPA", filename=f"{ckpt_name}.safetensors" + ) + _state_dict = load_file(ckpt_file) + else: + ckpt_file = hf_hub_download( + repo_id="HaoyiZhu/SPA", filename=f"{ckpt_name}.ckpt" + ) + _state_dict = torch.load(ckpt_file)["state_dict"] + + state_dict = OrderedDict() + for key, value in _state_dict.items(): + if key.startswith("model.img_backbone.") and ( + "decoder" not in key + and "head" not in key + and "upsample" not in key + and "mask" not in key + and "readout" not in key + ): + state_dict[key.replace("model.img_backbone.", "")] = value + + model.load_state_dict(state_dict, strict=True) + return model diff --git a/spa/models/components/model_utils/__init__.py b/spa/models/components/model_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spa/models/components/model_utils/attention_utils.py b/spa/models/components/model_utils/attention_utils.py new file mode 100644 index 0000000..c31ffdb --- /dev/null +++ b/spa/models/components/model_utils/attention_utils.py @@ -0,0 +1,282 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from deform_attn.ms_deform_attn_utils import MSDeformAttnFunction +from einops import rearrange +from torch.nn.functional import linear +from torch.nn.parameter import Parameter + + +def batch_mask_sequence(feats_list, mask): + """ + Args: + feats_list: [(B, N, C), ...] + mask: (B, N) + Returns: + rebatch_feats_list: [(B, M, C), ...] + mask_indices: [(M1,), (M2,), ...] + """ + batch_size = mask.shape[0] + mask_indices = [] + for bs_idx in range(batch_size): + mask_indices.append(mask[bs_idx].nonzero(as_tuple=True)[0]) + max_len = max([len(each) for each in mask_indices]) + rebatch_feats_list = [] + for feats in feats_list: + rebatch_feats = feats.new_zeros( + [batch_size, max_len, feats.shape[-1]], dtype=feats.dtype + ) + for bs_idx in range(batch_size): + i_index = mask_indices[bs_idx] + rebatch_feats[bs_idx, : len(i_index)] = feats[bs_idx, i_index] + rebatch_feats_list.append(rebatch_feats) + return rebatch_feats_list, mask_indices + + +def rebatch_mask_sequence(feats, rebatch_feats, mask_indices): + """ + Args: + feats: (B, N, C) + rebatch_feats: (B, M, C) + mask_indices: [(M1,), (M2,), ...] + Returns: + new_feats: (B, N, C) + """ + batch_size = feats.shape[0] + new_feats = rebatch_feats.new_zeros( + [batch_size, feats.shape[1], rebatch_feats.shape[-1]], dtype=rebatch_feats.dtype + ) + for bs_idx in range(batch_size): + i_index = mask_indices[bs_idx] + new_feats[bs_idx, i_index] = rebatch_feats[bs_idx, : len(i_index)] + return new_feats + + +class VoxelformerCrossAttention(nn.Module): + def __init__( + self, + embed_dims, + num_heads, + num_points, + num_levels, + im2col_step=64, + bias=True, + dropout=0.0, + **kwargs, + ): + super(VoxelformerCrossAttention, self).__init__() + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_points = num_points + self.num_levels = num_levels + self.sampling_offsets = nn.Linear( + embed_dims, num_heads * num_levels * num_points * 2 + ) + self.attention_weights = nn.Linear( + embed_dims, num_heads * num_levels * num_points + ) + self.value_proj = nn.Linear(embed_dims, embed_dims, bias=bias) + self.out_proj = nn.Linear(embed_dims, embed_dims, bias=bias) + self.dropout_layer = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + self.init_weights() + + def init_weights(self): + nn.init.zeros_(self.sampling_offsets.weight) + thetas = torch.arange( + self.num_heads, dtype=self.sampling_offsets.weight.dtype + ) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.num_heads, 1, 1, 2) + .repeat(1, self.num_levels, self.num_points, 1) + ) + for i in range(self.num_points): + grid_init[:, :, i, :] *= i + 1 + + self.sampling_offsets.bias.data = grid_init.view(-1) + nn.init.zeros_(self.attention_weights.weight) + nn.init.zeros_(self.attention_weights.bias) + nn.init.xavier_uniform_(self.value_proj.weight) + nn.init.zeros_(self.value_proj.bias) + nn.init.xavier_uniform_(self.out_proj.weight) + nn.init.zeros_(self.out_proj.bias) + + def forward( + self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + query_mask=None, + **kwargs, + ): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): Query of Transformer with shape + (bs, num_query, embed_dims). + key (Tensor): The key tensor with shape + `(bs, num_key, embed_dims)`. + value (Tensor): The value tensor with shape + `(bs, num_key, embed_dims)`. + identity (Tensor): The tensor used for addition, with the + same shape as `query`. Default None. If None, + `query` will be used. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. Default + None. + reference_points (Tensor): The normalized reference + points with shape (bs, num_query, 3), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + spatial_shapes (Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + + Returns: + Tensor: forwarded results with shape [bs, num_query, embed_dims]. + """ + length = [tmp.shape[0] for tmp in query_mask] + B, Z, Y, X, C = query.shape + query = [ + query[i : i + 1].expand(N, -1, -1, -1, -1) for i, N in enumerate(length) + ] + # (N1+N2+..., Z*Y*X, C) + query = torch.cat(query, dim=0).view(-1, Z * Y * X, C) + reference_points = torch.cat(reference_points, dim=0).view(-1, Z * Y * X, 2) + query_mask = torch.cat(query_mask, dim=0).view(-1, Z * Y * X) + + (bvoxel_cam_coords, bvoxel_embeds), bmask_indices = batch_mask_sequence( + [reference_points, query], query_mask + ) # [(N1+N2+..., M, 2), (N1+N2+..., M, C)] + + query = bvoxel_embeds + if value is None: + value = query + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2 + ) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points + ) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view( + bs, num_query, self.num_heads, self.num_levels, self.num_points + ) + + # (num_level, 2) + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1 + ) + # (bs, num_query, num_heads, num_levels, num_points, 2) + sampling_locations = ( + bvoxel_cam_coords[:, :, None, None, None, :2] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + + output = MSDeformAttnFunction.apply( + value, + spatial_shapes, + level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + + output = self.out_proj(output) + output = identity + self.dropout_layer(output) + + output = rebatch_mask_sequence(reference_points, output, bmask_indices) + + presum = 0 + final = [] + for l in length: + x = output[presum : presum + l].sum(dim=0) + c = torch.clamp(query_mask[presum : presum + l].sum(dim=0), min=1.0) + x = (x / c[..., None]).view(Z, Y, X, -1) + final.append(x) + presum += l + # (B, Z, Y, X, C) + output = torch.stack(final, dim=0) + return output + + +class VoxelformerSelfAttention(nn.Module): + def __init__( + self, + embed_dims, + num_convs, + **kwargs, + ): + super(VoxelformerSelfAttention, self).__init__() + self.embed_dims = embed_dims + self.conv_layer = nn.ModuleList() + for k in range(num_convs): + self.conv_layer.append( + nn.Sequential( + nn.Conv3d( + embed_dims, + embed_dims, + kernel_size=3, + stride=1, + padding=1, + bias=True, + ), + nn.ReLU(inplace=True), + ) + ) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward( + self, + query, + key=None, + value=None, + identity=None, + **kwargs, + ): + # (B, Z, Y, X, C) -> (B, C, Z, Y, X) + output = query.permute(0, 4, 1, 2, 3) + for layer in self.conv_layer: + output = layer(output) + # (B, C, Z, Y, X) -> (B, Z, Y, X, C) + output = output.permute(0, 2, 3, 4, 1).contiguous() + return output diff --git a/spa/models/components/model_utils/render_utils/decoders.py b/spa/models/components/model_utils/render_utils/decoders.py new file mode 100644 index 0000000..a9121c0 --- /dev/null +++ b/spa/models/components/model_utils/render_utils/decoders.py @@ -0,0 +1,92 @@ +import numpy as np +import torch +import torch.nn as nn + + +class SDFDecoder(nn.Module): + def __init__(self, in_dim, out_dim, hidden_size=256, n_blocks=5, **kwargs): + super().__init__() + + dims = [hidden_size] + [hidden_size for _ in range(n_blocks)] + [out_dim] + self.num_layers = len(dims) + + for l in range(self.num_layers - 1): + lin = nn.Linear(dims[l], dims[l + 1]) + setattr(self, "lin" + str(l), lin) + + self.fc_c = nn.ModuleList( + [nn.Linear(in_dim, hidden_size) for i in range(self.num_layers - 1)] + ) + self.fc_p = nn.Linear(3, hidden_size) + + self.activation = nn.Softplus(beta=100) + + def forward(self, points, point_feats): + x = self.fc_p(points) + for l in range(self.num_layers - 1): + x = x + self.fc_c[l](point_feats) + lin = getattr(self, "lin" + str(l)) + x = lin(x) + if l < self.num_layers - 2: + x = self.activation(x) + return x + + +class RGBDecoder(nn.Module): + def __init__(self, in_dim, out_dim=3, hidden_size=256, n_blocks=5, **kwargs): + super().__init__() + + dims = [hidden_size] + [hidden_size for _ in range(n_blocks)] + [out_dim] + self.num_layers = len(dims) + + for l in range(self.num_layers - 1): + lin = nn.Linear(dims[l], dims[l + 1]) + setattr(self, "lin" + str(l), lin) + + self.fc_p = nn.Linear(3, hidden_size) + + self.fc_c = nn.ModuleList( + [nn.Linear(in_dim, hidden_size) for i in range(self.num_layers - 1)] + ) + self.activation = nn.ReLU() + + def forward(self, points, point_feats, out_sigmoid=True): + x = self.fc_p(points) + for l in range(self.num_layers - 1): + x = x + self.fc_c[l](point_feats) + lin = getattr(self, "lin" + str(l)) + x = lin(x) + if l < self.num_layers - 2: + x = self.activation(x) + if out_sigmoid: + x = torch.sigmoid(x) + return x + + +class SemanticDecoder(nn.Module): + def __init__(self, in_dim, out_dim, hidden_size=256, n_blocks=5, **kwargs): + super().__init__() + + dims = [hidden_size] + [hidden_size for _ in range(n_blocks)] + [out_dim] + self.num_layers = len(dims) + + for l in range(self.num_layers - 1): + lin = nn.Linear(dims[l], dims[l + 1]) + setattr(self, "lin" + str(l), lin) + + self.fc_p = nn.Linear(3, hidden_size) + + self.fc_c = nn.ModuleList( + [nn.Linear(in_dim, hidden_size) for i in range(self.num_layers - 1)] + ) + self.activation = nn.ReLU() + + def forward(self, points, point_feats): + x = self.fc_p(points) + for l in range(self.num_layers - 1): + x = x + self.fc_c[l](point_feats) + lin = getattr(self, "lin" + str(l)) + x = lin(x) + if l < self.num_layers - 2: + x = self.activation(x) + return x diff --git a/spa/models/components/model_utils/render_utils/fields/__init__.py b/spa/models/components/model_utils/render_utils/fields/__init__.py new file mode 100644 index 0000000..4206eab --- /dev/null +++ b/spa/models/components/model_utils/render_utils/fields/__init__.py @@ -0,0 +1,3 @@ +from .sdf_field import SDFFieldExp + +__all__ = ["SDFFieldExp"] diff --git a/spa/models/components/model_utils/render_utils/fields/sdf_field.py b/spa/models/components/model_utils/render_utils/fields/sdf_field.py new file mode 100644 index 0000000..374a43e --- /dev/null +++ b/spa/models/components/model_utils/render_utils/fields/sdf_field.py @@ -0,0 +1,235 @@ +import math + +import torch +import torch.nn.functional as F +from grid_sampler import GridSampler3D +from torch import nn + +from spa.utils.transforms import components_from_spherical_harmonics + +from ..decoders import RGBDecoder, SDFDecoder, SemanticDecoder + + +class LaplaceDensity(nn.Module): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf) + """Laplace density from VolSDF""" + + def __init__(self, init_val, beta_min=0.0001): + super().__init__() + self.register_parameter( + "beta_min", nn.Parameter(beta_min * torch.ones(1), requires_grad=False) + ) + self.register_parameter( + "beta", nn.Parameter(init_val * torch.ones(1), requires_grad=True) + ) + + def forward(self, sdf, beta=None): + """convert sdf value to density value with beta, if beta is missing, then use learable beta""" + if beta is None: + beta = self.get_beta() + + alpha = 1.0 / beta + + density = alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) + return density + + def get_beta(self): + """return current beta value""" + beta = self.beta.abs() + self.beta_min + return beta + + +class SingleVarianceNetwork(nn.Module): + """Variance network in NeuS""" + + def __init__(self, init_val): + super(SingleVarianceNetwork, self).__init__() + self.register_parameter( + "variance", nn.Parameter(init_val * torch.ones(1), requires_grad=True) + ) + + def forward(self, x): + """Returns current variance value""" + return torch.ones([len(x), 1], device=x.device) * torch.exp( + self.variance * 10.0 + ) + + def get_variance(self): + """return current variance value""" + return torch.exp(self.variance * 10.0).clip(1e-6, 1e6) + + +class SDFFieldExp(nn.Module): + def __init__( + self, + beta_init, + padding_mode="zeros", + render_rgb=True, + render_semantic=False, + use_alpha=False, + use_density=False, + **kwargs + ): + super().__init__() + self.beta_init = beta_init + self.padding_mode = padding_mode + self.use_density = use_density + self.use_alpha = use_alpha + self.render_rgb = render_rgb + self.render_semantic = render_semantic + if use_density: + # laplace function for transform sdf to density from VolSDF + self.laplace_density = LaplaceDensity(init_val=self.beta_init) + + if use_alpha: + # deviation_network to compute alpha from sdf from NeuS + self.deviation_network = SingleVarianceNetwork(init_val=self.beta_init) + + self._cos_anneal_ratio = 1.0 + + def set_cos_anneal_ratio(self, anneal): + """Set the anneal value for the proposal network.""" + self._cos_anneal_ratio = anneal + + def get_alpha(self, ray_samples, sdf, gradients): + inv_s = self.deviation_network.get_variance() # Single parameter + + true_cos = (ray_samples.frustums.directions * gradients).sum(-1, keepdim=True) + + # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes + # the cos value "not dead" at the beginning training iterations, for better convergence. + iter_cos = -( + F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self._cos_anneal_ratio) + + F.relu(-true_cos) * self._cos_anneal_ratio + ) # always non-positive + + # Estimate signed distances at section points + estimated_next_sdf = sdf + iter_cos * ray_samples.deltas * 0.5 + estimated_prev_sdf = sdf - iter_cos * ray_samples.deltas * 0.5 + + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) + + p = prev_cdf - next_cdf + c = prev_cdf + + alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) + + return alpha + + def feature_sampling(self, pts, volume_feature, scene_bbox, scene_scale, index): + """ + Args: + pts: (N, K, 3), [x, y, z], scaled + feats_volume: (C, Z, Y, X) + Returns: + feats: (N, K, C) + """ + scene_bbox = pts.new_tensor(scene_bbox) + pts_norm = (pts * scene_scale - scene_bbox[:3]) / ( + scene_bbox[3:] - scene_bbox[:3] + ) + pts_norm = pts_norm * 2 - 1 # [0, 1] -> [-1, 1] + + ret_feat = ( + GridSampler3D.apply( + volume_feature[index].unsqueeze(0).contiguous().to(pts_norm.dtype), + pts_norm[None, None].contiguous(), + self.padding_mode, + True, + ) + .squeeze(0) + .squeeze(1) + .permute(1, 2, 0) + .contiguous() + ) + # (1, C, 1, N, K) -> (N, K, C) + + return ret_feat + + def get_sdf(self, points, volume_feature, scene_bbox, scene_scale): + """predict the sdf value for ray samples""" + sdf = self.feature_sampling( + points, volume_feature, scene_bbox, scene_scale, index=0 + ) + return (sdf,) + + def get_density(self, ray_samples, volume_feature, scene_bbox, scene_scale): + """Computes and returns the densities.""" + points = ray_samples.frustums.get_positions() + sdf = self.get_sdf(points, volume_feature, scene_bbox, scene_scale)[0] + density = self.laplace_density(sdf) + return density + + def get_occupancy(self, sdf): + """compute occupancy as in UniSurf""" + occupancy = torch.sigmoid(-10.0 * sdf) + return occupancy + + def forward(self, ray_samples, volume_feature, scene_bbox, scene_scale): + """Evaluates the field at points along the ray. + + Args: + ray_samples: Samples to evaluate field on. + """ + outputs = {} + + points = ray_samples.frustums.get_positions() # (num_rays, num_samples, 3) + + points.requires_grad_(True) + with torch.enable_grad(): + (sdf,) = self.get_sdf(points, volume_feature, scene_bbox, scene_scale) + + d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) + gradients = torch.autograd.grad( + outputs=sdf, + inputs=points, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + if self.render_rgb: + directions = ray_samples.frustums.directions # (num_rays, num_samples, 3) + + sh = self.feature_sampling( + points, volume_feature, scene_bbox, scene_scale, index=1 + ) + sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3) + + levels = int(math.sqrt(sh.shape[-1])) + components = components_from_spherical_harmonics( + levels=levels, directions=directions + ) + + rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components] + rgb = torch.sum(sh, dim=-1) + 0.5 # [..., num_samples, 3] + rgb = torch.sigmoid(rgb) + + outputs["rgb"] = rgb + + if self.render_semantic: + semantic = self.feature_sampling( + points, volume_feature, scene_bbox, scene_scale, index=-1 + ) + outputs["semantic"] = semantic + + outputs.update( + { + "sdf": sdf, + "gradients": gradients, + "normal": F.normalize(gradients, dim=-1), # TODO: should normalize? + } + ) + + if self.use_density: + density = self.laplace_density(sdf) + outputs["density"] = density + + if self.use_alpha: + # TODO use mid point sdf for NeuS + # (num_rays, num_samples, 1) + alphas = self.get_alpha(ray_samples, sdf, gradients) + outputs["alphas"] = alphas + + return outputs diff --git a/spa/models/components/model_utils/render_utils/models/__init__.py b/spa/models/components/model_utils/render_utils/models/__init__.py new file mode 100644 index 0000000..9268c70 --- /dev/null +++ b/spa/models/components/model_utils/render_utils/models/__init__.py @@ -0,0 +1,3 @@ +from .neus import NeuSModel + +__all__ = ["NeuSModel"] diff --git a/spa/models/components/model_utils/render_utils/models/base_surface_model.py b/spa/models/components/model_utils/render_utils/models/base_surface_model.py new file mode 100644 index 0000000..377ef11 --- /dev/null +++ b/spa/models/components/model_utils/render_utils/models/base_surface_model.py @@ -0,0 +1,162 @@ +from abc import abstractmethod + +import torch +import torch.nn.functional as F +from torch import nn + +from .. import fields, ray_samplers +from ..renderers import DepthRenderer, OtherRenderer, RGBRenderer + + +class SurfaceModel(nn.Module): + def __init__( + self, + field_cfg, + sampler_cfg, + loss_cfg, + **kwargs, + ): + super().__init__() + self.field = getattr(fields, field_cfg["type"])(**field_cfg) + self.sampler = getattr(ray_samplers, sampler_cfg["type"])(**sampler_cfg) + self.rgb_renderer = RGBRenderer() + self.depth_renderer = DepthRenderer() + self.other_renderer = OtherRenderer() + self.loss_cfg = loss_cfg + + @abstractmethod + def sample_and_forward_field( + self, ray_bundle, volume_feature, scene_bbox, scene_scale + ): + """_summary_ + + Args: + ray_bundle (RayBundle): _description_ + return_samples (bool, optional): _description_. Defaults to False. + """ + + def get_outputs( + self, ray_bundle, volume_feature, scene_bbox, scene_scale, **kwargs + ): + outputs = {} + + samples_and_field_outputs = self.sample_and_forward_field( + ray_bundle, volume_feature, scene_bbox, scene_scale + ) + + # Shotscuts + field_outputs = samples_and_field_outputs["field_outputs"] + ray_samples = samples_and_field_outputs["ray_samples"] + weights = samples_and_field_outputs["weights"] + + depth = self.depth_renderer(ray_samples=ray_samples, weights=weights) + normal = self.other_renderer(vals=field_outputs["normal"], weights=weights) + if "rgb" in field_outputs.keys(): + rgb = self.rgb_renderer(rgb=field_outputs["rgb"], weights=weights) + outputs["rgb"] = rgb + if "semantic" in field_outputs.keys(): + semantic = self.other_renderer( + vals=field_outputs["semantic"], weights=weights + ) + outputs["semantic"] = semantic + + outputs.update( + { + "depth": depth, + "normal": normal, + "weights": weights, + "sdf": field_outputs["sdf"], + "gradients": field_outputs["gradients"], + "z_vals": ray_samples.frustums.starts, + } + ) + + """ add for visualization""" + outputs.update({"sampled_points": samples_and_field_outputs["sampled_points"]}) + if samples_and_field_outputs.get("init_sampled_points", None) is not None: + outputs.update( + { + "init_sampled_points": samples_and_field_outputs[ + "init_sampled_points" + ], + "init_weights": samples_and_field_outputs["init_weights"], + "new_sampled_points": samples_and_field_outputs[ + "new_sampled_points" + ], + } + ) + + return outputs + + def forward(self, ray_bundle, volume_feature, scene_bbox, scene_scale, **kwargs): + """Run forward starting with a ray bundle. This outputs different things depending on the + configuration of the model and whether or not the batch is provided (whether or not we are + training basically) + + Args: + ray_bundle: containing all the information needed to render that ray latents included + """ + ray_bundle.origins /= scene_scale + ray_bundle.nears /= scene_scale + ray_bundle.fars /= scene_scale + return self.get_outputs( + ray_bundle, volume_feature, scene_bbox, scene_scale, **kwargs + ) + + def get_loss(self, preds_dict, targets): + loss_dict = {} + loss_weights = self.loss_cfg.weights + scene_scale = targets["scene_scale"] + + depth_pred = preds_dict["depth"] # (num_rays, 1) + depth_gt = targets["depth"] / scene_scale + valid_gt_mask = depth_gt > 0.0 + if loss_weights.get("depth_loss", 0.0) > 0: + depth_loss = torch.sum( + valid_gt_mask * torch.abs(depth_gt - depth_pred) + ) / torch.clamp(valid_gt_mask.sum(), min=1.0) + loss_dict["depth_loss"] = depth_loss * loss_weights.depth_loss + + # free space loss and sdf loss + pred_sdf = preds_dict["sdf"][..., 0] + z_vals = preds_dict["z_vals"][..., 0] + truncation = self.loss_cfg.sensor_depth_truncation / scene_scale + + front_mask = valid_gt_mask & (z_vals < (depth_gt - truncation)) + back_mask = valid_gt_mask & (z_vals > (depth_gt + truncation)) + sdf_mask = valid_gt_mask & (~front_mask) & (~back_mask) + + if loss_weights.get("free_space_loss", 0.0) > 0: + free_space_loss = ( + F.relu(truncation - pred_sdf) * front_mask + ).sum() / torch.clamp(front_mask.sum(), min=1.0) + loss_dict["free_space_loss"] = ( + free_space_loss * loss_weights.free_space_loss + ) + + if loss_weights.get("sdf_loss", 0.0) > 0: + sdf_loss = ( + torch.abs(z_vals + pred_sdf - depth_gt) * sdf_mask + ).sum() / torch.clamp(sdf_mask.sum(), min=1.0) + loss_dict["sdf_loss"] = sdf_loss * loss_weights.sdf_loss + + if loss_weights.get("eikonal_loss", 0.0) > 0: + gradients = preds_dict["gradients"] + eikonal_loss = ((gradients.norm(2, dim=-1) - 1) ** 2).mean() + loss_dict["eikonal_loss"] = eikonal_loss * loss_weights.eikonal_loss + + if loss_weights.get("rgb_loss", 0.0) > 0: + rgb_pred = preds_dict["rgb"] # (num_rays, 3) + rgb_gt = targets["rgb"] + rgb_loss = F.l1_loss(rgb_pred, rgb_gt) + loss_dict["rgb_loss"] = rgb_loss * loss_weights.rgb_loss + psnr = 20.0 * torch.log10(1.0 / (rgb_pred - rgb_gt).pow(2).mean().sqrt()) + loss_dict["psnr"] = psnr + + if loss_weights.get("semantic_loss", 0.0) > 0: + semantic_pred = F.normalize(preds_dict["semantic"], dim=-1) + semantic_gt = F.normalize(targets["semantic"], dim=-1) + semantic_loss = 1 - (semantic_pred * semantic_gt).sum(-1).mean() + loss_dict["semantic_loss"] = semantic_loss * loss_weights.semantic_loss + + return loss_dict diff --git a/spa/models/components/model_utils/render_utils/models/neus.py b/spa/models/components/model_utils/render_utils/models/neus.py new file mode 100644 index 0000000..bc75540 --- /dev/null +++ b/spa/models/components/model_utils/render_utils/models/neus.py @@ -0,0 +1,37 @@ +from functools import partial + +from .base_surface_model import SurfaceModel + + +class NeuSModel(SurfaceModel): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def sample_and_forward_field( + self, ray_bundle, volume_feature, scene_bbox, scene_scale + ): + sampler_out_dict = self.sampler( + ray_bundle, + occupancy_fn=self.field.get_occupancy, + sdf_fn=partial( + self.field.get_sdf, + volume_feature=volume_feature, + scene_bbox=scene_bbox, + scene_scale=scene_scale, + ), + ) + ray_samples = sampler_out_dict.pop("ray_samples") + field_outputs = self.field(ray_samples, volume_feature, scene_bbox, scene_scale) + weights, _ = ray_samples.get_weights_and_transmittance_from_alphas( + field_outputs["alphas"] + ) + + samples_and_field_outputs = { + "ray_samples": ray_samples, + "field_outputs": field_outputs, + "weights": weights, # (num_rays, num_smaples+num_importance, 1) + "sampled_points": ray_samples.frustums.get_positions(), # (num_rays, num_smaples+num_importance, 3) + **sampler_out_dict, + } + + return samples_and_field_outputs diff --git a/spa/models/components/model_utils/render_utils/ray_samplers.py b/spa/models/components/model_utils/render_utils/ray_samplers.py new file mode 100644 index 0000000..a9c6ca1 --- /dev/null +++ b/spa/models/components/model_utils/render_utils/ray_samplers.py @@ -0,0 +1,779 @@ +from abc import abstractmethod + +import torch +from torch import nn + + +class Sampler(nn.Module): + """Generate Samples + + Args: + num_samples: number of samples to take + """ + + def __init__(self, num_samples=None): + super().__init__() + self.num_samples = num_samples + + @abstractmethod + def generate_ray_samples(self): + """Generate Ray Samples""" + + def forward(self, *args, **kwargs): + """Generate ray samples""" + return self.generate_ray_samples(*args, **kwargs) + + +class SpacedSampler(Sampler): + """Sample points according to a function. + + Args: + num_samples: Number of samples per ray + spacing_fn: Function that dictates sample spacing (ie `lambda x : x` is uniform). + spacing_fn_inv: The inverse of spacing_fn. + train_stratified: Use stratified sampling during training. Defults to True + single_jitter: Use a same random jitter for all samples along a ray. Defaults to False + """ + + def __init__( + self, + spacing_fn, + spacing_fn_inv, + num_samples=None, + train_stratified=True, + single_jitter=False, + ): + super().__init__(num_samples=num_samples) + self.train_stratified = train_stratified + self.single_jitter = single_jitter + self.spacing_fn = spacing_fn + self.spacing_fn_inv = spacing_fn_inv + + def generate_ray_samples(self, ray_bundle, num_samples=None): + """Generates position samples accoring to spacing function. + + Args: + ray_bundle: Rays to generate samples for + num_samples: Number of samples per ray + + Returns: + Positions and deltas for samples along a ray + """ + assert ray_bundle is not None + assert ray_bundle.nears is not None + assert ray_bundle.fars is not None + + num_samples = num_samples or self.num_samples + num_rays = ray_bundle.origins.shape[0] + + bins = ( + torch.linspace(0.0, 1.0, num_samples + 1) + .to(ray_bundle.origins.device) + .expand(size=(num_rays, -1)) + ) # [num_rays, num_samples+1] + + if self.train_stratified and self.training: + if self.single_jitter: + t_rand = torch.rand((num_rays, 1), dtype=bins.dtype, device=bins.device) + else: + t_rand = torch.rand( + (num_rays, num_samples + 1), dtype=bins.dtype, device=bins.device + ) + bin_centers = (bins[..., 1:] + bins[..., :-1]) / 2.0 + bin_upper = torch.cat([bin_centers, bins[..., -1:]], -1) + bin_lower = torch.cat([bins[..., :1], bin_centers], -1) + bins = bin_lower + (bin_upper - bin_lower) * t_rand + + s_near, s_far = ( + self.spacing_fn(x) + for x in (ray_bundle.nears.clone(), ray_bundle.fars.clone()) + ) + spacing_to_euclidean_fn = lambda x: self.spacing_fn_inv( + x * s_far + (1 - x) * s_near + ) + euclidean_bins = spacing_to_euclidean_fn(bins) # [num_rays, num_samples+1] + + ray_samples = ray_bundle.get_ray_samples( + bin_starts=euclidean_bins[..., :-1, None], # (num_rays, num_samples, 1) + bin_ends=euclidean_bins[..., 1:, None], + spacing_starts=bins[..., :-1, None], + spacing_ends=bins[..., 1:, None], + spacing_to_euclidean_fn=spacing_to_euclidean_fn, + ) + + return ray_samples + + +class UniformSampler(SpacedSampler): + """Sample uniformly along a ray + + Args: + num_samples: Number of samples per ray + train_stratified: Use stratified sampling during training. Defults to True + single_jitter: Use a same random jitter for all samples along a ray. Defaults to False + """ + + def __init__(self, num_samples=None, train_stratified=True, single_jitter=False): + super().__init__( + num_samples=num_samples, + spacing_fn=lambda x: x, + spacing_fn_inv=lambda x: x, + train_stratified=train_stratified, + single_jitter=single_jitter, + ) + + +class LinearDisparitySampler(SpacedSampler): + """Sample linearly in disparity along a ray + + Args: + num_samples: Number of samples per ray + train_stratified: Use stratified sampling during training. Defults to True + single_jitter: Use a same random jitter for all samples along a ray. Defaults to False + """ + + def __init__(self, num_samples=None, train_stratified=True, single_jitter=False): + super().__init__( + num_samples=num_samples, + spacing_fn=lambda x: 1 / x, + spacing_fn_inv=lambda x: 1 / x, + train_stratified=train_stratified, + single_jitter=single_jitter, + ) + + +class SqrtSampler(SpacedSampler): + """Square root sampler along a ray + + Args: + num_samples: Number of samples per ray + train_stratified: Use stratified sampling during training. Defults to True + """ + + def __init__(self, num_samples=None, train_stratified=True, single_jitter=False): + super().__init__( + num_samples=num_samples, + spacing_fn=torch.sqrt, + spacing_fn_inv=lambda x: x**2, + train_stratified=train_stratified, + single_jitter=single_jitter, + ) + + +class LogSampler(SpacedSampler): + """Log sampler along a ray + + Args: + num_samples: Number of samples per ray + train_stratified: Use stratified sampling during training. Defults to True + """ + + def __init__(self, num_samples=None, train_stratified=True, single_jitter=False): + super().__init__( + num_samples=num_samples, + spacing_fn=torch.log, + spacing_fn_inv=torch.exp, + train_stratified=train_stratified, + single_jitter=single_jitter, + ) + + +class UniformLinDispPiecewiseSampler(SpacedSampler): + """Piecewise sampler along a ray that allocates the first half of the samples uniformly and the second half + using linearly in disparity spacing. + + + Args: + num_samples: Number of samples per ray + train_stratified: Use stratified sampling during training. Defults to True + single_jitter: Use a same random jitter for all samples along a ray. Defaults to False + """ + + def __init__(self, num_samples=None, train_stratified=True, single_jitter=False): + super().__init__( + num_samples=num_samples, + spacing_fn=lambda x: torch.where(x < 1, x / 2, 1 - 1 / (2 * x)), + spacing_fn_inv=lambda x: torch.where(x < 0.5, 2 * x, 1 / (2 - 2 * x)), + train_stratified=train_stratified, + single_jitter=single_jitter, + ) + + +class PDFSampler(Sampler): + """Sample based on probability distribution + + Args: + num_samples: Number of samples per ray + train_stratified: Randomize location within each bin during training. + single_jitter: Use a same random jitter for all samples along a ray. Defaults to False + include_original: Add original samples to ray. + histogram_padding: Amount to weights prior to computing PDF. + """ + + def __init__(self, num_samples=None, train_stratified=True, single_jitter=False): + super().__init__(num_samples=num_samples) + self.train_stratified = train_stratified + self.single_jitter = single_jitter + + def generate_ray_samples( + self, ray_bundle, ray_samples, weights, num_samples=None, eps=1e-5 + ): + """Generates position samples given a distribution. + + Args: + ray_bundle: Rays to generate samples for + ray_samples: Existing ray samples + weights: Weights for each bin, (num_rays, num_samples) + eps: Small value to prevent numerical issues. + + Returns: + Positions and deltas for samples along a ray + """ + num_samples = num_samples or self.num_samples + num_bins = num_samples + 1 + + weights = weights[..., 0] + # Add small offset to rays with zero weight to prevent NaNs + weights_sum = torch.sum(weights, dim=-1, keepdim=True) + padding = torch.relu(eps - weights_sum) + weights = weights + padding / weights.shape[-1] + weights_sum += padding + + pdf = weights / weights_sum + cdf = torch.min(torch.ones_like(pdf), torch.cumsum(pdf, dim=-1)) + cdf = torch.cat( + [torch.zeros_like(cdf[..., :1]), cdf], dim=-1 + ) # (num_rays, num_samples+1) + + if self.train_stratified and self.training: + # Stratified samples between 0 and 1 + u = torch.linspace( + 0.0, 1.0 - (1.0 / num_bins), steps=num_bins, device=cdf.device + ) + u = u.expand(size=(*cdf.shape[:-1], num_bins)) + if self.single_jitter: + rand = torch.rand((*cdf.shape[:-1], 1), device=cdf.device) / num_bins + else: + rand = ( + torch.rand((*cdf.shape[:-1], num_bins), device=cdf.device) + / num_bins + ) + u = u + rand + else: + # Uniform samples between 0 and 1 + u = torch.linspace( + 0.0, 1.0 - (1.0 / num_bins), steps=num_bins, device=cdf.device + ) + u = u + 1.0 / (2 * num_bins) + u = u.expand(size=(*cdf.shape[:-1], num_bins)) + u = u.contiguous() + + assert ( + ray_samples.spacing_starts is not None + and ray_samples.spacing_ends is not None + ), "ray_sample spacing_starts and spacing_ends must be provided" + assert ( + ray_samples.spacing_to_euclidean_fn is not None + ), "ray_samples.spacing_to_euclidean_fn must be provided" + existing_bins = torch.cat( + [ + ray_samples.spacing_starts[..., 0], + ray_samples.spacing_ends[..., -1:, 0], + ], + dim=-1, + ) # (num_rays, num_samples+1) + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp(inds - 1, 0, existing_bins.shape[-1] - 1) + above = torch.clamp(inds, 0, existing_bins.shape[-1] - 1) + cdf_g0 = torch.gather(cdf, -1, below) + bins_g0 = torch.gather(existing_bins, -1, below) + cdf_g1 = torch.gather(cdf, -1, above) + bins_g1 = torch.gather(existing_bins, -1, above) + + # t = torch.clip(torch.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) + denom = cdf_g1 - cdf_g0 + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = torch.clip((u - cdf_g0) / denom, 0, 1) + bins = bins_g0 + t * (bins_g1 - bins_g0) + + # Stop gradients + bins = bins.detach() + + euclidean_bins = ray_samples.spacing_to_euclidean_fn(bins) + + ray_samples = ray_bundle.get_ray_samples( + bin_starts=euclidean_bins[..., :-1, None], # (num_rays, num_importance, 1) + bin_ends=euclidean_bins[..., 1:, None], + spacing_starts=bins[..., :-1, None], + spacing_ends=bins[..., 1:, None], + spacing_to_euclidean_fn=ray_samples.spacing_to_euclidean_fn, + ) + + return ray_samples + + +class NeuSSampler(Sampler): + """NeuS sampler that uses a sdf network to generate samples with fixed variance value in each iterations.""" + + def __init__( + self, + initial_sampler, + num_samples, + num_samples_importance, + num_upsample_steps, + base_variance=64.0, + train_stratified=True, + single_jitter=True, + **kwargs + ): + super().__init__() + self.num_samples = num_samples + self.num_samples_importance = num_samples_importance + self.num_upsample_steps = num_upsample_steps + self.base_variance = base_variance + + # samplers + self.initial_sampler = eval(initial_sampler)( + num_samples=num_samples, + train_stratified=train_stratified, + single_jitter=single_jitter, + ) + self.pdf_sampler = PDFSampler( + train_stratified=train_stratified, single_jitter=single_jitter + ) + + def generate_ray_samples(self, ray_bundle, sdf_fn, **kwargs): + # Start with uniform sampling + ray_samples = self.initial_sampler(ray_bundle) + + total_iters = 0 + sorted_index = None + new_samples = ray_samples + + base_variance = self.base_variance + output_dict = {} + while total_iters < self.num_upsample_steps: + with torch.no_grad(): + new_points = new_samples.frustums.get_positions() + new_sdf = sdf_fn(new_points)[0] + + # merge sdf predictions + if sorted_index is not None: + sdf_merge = torch.cat([sdf.squeeze(-1), new_sdf.squeeze(-1)], -1) + sdf = torch.gather(sdf_merge, 1, sorted_index).unsqueeze(-1) + else: + sdf = new_sdf + + # compute with fix variances + alphas = self.rendering_sdf_with_fixed_inv_s( + ray_samples, sdf.squeeze(-1), inv_s=base_variance * 2**total_iters + ) # (num_rays, num_samples-1) + + weights, _ = ray_samples.get_weights_and_transmittance_from_alphas( + alphas.unsqueeze(-1) + ) + weights = torch.cat( + (weights, torch.zeros_like(weights[:, :1])), dim=1 + ) # (num_rays, num_samples, 1) + + if total_iters == 0: + output_dict.update( + { + "init_sampled_points": new_points, # (num_rays, num_samples, 3) + "init_weights": weights, # (num_rays, num_samples, 1) + } + ) + + new_samples = self.pdf_sampler( + ray_bundle, + ray_samples, + weights, + num_samples=self.num_samples_importance // self.num_upsample_steps, + ) + + if output_dict.get("new_sampled_points", None) is None: + output_dict["new_sampled_points"] = new_samples.frustums.get_positions() + else: + output_dict["new_sampled_points"] = torch.cat( + [ + output_dict["new_sampled_points"], + new_samples.frustums.get_positions(), + ], + dim=1, + ) # (num_rays, num_importance_samples, 3) + + ray_samples, sorted_index = ray_bundle.merge_ray_samples( + ray_samples, new_samples + ) + + total_iters += 1 + + output_dict.update({"ray_samples": ray_samples}) + return output_dict + + def rendering_sdf_with_fixed_inv_s(self, ray_samples, sdf, inv_s): + """rendering given a fixed inv_s as NeuS""" + batch_size = ray_samples.deltas.shape[0] + prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] + deltas = ray_samples.deltas[:, :-1, 0] + mid_sdf = (prev_sdf + next_sdf) * 0.5 + cos_val = (next_sdf - prev_sdf) / (deltas + 1e-5) + + # ---------------------------------------------------------------------------------------------------------- + # Use min value of [ cos, prev_cos ] + # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more + # robust when meeting situations like below: + # + # SDF + # ^ + # |\ -----x----... + # | \ / + # | x x + # |---\----/-------------> 0 level + # | \ / + # | \/ + # | + # ---------------------------------------------------------------------------------------------------------- + prev_cos_val = torch.cat( + [torch.zeros([batch_size, 1], device=sdf.device), cos_val[:, :-1]], dim=-1 + ) + cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) + cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) + cos_val = cos_val.clip(-1e3, 0.0) + + dist = deltas + prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 + next_esti_sdf = mid_sdf + cos_val * dist * 0.5 + prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) + next_cdf = torch.sigmoid(next_esti_sdf * inv_s) + alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) + + return alpha + + +class ErrorBoundedSampler(Sampler): + """VolSDF's error bounded sampler that uses a sdf network to generate samples.""" + + def __init__( + self, + initial_sampler, + num_samples=64, + num_samples_eval=128, + num_samples_extra=32, + eps=0.1, + beta_iters=10, + max_total_iters=5, + train_stratified=True, + single_jitter=True, + ): + super().__init__() + self.num_samples = num_samples + self.num_samples_eval = num_samples_eval + self.num_samples_extra = num_samples_extra + self.eps = eps + self.beta_iters = beta_iters + self.max_total_iters = max_total_iters + # samplers + self.initial_sampler = eval(initial_sampler)( + train_stratified=train_stratified, single_jitter=single_jitter + ) + self.pdf_sampler = PDFSampler( + train_stratified=train_stratified, single_jitter=single_jitter + ) + + def generate_ray_samples(self, ray_bundle, density_fn, sdf_fn, **kwargs): + beta0 = density_fn.get_beta().detach() + + # Start with uniform sampling + ray_samples = self.initial_sampler( + ray_bundle, num_samples=self.num_samples_eval + ) + + # Get maximum beta from the upper bound (Lemma 2) + deltas = ray_samples.deltas.squeeze(-1) + + bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * ( + deltas**2.0 + ).sum(-1) + beta = torch.sqrt(bound) + + total_iters, not_converge = 0, True + sorted_index = None + new_samples = ray_samples + + output_dict = {} + # Algorithm 1 + while not_converge and total_iters < self.max_total_iters: + with torch.no_grad(): + new_points = new_samples.frustums.get_positions() + new_sdf = sdf_fn(new_points)[0] + + # merge sdf predictions + if sorted_index is not None: + sdf_merge = torch.cat([sdf.squeeze(-1), new_sdf.squeeze(-1)], -1) + sdf = torch.gather(sdf_merge, 1, sorted_index).unsqueeze(-1) + else: + sdf = new_sdf + + # Calculating the bound d* (Theorem 1) + d_star = self.get_dstar(sdf.squeeze(-1), ray_samples) + + # Updating beta using line search + beta = self.get_updated_beta( + beta0, beta, density_fn, sdf.squeeze(-1), d_star, ray_samples + ) + + # Upsample more points + density = density_fn(sdf.squeeze(-1), beta=beta.unsqueeze(-1)) + + weights, transmittance = ray_samples.get_weights_and_transmittance( + density.unsqueeze(-1) + ) + + if total_iters == 0: + output_dict.update( + { + "init_sampled_points": new_points, # (num_rays, num_samples, 3) + "init_weights": weights, # (num_rays, num_samples, 1) + } + ) + + # Check if we are done and this is the last sampling + total_iters += 1 + not_converge = beta.max() > beta0 + + if not_converge and total_iters < self.max_total_iters: + # Sample more points proportional to the current error bound + deltas = ray_samples.deltas.squeeze(-1) + + error_per_section = ( + torch.exp(-d_star / beta.unsqueeze(-1)) + * (deltas**2.0) + / (4 * beta.unsqueeze(-1) ** 2) + ) + + error_integral = torch.cumsum(error_per_section, dim=-1) + weights = ( + torch.clamp(torch.exp(error_integral), max=1.0e6) - 1.0 + ) * transmittance[..., 0] + + new_samples = self.pdf_sampler( + ray_bundle, + ray_samples, + weights.unsqueeze(-1), + num_samples=self.num_samples_eval, + ) + + ray_samples, sorted_index = ray_bundle.merge_ray_samples( + ray_samples, new_samples + ) + + else: + # Sample the final sample set to be used in the volume rendering integral + ray_samples = self.pdf_sampler( + ray_bundle, ray_samples, weights, num_samples=self.num_samples + ) + output_dict["new_sampled_points"] = ray_samples.frustums.get_positions() + + # Add extra samples uniformly + if self.num_samples_extra > 0: + ray_samples_uniform = self.initial_sampler( + ray_bundle, num_samples=self.num_samples_extra + ) + ray_samples, _ = ray_bundle.merge_ray_samples( + ray_samples, ray_samples_uniform + ) + + output_dict.update({"ray_samples": ray_samples}) + return output_dict + + def get_dstar(self, sdf, ray_samples): + """Calculating the bound d* (Theorem 1) from VolSDF""" + d = sdf + dists = ray_samples.deltas.squeeze(-1) + a, b, c = dists[:, :-1], d[:, :-1].abs(), d[:, 1:].abs() + first_cond = a.pow(2) + b.pow(2) <= c.pow(2) + second_cond = a.pow(2) + c.pow(2) <= b.pow(2) + d_star = torch.zeros( + ray_samples.deltas.shape[0], ray_samples.deltas.shape[1] - 1 + ).to(d) + d_star[first_cond] = b[first_cond] + d_star[second_cond] = c[second_cond] + s = (a + b + c) / 2.0 + area_before_sqrt = s * (s - a) * (s - b) * (s - c) + mask = ~first_cond & ~second_cond & (b + c - a > 0) + d_star[mask] = ((2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])).to(d) + d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign + + # padding to make the same shape as ray_samples + # d_star_left = torch.cat((d_star[:, :1], d_star), dim=-1) + # d_star_right = torch.cat((d_star, d_star[:, -1:]), dim=-1) + # d_star = torch.minimum(d_star_left, d_star_right) + + d_star = torch.cat((d_star, d_star[:, -1:]), dim=-1) + return d_star + + def get_updated_beta(self, beta0, beta, density_fn, sdf, d_star, ray_samples): + curr_error = self.get_error_bound(beta0, density_fn, sdf, d_star, ray_samples) + beta[curr_error <= self.eps] = beta0 + beta_min, beta_max = beta0.repeat(ray_samples.deltas.shape[0]), beta + for j in range(self.beta_iters): + beta_mid = (beta_min + beta_max) / 2.0 + curr_error = self.get_error_bound( + beta_mid.unsqueeze(-1), density_fn, sdf, d_star, ray_samples + ) + beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps] + beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps] + beta = beta_max + return beta + + def get_error_bound(self, beta, density_fn, sdf, d_star, ray_samples): + """Get error bound from VolSDF""" + densities = density_fn(sdf, beta=beta) + + deltas = ray_samples.deltas.squeeze(-1) + delta_density = deltas * densities + + integral_estimation = torch.cumsum(delta_density[..., :-1], dim=-1) + integral_estimation = torch.cat( + [ + torch.zeros( + (*integral_estimation.shape[:1], 1), device=densities.device + ), + integral_estimation, + ], + dim=-1, + ) + + error_per_section = torch.exp(-d_star / beta) * (deltas**2.0) / (4 * beta**2) + error_integral = torch.cumsum(error_per_section, dim=-1) + bound_opacity = ( + torch.clamp(torch.exp(error_integral), max=1.0e6) - 1.0 + ) * torch.exp(-integral_estimation) + + return bound_opacity.max(-1)[0] + + +class UniSurfSampler(Sampler): + """NeuS sampler that uses a sdf network to generate samples with fixed variance value in each iterations.""" + + def __init__( + self, + initial_sampler, + num_samples_importance, + num_marching_steps, + num_samples_interval, + delta, + train_stratified=True, + single_jitter=True, + ): + super().__init__() + self.num_samples_importance = num_samples_importance + self.num_marching_steps = num_marching_steps + self.num_samples_interval = num_samples_interval + self.delta = delta + # self.sample_ratio = sample_ratio + self.single_jitter = single_jitter + # samplers + self.initial_sampler = eval(initial_sampler)( + train_stratified=train_stratified, single_jitter=single_jitter + ) + self.pdf_sampler = PDFSampler( + train_stratified=train_stratified, single_jitter=single_jitter + ) + + def generate_ray_samples(self, ray_bundle, occupancy_fn, sdf_fn, **kwargs): + output_dict = {} + # Start with uniform sampling + ray_samples = self.initial_sampler( + ray_bundle, num_samples=self.num_marching_steps + ) + points = ray_samples.frustums.get_positions() + with torch.no_grad(): + sdf = sdf_fn(points)[0] + + # importance sampling + occupancy = occupancy_fn(sdf) + weights, _ = ray_samples.get_weights_and_transmittance_from_alphas(occupancy) + + output_dict.update( + { + "init_sampled_points": ray_samples.frustums.get_positions(), # (num_rays, num_samples, 3) + "init_weights": weights, # (num_rays, num_samples, 1) + } + ) + + importance_samples = self.pdf_sampler( + ray_bundle, + ray_samples, + weights, + num_samples=self.num_samples_importance, + ) + + # surface points + # Calculate if sign change occurred and concat 1 (no sign change) in + # last dimension + n_rays, n_samples = ray_samples.deltas.shape[:2] + starts = ray_samples.frustums.starts + sign_matrix = torch.cat( + [ + torch.sign(sdf[:, :-1, 0] * sdf[:, 1:, 0]), + torch.ones(n_rays, 1).to(sdf.device), + ], + dim=-1, + ) + cost_matrix = sign_matrix * torch.arange(n_samples, 0, -1).float().to( + sdf.device + ) # (n_rays, n_samples) + + # Get first sign change and mask for values where a.) a sign changed + # occurred and b.) no a neg to pos sign change occurred (meaning from + # inside surface to outside) + values, indices = torch.min(cost_matrix, -1) + mask_sign_change = values < 0 # (n_rays,) + mask_pos_to_neg = sdf[torch.arange(n_rays), indices, 0] > 0 + + # Define mask where a valid depth value is found + mask = mask_sign_change & mask_pos_to_neg # (n_rays,) + + # Get depth values and function values for the interval + d_low = starts[torch.arange(n_rays), indices, 0][mask] + v_low = sdf[torch.arange(n_rays), indices, 0][mask] + + indices = torch.clamp(indices + 1, max=n_samples - 1) + d_high = starts[torch.arange(n_rays), indices, 0][mask] + v_high = sdf[torch.arange(n_rays), indices, 0][mask] + + # TODO secant method + # linear-interpolations, estimated depth values + z = (v_low * d_high - v_high * d_low) / (v_low - v_high) + + # modify near and far values according current schedule + nears, fars = ray_bundle.nears.clone(), ray_bundle.fars.clone() + dists = fars - nears + + ray_bundle.nears[mask] = z[:, None] - dists[mask] * self.delta + ray_bundle.fars[mask] = z[:, None] + dists[mask] * self.delta + + # min max bound + ray_bundle.nears = torch.maximum(ray_bundle.nears, nears) + ray_bundle.fars = torch.minimum(ray_bundle.fars, fars) + + # samples uniformly with new surface interval + ray_samples_interval = self.initial_sampler( + ray_bundle, num_samples=self.num_samples_interval + ) + + # change back to original values + ray_bundle.nears = nears + ray_bundle.fars = fars + + # merge sampled points + ray_samples = ray_bundle.merge_ray_samples_in_eculidean( + ray_samples_interval, importance_samples + ) + + output_dict["new_sampled_points"] = ray_samples.frustums.get_positions() + output_dict.update({"ray_samples": ray_samples}) + return output_dict diff --git a/spa/models/components/model_utils/render_utils/rays.py b/spa/models/components/model_utils/render_utils/rays.py new file mode 100644 index 0000000..de26d31 --- /dev/null +++ b/spa/models/components/model_utils/render_utils/rays.py @@ -0,0 +1,227 @@ +import torch +import torch.nn as nn + + +class Frustums(nn.Module): + def __init__(self, origins, directions, starts, ends, **kwargs): + super().__init__() + self.origins = origins + self.directions = directions + self.starts = starts + self.ends = ends + + def get_positions(self): + """Calulates "center" position of frustum. Not weighted by mass. + + Returns: + xyz positions: (num_rays, num_samples, 3) + """ + pos = self.origins + self.directions * (self.starts + self.ends) / 2 + return pos + + def get_start_positions(self): + """Calulates "start" position of frustum. We use start positions for MonoSDF + because when we use error bounded sampling, we need to upsample many times. + It's hard to merge two set of ray samples while keeping the mid points fixed. + Every time we up sample the points the mid points will change and + therefore we need to evaluate all points again which is 3 times slower. + But we can skip the evaluation of sdf value if we use start position instead of mid position + because after we merge the points, the starting point is the same and only the delta is changed. + + Returns: + xyz positions: (num_rays, num_samples, 3) + """ + return self.origins + self.directions * self.starts + + +class RaySamples(nn.Module): + """Samples along a ray""" + + def __init__( + self, + frustums, + deltas, + spacing_starts, + spacing_ends, + spacing_to_euclidean_fn, + **kwargs + ): + super().__init__() + self.frustums = frustums + self.deltas = deltas + self.spacing_starts = spacing_starts + self.spacing_ends = spacing_ends + self.spacing_to_euclidean_fn = spacing_to_euclidean_fn + + def get_weights_and_transmittance(self, densities): + """Return weights based on predicted densities + + Args: + densities: Predicted densities for samples along ray + + Returns: + Weights for each sample + """ + + delta_density = self.deltas * densities + alphas = 1 - torch.exp(-delta_density) + + transmittance = torch.cumsum(delta_density[..., :-1, :], dim=-2) + transmittance = torch.cat( + [ + torch.zeros((*transmittance.shape[:1], 1, 1), device=densities.device), + transmittance, + ], + dim=-2, + ) + transmittance = torch.exp(-transmittance) # [..., "num_samples"] + + weights = alphas * transmittance # [..., "num_samples"] + + return weights, transmittance + + def get_weights_and_transmittance_from_alphas(self, alphas): + """Return weights based on predicted alphas + + Args: + alphas: Predicted alphas (maybe from sdf) for samples along ray + + Returns: + Weights for each sample + """ + transmittance = torch.cumprod( + torch.cat( + [ + torch.ones((*alphas.shape[:1], 1, 1), device=alphas.device), + 1.0 - alphas + 1e-7, + ], + 1, + ), + 1, + ) # [..., "num_samples"] + + weights = alphas * transmittance[:, :-1, :] # [num_rays, num_samples, 1] + + return weights, transmittance + + +class RayBundle(nn.Module): + """A bundle of ray parameters.""" + + def __init__(self, origins, directions, nears=None, fars=None, **kwargs): + super().__init__(**kwargs) + self.origins = origins # (num_rays, 3) + self.directions = directions # (num_rays, 3) + self.nears = nears # (num_rays, 1) + self.fars = fars # (num_rays, 1) + + def merge_ray_samples(self, ray_samples_1, ray_samples_2): + """Merge two set of ray samples and return sorted index which can be used to merge sdf values + + Args: + ray_samples_1 : ray_samples to merge + ray_samples_2 : ray_samples to merge + """ + + starts_1 = ray_samples_1.spacing_starts[..., 0] + starts_2 = ray_samples_2.spacing_starts[..., 0] + + ends = torch.maximum( + ray_samples_1.spacing_ends[..., -1:, 0], + ray_samples_2.spacing_ends[..., -1:, 0], + ) + + bins, sorted_index = torch.sort(torch.cat([starts_1, starts_2], -1), -1) + + bins = torch.cat([bins, ends], dim=-1) + + # Stop gradients + bins = bins.detach() + + euclidean_bins = ray_samples_1.spacing_to_euclidean_fn(bins) + + ray_samples = self.get_ray_samples( + bin_starts=euclidean_bins[ + ..., :-1, None + ], # (num_rays, num_samples + num_importance, 1) + bin_ends=euclidean_bins[..., 1:, None], + spacing_starts=bins[..., :-1, None], + spacing_ends=bins[..., 1:, None], + spacing_to_euclidean_fn=ray_samples_1.spacing_to_euclidean_fn, + ) + + return ray_samples, sorted_index + + def merge_ray_samples_in_eculidean(self, ray_samples_1, ray_samples_2): + """Merge two set of ray samples and return sorted index which can be used to merge sdf values + + Args: + ray_samples_1 : ray_samples to merge + ray_samples_2 : ray_samples to merge + """ + starts_1 = ray_samples_1.frustums.starts[..., 0] + starts_2 = ray_samples_2.frustums.starts[..., 0] + + end_1 = ray_samples_1.frustums.ends[:, -1:, 0] + end_2 = ray_samples_2.frustums.ends[:, -1:, 0] + + end = torch.maximum(end_1, end_2) + + euclidean_bins, _ = torch.sort(torch.cat([starts_1, starts_2], -1), -1) + + euclidean_bins = torch.cat([euclidean_bins, end], dim=-1) + + # Stop gradients + euclidean_bins = euclidean_bins.detach() + + # TODO convert euclidean bins to spacing bins + bins = euclidean_bins + + ray_samples = self.get_ray_samples( + bin_starts=euclidean_bins[..., :-1, None], + bin_ends=euclidean_bins[..., 1:, None], + spacing_starts=None, + spacing_ends=None, + spacing_to_euclidean_fn=None, # near and far are different + ) + + return ray_samples + + def get_ray_samples( + self, + bin_starts, + bin_ends, + spacing_starts, + spacing_ends, + spacing_to_euclidean_fn, + ): + """Produces samples for each ray by projection points along the ray direction. Currently samples uniformly. + + Args: + bin_starts: Distance from origin to start of bin. + bin_ends: Distance from origin to end of bin. + + Returns: + Samples projected along ray. + """ + deltas = bin_ends - bin_starts + broadcast_size = [*deltas.shape[:-1], -1] + frustums = Frustums( + origins=self.origins[..., None, :].expand( + broadcast_size + ), # (num_rays, num_samples, 3) + directions=self.directions[..., None, :].expand( + broadcast_size + ), # (num_rays, num_samples, 3) + starts=bin_starts, # (num_rays, num_samples, 1) + ends=bin_ends, + ) + ray_samples = RaySamples( + frustums=frustums, + deltas=deltas, # [..., num_samples, 1] + spacing_starts=spacing_starts, # [..., num_samples, 1] + spacing_ends=spacing_ends, # [..., num_samples, 1] + spacing_to_euclidean_fn=spacing_to_euclidean_fn, + ) + + return ray_samples diff --git a/spa/models/components/model_utils/render_utils/renderers.py b/spa/models/components/model_utils/render_utils/renderers.py new file mode 100644 index 0000000..2785715 --- /dev/null +++ b/spa/models/components/model_utils/render_utils/renderers.py @@ -0,0 +1,63 @@ +import torch +from torch import nn + + +class RGBRenderer(nn.Module): + """Standard volumetic rendering.""" + + def __init__(self, background_color=(0.0, 0.0, 0.0)): + super().__init__() + self.background_color = background_color + + def forward(self, rgb, weights): + """Composite samples along ray and render color image + + Args: + rgb: RGB for each sample, (num_rays, num_samples, 3) + weights: Weights for each sample, (num_rays, num_samples, 1) + Returns: + Outputs of rgb values. + """ + comp_rgb = torch.sum(weights * rgb, dim=-2) # (num_rays, 3) + accumulated_weight = torch.sum(weights, dim=-2) + comp_rgb = comp_rgb + comp_rgb.new_tensor(self.background_color) * ( + 1.0 - accumulated_weight + ) + if not self.training: + torch.clamp_(comp_rgb, min=0.0, max=1.0) + return comp_rgb + + +class DepthRenderer(nn.Module): + """Calculate depth along ray.""" + + def __init__(self, **kwargs): + super().__init__() + + def forward(self, ray_samples, weights): + """Composite samples along ray and calculate depths. + + Args: + weights: Weights for each sample. + ray_samples: Set of ray samples. + Returns: + Outputs of depth values. + """ + eps = 1e-10 + steps = (ray_samples.frustums.starts + ray_samples.frustums.ends) / 2 + # steps = ray_samples.frustums.starts + depth = torch.sum(weights * steps, dim=-2) / (torch.sum(weights, -2) + eps) + depth = torch.clip(depth, steps.min(), steps.max()) + return depth + + +class OtherRenderer(nn.Module): + """Calculate normals along the ray.""" + + def __init__(self, **kwargs): + super().__init__() + + def forward(self, vals, weights): + """Calculate normals along the ray.""" + val = torch.sum(weights * vals, dim=-2) + return val diff --git a/spa/models/components/model_utils/render_utils/scene_colliders.py b/spa/models/components/model_utils/render_utils/scene_colliders.py new file mode 100644 index 0000000..87c472b --- /dev/null +++ b/spa/models/components/model_utils/render_utils/scene_colliders.py @@ -0,0 +1,142 @@ +import numpy as np +import torch +import torch.nn as nn + + +class AABBBoxCollider(nn.Module): + """Module for colliding rays with the scene box to compute near and far values. + + Args: + scene_box: scene box to apply to dataset + """ + + def __init__(self, near_plane, **kwargs): + super().__init__() + self.near_plane = near_plane + + def _intersect_with_aabb(self, rays_o, rays_d, aabb): + """Returns collection of valid rays within a specified near/far bounding box along with a mask + specifying which rays are valid + + Args: + rays_o: (num_rays, 3) ray origins, scaled + rays_d: (num_rays, 3) ray directions + aabb: (6, ) This is [min point (x,y,z), max point (x,y,z)], scaled + """ + # avoid divide by zero + dir_fraction = 1.0 / (rays_d + 1e-6) + + # x + t1 = (aabb[0] - rays_o[:, 0:1]) * dir_fraction[:, 0:1] + t2 = (aabb[3] - rays_o[:, 0:1]) * dir_fraction[:, 0:1] + # y + t3 = (aabb[1] - rays_o[:, 1:2]) * dir_fraction[:, 1:2] + t4 = (aabb[4] - rays_o[:, 1:2]) * dir_fraction[:, 1:2] + # z + t5 = (aabb[2] - rays_o[:, 2:3]) * dir_fraction[:, 2:3] + t6 = (aabb[5] - rays_o[:, 2:3]) * dir_fraction[:, 2:3] + + near = torch.max( + torch.cat( + [torch.minimum(t1, t2), torch.minimum(t3, t4), torch.minimum(t5, t6)], + dim=1, + ), + dim=1, + ).values + far = torch.min( + torch.cat( + [torch.maximum(t1, t2), torch.maximum(t3, t4), torch.maximum(t5, t6)], + dim=1, + ), + dim=1, + ).values + + # clamp to near plane + near = torch.clamp(near, min=self.near_plane) + # assert torch.all(nears < fars), "not collide with scene box" + # fars = torch.maximum(fars, nears + 1e-6) + mask = near < far + near[~mask] = 0.0 + far[~mask] = 0.0 + + return near, far, mask + + def forward(self, origins, directions, scene_bbox): + """Intersects the rays with the scene box and updates the near and far values. + Populates nears and fars fields and returns the ray_bundle. + Returns: + nears: (num_rays, 1) + fars: (num_rays, 1) + """ + near, far, mask = self._intersect_with_aabb(origins, directions, scene_bbox) + near = near[..., None] + far = far[..., None] + return near, far, mask + + +class AABBBoxColliderNp: + """Module for colliding rays with the scene box to compute near and far values. + + Args: + scene_box: scene box to apply to dataset + """ + + def __init__(self, near_plane, **kwargs): + self.near_plane = near_plane + + def _intersect_with_aabb(self, rays_o, rays_d, aabb): + """Returns collection of valid rays within a specified near/far bounding box along with a mask + specifying which rays are valid + + Args: + rays_o: (num_rays, 3) ray origins, scaled + rays_d: (num_rays, 3) ray directions + aabb: (6, ) This is [min point (x,y,z), max point (x,y,z)], scaled + """ + # avoid divide by zero + dir_fraction = 1.0 / (rays_d + 1e-6) + + # x + t1 = (aabb[0] - rays_o[:, 0:1]) * dir_fraction[:, 0:1] + t2 = (aabb[3] - rays_o[:, 0:1]) * dir_fraction[:, 0:1] + # y + t3 = (aabb[1] - rays_o[:, 1:2]) * dir_fraction[:, 1:2] + t4 = (aabb[4] - rays_o[:, 1:2]) * dir_fraction[:, 1:2] + # z + t5 = (aabb[2] - rays_o[:, 2:3]) * dir_fraction[:, 2:3] + t6 = (aabb[5] - rays_o[:, 2:3]) * dir_fraction[:, 2:3] + + near = np.max( + np.concatenate( + [np.minimum(t1, t2), np.minimum(t3, t4), np.minimum(t5, t6)], + axis=1, + ), + axis=1, + ) + far = np.min( + np.concatenate( + [np.maximum(t1, t2), np.maximum(t3, t4), np.maximum(t5, t6)], + axis=1, + ), + axis=1, + ) + + # clamp to near plane + near = np.clip(near, a_min=self.near_plane, a_max=None) + mask = near < far + near[~mask] = 0.0 + far[~mask] = 0.0 + + return near, far, mask + + def __call__(self, origins, directions, scene_bbox): + """Intersects the rays with the scene box and updates the near and far values. + Populates nears and fars fields and returns the ray_bundle. + Returns: + nears: (num_rays, 1) + fars: (num_rays, 1) + """ + near, far, mask = self._intersect_with_aabb(origins, directions, scene_bbox) + near = near[..., None] + far = far[..., None] + return near, far, mask diff --git a/spa/models/components/model_utils/transformer_utils.py b/spa/models/components/model_utils/transformer_utils.py new file mode 100644 index 0000000..0073c7f --- /dev/null +++ b/spa/models/components/model_utils/transformer_utils.py @@ -0,0 +1,314 @@ +import copy + +import numpy as np +import torch +import torch.nn as nn + +from . import attention_utils + + +def pos2embed(pos, num_pos_feats=128, reverse=False): + scale = 2 * np.pi + pos = pos * scale + dim_t = torch.arange(num_pos_feats).to(pos) + dim_t = 2 * (dim_t // 2) / num_pos_feats + 1 + split_pos = torch.split(pos, 1, dim=-1) + split_pos = [pos_t / dim_t for pos_t in split_pos] + split_pos = [ + torch.stack((pos_t[..., 0::2].sin(), pos_t[..., 1::2].cos()), dim=-1).flatten( + -2 + ) + for pos_t in split_pos + ] + if reverse: + split_pos = split_pos[::-1] + posemb = torch.cat(split_pos, dim=-1) + return posemb + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class FFN(nn.Module): + def __init__( + self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + add_identity=True, + **kwargs, + ): + super(FFN, self).__init__() + assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}." + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append( + nn.Sequential( + nn.Linear(in_channels, feedforward_channels), + nn.ReLU(inplace=True), + nn.Dropout(ffn_drop), + ) + ) + in_channels = feedforward_channels + layers.append(nn.Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = nn.Sequential(*layers) + self.add_identity = add_identity + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x, identity=None): + """Forward function for `FFN`. + + The function would add x to the output tensor if residue is None. + """ + out = self.layers(x) + if not self.add_identity: + return out + if identity is None: + identity = x + return identity + out + + +class BaseTransformerLayer(nn.Module): + def __init__( + self, + attn_cfgs=None, + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + ), + operation_order=None, + **kwargs, + ): + super(BaseTransformerLayer, self).__init__() + + num_attn = operation_order.count("self_attn") + operation_order.count( + "cross_attn" + ) + assert num_attn == len(attn_cfgs) + + self.num_attn = num_attn + self.operation_order = operation_order + self.pre_norm = operation_order[0] == "norm" + self.attentions = nn.ModuleList() + + index = 0 + for operation_name in operation_order: + if operation_name in ["self_attn", "cross_attn"]: + attn_type = attn_cfgs[index].pop("type") + attention = getattr(attention_utils, attn_type)(**attn_cfgs[index]) + # Some custom attentions used as `self_attn` + # or `cross_attn` can have different behavior. + attention.operation_name = operation_name + self.attentions.append(attention) + index += 1 + + self.embed_dims = self.attentions[0].embed_dims + + self.ffns = nn.ModuleList() + num_ffns = operation_order.count("ffn") + ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] + assert len(ffn_cfgs) == num_ffns + for ffn_index in range(num_ffns): + if "embed_dims" not in ffn_cfgs[ffn_index]: + ffn_cfgs[ffn_index]["embed_dims"] = self.embed_dims + else: + assert ffn_cfgs[ffn_index]["embed_dims"] == self.embed_dims + self.ffns.append(FFN(**ffn_cfgs[ffn_index])) + + self.norms = nn.ModuleList() + num_norms = operation_order.count("norm") + for _ in range(num_norms): + self.norms.append(nn.LayerNorm(self.embed_dims)) + + def forward( + self, + query, + key=None, + value=None, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs, + ): + """Forward function for `TransformerDecoderLayer`. + + **kwargs contains some specific arguments of attentions. + + Args: + query (Tensor): The input query with shape + [num_queries, bs, embed_dims] if + self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + value (Tensor): The value tensor with same shape as `key`. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. + Default: None. + attn_masks (List[Tensor] | None): 2D Tensor used in + calculation of corresponding attention. The length of + it should equal to the number of `attention` in + `operation_order`. Default: None. + query_key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_queries]. Only used in `self_attn` layer. + Defaults to None. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_keys]. Default: None. + + Returns: + Tensor: forwarded results with shape [num_queries, bs, embed_dims]. + """ + + norm_index = 0 + attn_index = 0 + ffn_index = 0 + identity = query + if attn_masks is None: + attn_masks = [None for _ in range(self.num_attn)] + + assert len(attn_masks) == self.num_attn + + for layer in self.operation_order: + if layer == "self_attn": + temp_key = temp_value = query + query = self.attentions[attn_index]( + query, + temp_key, + temp_value, + identity if self.pre_norm else None, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=attn_masks[attn_index], + key_padding_mask=query_key_padding_mask, + **kwargs, + ) + attn_index += 1 + identity = query + + elif layer == "norm": + query = self.norms[norm_index](query) + norm_index += 1 + + elif layer == "cross_attn": + query = self.attentions[attn_index]( + query, + key, + value, + identity if self.pre_norm else None, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=attn_masks[attn_index], + key_padding_mask=key_padding_mask, + **kwargs, + ) + attn_index += 1 + identity = query + + elif layer == "ffn": + query = self.ffns[ffn_index](query, identity if self.pre_norm else None) + ffn_index += 1 + + return query + + +class TransformerLayerSequence(nn.Module): + def __init__(self, transformerlayers=None, num_layers=None): + super(TransformerLayerSequence, self).__init__() + + transformerlayers = [ + copy.deepcopy(transformerlayers) for _ in range(num_layers) + ] + + self.num_layers = num_layers + self.layers = nn.ModuleList() + for i in range(num_layers): + self.layers.append(BaseTransformerLayer(**transformerlayers[i])) + self.embed_dims = self.layers[0].embed_dims + self.pre_norm = self.layers[0].pre_norm + + def forward( + self, + query, + key, + value, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs, + ): + """Forward function for `TransformerCoder`. + + Args: + query (Tensor): Input query with shape + `(num_queries, bs, embed_dims)`. + key (Tensor): The key tensor with shape + `(num_keys, bs, embed_dims)`. + value (Tensor): The value tensor with shape + `(num_keys, bs, embed_dims)`. + query_pos (Tensor): The positional encoding for `query`. + Default: None. + key_pos (Tensor): The positional encoding for `key`. + Default: None. + attn_masks (List[Tensor], optional): Each element is 2D Tensor + which is used in calculation of corresponding attention in + operation_order. Default: None. + query_key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_queries]. Only used in self-attention + Default: None. + key_padding_mask (Tensor): ByteTensor for `query`, with + shape [bs, num_keys]. Default: None. + + Returns: + Tensor: results with shape [num_queries, bs, embed_dims]. + """ + for layer in self.layers: + query = layer( + query, + key, + value, + query_pos=query_pos, + key_pos=key_pos, + attn_masks=attn_masks, + query_key_padding_mask=query_key_padding_mask, + key_padding_mask=key_padding_mask, + **kwargs, + ) + return query diff --git a/spa/models/components/model_utils/unet3d_utils.py b/spa/models/components/model_utils/unet3d_utils.py new file mode 100644 index 0000000..0bdfb3d --- /dev/null +++ b/spa/models/components/model_utils/unet3d_utils.py @@ -0,0 +1,339 @@ +from functools import partial + +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ + + +class BasicBlock3D(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=False, + group_norm=False, + ): + super().__init__() + + self.downsample = downsample + self.conv1 = nn.Conv3d(inplanes, planes, 3, stride, padding=1, bias=False) + self.bn1 = ( + nn.BatchNorm3d(planes, eps=1e-3, momentum=0.01) + if not group_norm + else nn.GroupNorm( + num_groups=planes // 16, + num_channels=planes, + ) + ) + self.conv2 = nn.Conv3d(planes, planes, 3, stride=1, padding=1, bias=False) + self.bn2 = ( + nn.BatchNorm3d(planes, eps=1e-3, momentum=0.01) + if not group_norm + else nn.GroupNorm( + num_groups=planes // 16, + num_channels=planes, + ) + ) + self.relu = nn.GELU() # nn.ReLU() + self.downsample = downsample + if self.downsample: + self.downsample_layer = nn.Sequential( + nn.Conv3d( + inplanes, + planes, + 1, + stride=stride, + padding=0, + bias=False, + ), + ( + nn.BatchNorm3d(planes, eps=1e-3, momentum=0.01) + if not group_norm + else nn.GroupNorm(num_groups=planes // 16, num_channels=planes) + ), + ) + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample: + identity = self.downsample_layer(x) + + out = out + identity + out = self.relu(out) + + return out + + +class SimpleConv3D(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=3, padding=1, stride=1, **kwargs + ): + super(SimpleConv3D, self).__init__() + self.conv = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + ), + nn.BatchNorm3d(out_channels), + # nn.ReLU(inplace=True), + nn.GELU(), + ) + + def forward(self, x, **kwargs): + outs = [] + outs.append(self.conv(x)) + return outs + + +class BEV23DConv3D(nn.Module): + def __init__( + self, + in_channels, + mid_channels, + out_channels, + grid_size, + kernel_size=3, + padding=1, + stride=1, + **kwargs + ): + super(BEV23DConv3D, self).__init__() + self.grid_size = grid_size + self.mid_channels = mid_channels + self.conv2d = nn.Sequential( + nn.Conv2d( + in_channels, + mid_channels * grid_size[2], + kernel_size=1, + padding=0, + stride=1, + ), + nn.BatchNorm2d(mid_channels * grid_size[2]), + # nn.ReLU(inplace=True), + nn.GELU(), + ) + self.conv3d = nn.Sequential( + nn.Conv3d( + mid_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + ), + nn.BatchNorm3d(out_channels), + # nn.ReLU(inplace=True), + nn.GELU(), + ) + + def forward(self, x, **kwargs): + outs = [] + x = self.conv2d(x) + # (B, C, Y, X) -> (B, C, Z, Y, X) + x = x.view(x.shape[0], -1, *torch.flip(self.grid_size, dims=[-1])) + outs.append(self.conv3d(x)) + return outs + + +class BEV23DExpConv3D(nn.Module): + def __init__( + self, + in_channels, + mid_channels, + grid_size, + sdf_channels, + semantic_channels, + kernel_size=3, + padding=1, + stride=1, + **kwargs + ): + super(BEV23DExpConv3D, self).__init__() + self.grid_size = grid_size + self.mid_channels = mid_channels + self.conv2d = nn.Sequential( + nn.Conv2d( + in_channels, + mid_channels * grid_size[2], + kernel_size=1, + padding=0, + stride=1, + ), + nn.BatchNorm2d(mid_channels * grid_size[2]), + # nn.ReLU(inplace=True), + nn.GELU(), + ) + self.shared_conv3d = BasicBlock3D(mid_channels, mid_channels, downsample=False) + self.sdf_conv3d = nn.Sequential( + nn.Conv3d( + mid_channels, + mid_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + ), + nn.BatchNorm3d(mid_channels), + # nn.Softplus(beta=100), + nn.GELU(), + nn.Conv3d( + mid_channels, + sdf_channels, + kernel_size=1, + ), + ) + self.semantic_conv3d = nn.Sequential( + nn.Conv3d( + mid_channels + sdf_channels - 1, + mid_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + ), + nn.BatchNorm3d(mid_channels), + # nn.ReLU(inplace=True), + nn.GELU(), + nn.Conv3d( + mid_channels, + semantic_channels, + kernel_size=1, + ), + ) + + def forward(self, x, **kwargs): + x = self.conv2d(x) + # (B, C, Y, X) -> (B, C, Z, Y, X) + assert x.shape[-1] == self.grid_size[0] and x.shape[-2] == self.grid_size[1] + x = x.view(x.shape[0], -1, *torch.flip(self.grid_size, dims=[-1])) + x = self.shared_conv3d(x) + sdf = self.sdf_conv3d(x) + semantic = self.semantic_conv3d(torch.cat([x, sdf[:, 1:]], dim=1)) + outs = [sdf[:, :1], semantic] + return outs + + +class ExpConv3D(nn.Module): + def __init__( + self, + in_channels, + mid_channels, + sdf_channels, + rgb_channels=None, + semantic_channels=None, + kernel_size=3, + padding=1, + stride=1, + group_norm=False, + **kwargs + ): + super(ExpConv3D, self).__init__() + + self.shared_conv3d = BasicBlock3D( + in_channels, mid_channels, downsample=True, group_norm=group_norm + ) + self.sdf_conv3d = nn.Sequential( + nn.Conv3d( + mid_channels, + mid_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + ), + ( + nn.BatchNorm3d(mid_channels) + if not group_norm + else nn.GroupNorm( + num_groups=mid_channels // 16, + num_channels=mid_channels, + ) + ), + nn.Softplus(beta=100), + nn.Conv3d( + mid_channels, + sdf_channels, + kernel_size=1, + ), + ) + self.rgb_channels = rgb_channels + if self.rgb_channels is not None: + self.rgb_conv3d = nn.Sequential( + nn.Conv3d( + mid_channels + sdf_channels - 1, + mid_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + ), + ( + nn.BatchNorm3d(mid_channels) + if not group_norm + else nn.GroupNorm( + num_groups=mid_channels // 16, + num_channels=mid_channels, + ) + ), + nn.GELU(), + nn.Conv3d( + mid_channels, + rgb_channels, + kernel_size=1, + ), + ) + self.semantic_channels = semantic_channels + if self.semantic_channels is not None: + self.semantic_conv3d = nn.Sequential( + nn.Conv3d( + mid_channels + sdf_channels - 1, + mid_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + ), + ( + nn.BatchNorm3d(mid_channels) + if not group_norm + else nn.GroupNorm( + num_groups=mid_channels // 16, + num_channels=mid_channels, + ) + ), + nn.GELU(), + nn.Conv3d( + mid_channels, + semantic_channels, + kernel_size=1, + ), + ) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv3d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, **kwargs): + x = self.shared_conv3d(x) + sdf = self.sdf_conv3d(x) + outs = [sdf[:, :1]] + if self.rgb_channels is not None: + rgb = self.rgb_conv3d(torch.cat([x, sdf[:, 1:]], dim=1)) + outs.append(rgb) + if self.semantic_channels is not None: + semantic = self.semantic_conv3d(torch.cat([x, sdf[:, 1:]], dim=1)) + outs.append(semantic) + return outs diff --git a/spa/models/components/spa.py b/spa/models/components/spa.py new file mode 100644 index 0000000..f5a676a --- /dev/null +++ b/spa/models/components/spa.py @@ -0,0 +1,279 @@ +import os +import re + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import spa.utils as U +from spa.data.components.processor.data_processor_gpu import DataProcessorGPU + +logger = U.RankedLogger(__name__, rank_zero_only=True) + + +def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): + # features: (N, C) + # m: a hyperparam controlling how many std dev outside for outliers + assert len(features.shape) == 2, "features should be (N, C)" + reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] + colors = features @ reduction_mat + if remove_first_component: + colors_min = colors.min(dim=0).values + colors_max = colors.max(dim=0).values + tmp_colors = (colors - colors_min) / (colors_max - colors_min) + fg_mask = tmp_colors[..., 0] < 0.2 + reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] + colors = features @ reduction_mat + else: + fg_mask = torch.ones_like(colors[:, 0]).bool() + d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) + mdev = torch.median(d, dim=0).values + s = d / mdev + try: + rins = colors[fg_mask][s[:, 0] < m, 0] + gins = colors[fg_mask][s[:, 1] < m, 1] + bins = colors[fg_mask][s[:, 2] < m, 2] + rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) + rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) + except: + rins = colors + gins = colors + bins = colors + rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) + rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) + + return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) + + +def get_pca_map( + feature_map: torch.Tensor, + return_pca_stats=False, + pca_stats=None, +): + """ + feature_map: (1, h, w, C) is the feature map of a single image. + """ + if feature_map.shape[0] != 1: + # make it (1, h, w, C) + feature_map = feature_map[None] + if pca_stats is None: + reduct_mat, color_min, color_max = get_robust_pca( + feature_map.reshape(-1, feature_map.shape[-1]) + ) + else: + reduct_mat, color_min, color_max = pca_stats + pca_color = feature_map @ reduct_mat + pca_color = (pca_color - color_min) / (color_max - color_min) + pca_color = pca_color.clamp(0, 1) + pca_color = pca_color.cpu().numpy().squeeze(0) + if return_pca_stats: + return pca_color, (reduct_mat, color_min, color_max) + return pca_color + + +class SPA(nn.Module): + def __init__( + self, + fp16_enabled_layers=[], + img_backbone=None, + view_transform=None, + dense_head=None, + data_processor_cfg={}, + ckpt_name=None, + ): + super().__init__() + + self.logger = logger + + self.fp16_enabled_layers = fp16_enabled_layers + self.img_backbone = img_backbone + self.view_transform = view_transform + self.dense_head = dense_head + self.data_processor_cfg = data_processor_cfg + self.data_processor = None + self.init_data_processor() + + if ckpt_name is not None: + self.load_pretrained(ckpt_name) + + self.logger.info("----------- FP16 Enabled Status -----------") + for module_name in self.fp16_enabled_layers: + getattr(self, module_name).fp16_enabled = True + self.logger.info( + f"{module_name}: {getattr(self, module_name).fp16_enabled}" + ) + + def load_pretrained(self, ckpt_name: str = None): + assert ckpt_name in [ + "spa-l", + "spa-b", + ], f"`ckpt_name` should be 'spa-l' or 'spa-b', got {ckpt_name}" + + from huggingface_hub import hf_hub_download + + try: + import safetensors.torch + + _has_safetensors = True + except ImportError: + _has_safetensors = False + + if _has_safetensors: + from safetensors.torch import load_file + + ckpt_file = hf_hub_download( + repo_id="HaoyiZhu/SPA", filename=f"{ckpt_name}.safetensors" + ) + state_dict = load_file(ckpt_file) + else: + ckpt_file = hf_hub_download( + repo_id="HaoyiZhu/SPA", filename=f"{ckpt_name}.ckpt" + ) + state_dict = torch.load(ckpt_file)["state_dict"] + + self.load_state_dict(state_dict, strict=True) + + def init_data_processor(self): + self.data_processor = DataProcessorGPU( + self.data_processor_cfg, mode=self.mode, logger=self.logger + ) + + @property + def mode(self): + return "train" if self.training else "test" + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super().train(mode) + self.init_data_processor() + + @torch.amp.autocast( + "cuda", enabled=False + ) # manually enable half precision in each module + def forward(self, batch_dict): + batch_dict = self.data_processor.forward(batch_dict=batch_dict) + batch_dict = self.dense_head(self.view_transform(self.img_backbone(batch_dict))) + + render_out = batch_dict.pop("render_out") + ray_dict = batch_dict.pop("ray_dict") + loss, loss_dict = self.dense_head.get_loss(render_out, ray_dict) + + out_dict = dict(loss=loss, **loss_dict) + + if not self.training: + out_dict.update(self.prepare_visualize(render_out, batch_dict, ray_dict)) + + return out_dict + + @torch.no_grad() + def prepare_visualize(self, render_out, data_dict, ray_dict): + W, H = ( + int(data_dict["depth"][0].shape[-1]), + int(data_dict["depth"][0].shape[-2]), + ) + + gt_img = ray_dict[0]["rgb"].reshape(-1, W, 3) * 255.0 + pred_img = render_out[0]["rgb"].reshape(-1, W, 3) * 255.0 + + gt_depth = ray_dict[0]["depth"].reshape(-1, W) + pred_depth = render_out[0]["depth"].reshape(-1, W) + + pred_normal = ( + render_out[0]["normal"].float().reshape(-1, W, 3) * 127.5 + 127.5 + ).clip(0, 255) + + if "masked_inputs" in data_dict and data_dict["masked_inputs"] is not None: + img_norm_cfg = self.model_cfg.img_norm_cfg + masked_inputs = einops.rearrange( + data_dict["masked_inputs"], "n c h w -> (n h) w c", c=3, w=W + ) + mask = 1.0 - (masked_inputs == 0).float() + masked_inputs = ( + ( + masked_inputs + * torch.FloatTensor(img_norm_cfg["std"]) + .to(masked_inputs.device) + .reshape(1, 1, 3) + + torch.FloatTensor(img_norm_cfg["mean"]) + .to(masked_inputs.device) + .reshape(1, 1, 3) + ) + * mask + * 255.0 + ) + + if "mae_pred" in data_dict: + mae_pred = einops.rearrange( + data_dict["mae_pred"], "n c h w -> (n h) w c", c=3, w=W + ) + mae_pred = ( + mae_pred + * torch.FloatTensor(img_norm_cfg["std"]) + .to(mae_pred.device) + .reshape(1, 1, 3) + + torch.FloatTensor(img_norm_cfg["mean"]) + .to(mae_pred.device) + .reshape(1, 1, 3) + ) * 255.0 + mask = einops.repeat(data_dict["mask"], "n 1 h w -> (n h) w 3", w=W) + paste_img = masked_inputs * (1 - mask) + mae_pred * mask + gt_img = torch.cat([gt_img, masked_inputs, mae_pred, paste_img], dim=1) + else: + gt_img = torch.cat([gt_img, masked_inputs], dim=1) + + if "semantic" in render_out[0]: + semantic_pred = render_out[0]["semantic"] + semantic_gt = ray_dict[0]["semantic"] + similarity = torch.cosine_similarity( + F.normalize(semantic_pred, dim=-1).detach().clone(), + F.normalize(semantic_gt, dim=-1).detach().clone(), + dim=-1, + ).reshape(-1, W) + semantic_gt_pca = get_pca_map( + einops.rearrange( + semantic_gt, + "(b h w) c -> 1 (b h) w c", + h=H, + w=W, + c=semantic_gt.shape[-1], + ) + ) + semantic_pred_pca = get_pca_map( + einops.rearrange( + semantic_pred, + "(b h w) c -> 1 (b h) w c", + h=H, + w=W, + c=semantic_pred.shape[-1], + ) + ) + return dict( + gt_img=gt_img.cpu().numpy(), + pred_img=pred_img.cpu().numpy(), + gt_depth=gt_depth.cpu().numpy(), + pred_depth=pred_depth.cpu().numpy(), + pred_normal=pred_normal.cpu().numpy(), + similarity=similarity.cpu().numpy(), + semantic_gt_pca=semantic_gt_pca, + semantic_pred_pca=semantic_pred_pca, + ) + + return dict( + gt_img=gt_img.cpu().numpy(), + pred_img=pred_img.cpu().numpy(), + gt_depth=gt_depth.cpu().numpy(), + pred_depth=pred_depth.cpu().numpy(), + pred_normal=pred_normal.cpu().numpy(), + ) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r"^img_backbone.cls_token|img_backbone.pos_embed|img_backbone.patch_embed", # stem and embed + blocks=[ + (r"^img_backbone.blocks\.(\d+)", None), + (r"^img_backbone.norm", (99999,)), + ], + ) diff --git a/spa/models/components/view_transforms/__init__.py b/spa/models/components/view_transforms/__init__.py new file mode 100644 index 0000000..1ff707a --- /dev/null +++ b/spa/models/components/view_transforms/__init__.py @@ -0,0 +1 @@ +from .lss_voxelformer import LSSVoxelformer diff --git a/spa/models/components/view_transforms/lss_voxelformer.py b/spa/models/components/view_transforms/lss_voxelformer.py new file mode 100644 index 0000000..c5d30f4 --- /dev/null +++ b/spa/models/components/view_transforms/lss_voxelformer.py @@ -0,0 +1,176 @@ +import numpy as np +import torch +from torch import nn + +from ..model_utils.transformer_utils import TransformerLayerSequence + + +class LearnablePositionalEncoding(nn.Module): + def __init__(self, input_channel, embed_dims=256): + super().__init__() + self.position_embedding = nn.Sequential( + nn.Linear(input_channel, embed_dims), + nn.BatchNorm1d(embed_dims), + nn.ReLU(inplace=True), + nn.Linear(embed_dims, embed_dims), + ) + self.init_weights() + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, xyz): + position_embedding = self.position_embedding(xyz) + return position_embedding + + +class LSSVoxelformer(nn.Module): + def __init__( + self, in_channels, grid_size, feature_map_stride, transformer_cfg, **kwargs + ): + super().__init__() + self.grid_size = np.array(grid_size, dtype=np.int64) // feature_map_stride + self.feature_map_stride = feature_map_stride + self.in_channels = in_channels + self.register_buffer("voxels", self.create_voxels()) + self.position_encoding = LearnablePositionalEncoding( + 3, embed_dims=self.in_channels + ) + self.decoder = TransformerLayerSequence(**transformer_cfg) + + def create_voxels(self): + zs, ys, xs = torch.meshgrid( + *[torch.arange(0, gs, dtype=torch.float) for gs in self.grid_size[::-1]], + indexing="ij" + ) + # [0, grid_size-1] + voxels = torch.stack([xs, ys, zs], dim=-1) # (Z, Y, X, 3) + return voxels + + def prepare_voxels(self, voxel_size, point_cloud_range): + voxel_coords = self.voxels + voxel_size = voxel_coords.new_tensor(voxel_size) + point_cloud_range = voxel_coords.new_tensor(point_cloud_range) + voxel_coords = (voxel_coords + 0.5) * voxel_size + + voxel_embeds = voxel_coords / (point_cloud_range[3:6] - point_cloud_range[:3]) + assert torch.all((voxel_embeds < 1.0) & (voxel_embeds > 0.0)) + Z, Y, X = voxel_embeds.shape[:3] + # (Z, Y, X, C) + voxel_embeds = self.position_encoding(voxel_embeds.view(Z * Y * X, -1)) + voxel_embeds = voxel_embeds.view(Z, Y, X, -1) + + voxel_coords = voxel_coords + point_cloud_range[:3] + return voxel_coords, voxel_embeds + + def transform_voxels(self, voxel_coords, img, world2cam, cam2img): + # (Z, Y, X, 4) + voxel_coords = torch.cat( + [voxel_coords, torch.ones_like(voxel_coords[..., :1])], dim=-1 + ) + + # (N, 4, 4) + world2img = cam2img @ world2cam + + # (N, 1, 1, 1, 4, 4) @ (1, Z, Y, X, 4, 1) -> (N, Z, Y, X, 3) + voxel_cam_coords = ( + world2img[:, None, None, None] @ voxel_coords[None, ..., None] + )[..., :3, 0] + + eps = 1e-5 + voxel_cam_depths = voxel_cam_coords[..., 2].clone() + mask = voxel_cam_depths > eps + # (N, Z, Y, X, 2) + voxel_cam_coords = voxel_cam_coords[..., :2] / torch.maximum( + voxel_cam_depths, torch.ones_like(voxel_cam_depths) * eps + ).unsqueeze(-1) + + H, W = img.shape[-2:] + voxel_cam_coords[..., 0] /= W + voxel_cam_coords[..., 1] /= H + + mask &= ( + (voxel_cam_coords[..., 0] > 0) + & (voxel_cam_coords[..., 0] < 1) + & (voxel_cam_coords[..., 1] > 0) + & (voxel_cam_coords[..., 1] < 1) + ) + + return voxel_cam_coords, voxel_cam_depths, mask + + def inner_forward(self, mlvl_feats, voxel_cam_coords, voxel_embeds, mask): + feat_flatten = [] + spatial_shapes = [] + for lvl, feat in enumerate(mlvl_feats): + _, _, H, W = feat.shape + spatial_shape = (H, W) + # (N1+N2+..., C, H, W) -> (N1+N2+..., C, H*W) -> (N1+N2+..., H*W, C) + feat = feat.flatten(-2).permute(0, 2, 1).contiguous() + spatial_shapes.append(spatial_shape) + feat_flatten.append(feat) + + # (N1+N2+..., H1*W1+H2*W2+..., C) + feat_flatten = torch.cat(feat_flatten, dim=1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device + ) # (num_level, 2) + level_start_index = torch.cat( + [ + spatial_shapes.new_zeros((1,)), + spatial_shapes.prod(dim=1).cumsum(dim=0)[:-1], + ], + dim=0, + ) # (num_level,), [0, H1*W1, H1*W1+H2*W2, ...] + + voxel_features = self.decoder( + query=voxel_embeds, + value=feat_flatten, + key=feat_flatten, + reference_points=voxel_cam_coords, + query_mask=mask, + level_start_index=level_start_index, + spatial_shapes=spatial_shapes, + ) + voxel_features = voxel_features.permute(0, 4, 1, 2, 3).contiguous() + return voxel_features + + def forward(self, batch_dict): + # [(N1+N2+..., C, H, W), ...] + img_features = batch_dict["img_features"] + batch_size = batch_dict["batch_size"] + world2cam = batch_dict["world2cam"] + cam2img = batch_dict["cam2img"] + img = batch_dict["img"] + voxel_size = batch_dict["voxel_size"] + point_cloud_range = batch_dict["point_cloud_range"] + + voxel_embeds, voxel_cam_coords, mask = ( + [], + [], + [], + ) + for bidx in range(batch_size): + vs = voxel_size[bidx] * self.feature_map_stride + pcr = point_cloud_range[bidx] + vc, ve = self.prepare_voxels(vs, pcr) + + l2c = world2cam[bidx] + c2i = cam2img[bidx] + vcc, _, m = self.transform_voxels(vc, img[bidx], l2c, c2i) + + voxel_embeds.append(ve) + voxel_cam_coords.append(vcc) + mask.append(m) + # (B, Z, Y, X, C) + voxel_embeds = torch.stack(voxel_embeds, dim=0) + + # (B, C, Z, Y, X) + encoded_spconv_tensor = self.inner_forward( + img_features, voxel_cam_coords, voxel_embeds, mask + ) + batch_dict["encoded_spconv_tensor"] = encoded_spconv_tensor + batch_dict["encoded_spconv_tensor_stride"] = self.feature_map_stride + return batch_dict diff --git a/spa/models/spa_pretrain_module.py b/spa/models/spa_pretrain_module.py new file mode 100755 index 0000000..b8ef051 --- /dev/null +++ b/spa/models/spa_pretrain_module.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import numpy as np +import torch +from lightning import LightningModule + +from spa import utils as U +from spa.utils import RankedLogger + +log = RankedLogger(__name__, rank_zero_only=True) + + +class SPAPretrainModule(LightningModule): + def __init__( + self, + model, + optimizer, + lr_scheduler, + train_metrics, + val_metrics, + best_val_metrics, + compile: bool = False, + **kwargs, + ) -> None: + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters( + logger=False, + ignore=["model", "train_metrics", "val_metrics", "best_val_metrics"], + ) + + self.model = model + + # metric objects for calculating and averaging accuracy across batches + self.train_metrics = train_metrics + self.val_metrics = val_metrics + + # for tracking best so far validation metrics + self.best_val_metrics = best_val_metrics + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.train_metrics.reset() + self.val_metrics.reset() + self.best_val_metrics.reset() + + def model_step( + self, batch: tuple[torch.Tensor, torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.model(batch) + + def on_train_epoch_start(self) -> None: + super().on_train_epoch_start() + self.model.train() + + def training_step( + self, + batch, + batch_idx, + ) -> torch.Tensor: + if not isinstance(batch, list): + batch = [batch] + + count = 0 + dst_batch = None + dataloader_idx = 0 + for i, x in enumerate(batch): + if x is not None: + count += 1 + dst_batch = x + dataloader_idx = i + batch = dst_batch + assert count == 1, "Only one dataset is allowed for each iteration" + + loss_dict = self.model_step(batch) + + # update and log metrics + self.train_metrics(loss_dict) + batch_size = len(batch["img"]) + self.log_dict( + self.train_metrics.metrics_dict(), + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=batch_size, + ) + + # log global_step + self.log( + "global_step", + self.global_step, + on_step=True, + on_epoch=False, + prog_bar=False, + logger=False, + ) + + # return loss or backpropagation will fail + return loss_dict["loss"] + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + return super().on_train_epoch_end() + + def on_validation_epoch_start(self) -> None: + super().on_validation_epoch_start() + self.model.eval() + + def validation_step( + self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> None: + loss_dict = self.model_step(batch) + + # update and log metrics + self.val_metrics(loss_dict) + batch_size = len(batch["img"]) + self.log_dict( + self.val_metrics.metrics_dict(), + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=batch_size, + ) + + # check input + def project(xyz, K, R, T): + xyz = np.dot(xyz, R.T) + T.T + xyz = np.dot(xyz, K.T) + xy = xyz[:, :2] / xyz[:, 2:] + return xy + + # visualize + gt_img = loss_dict["gt_img"] + pred_img = loss_dict["pred_img"] + gt_depth = loss_dict["gt_depth"] + pred_depth = loss_dict["pred_depth"] + if "pred_normal" not in loss_dict: + pred_normal = np.zeros_like(gt_img) + else: + pred_normal = loss_dict["pred_normal"] + + img_tfboard = np.concatenate( + [gt_img, pred_img, pred_normal], axis=1 + ) # (H, W, 3) + img_tfboard /= 255.0 + + self.logger.experiment.add_images( + f"val/img_tfboard_{batch_idx}", + img_tfboard, + global_step=self.current_epoch, + dataformats="HWC", + ) + + def min_max_normalize(data): + return ( + (data - data[data > 0].min()) + / (data[data > 0].max() - data[data > 0].min()) + ).clip(0.0, 1.0) + + depth_tfboard = np.concatenate([gt_depth, pred_depth], axis=1) + depth_tfboard = cm.bwr(min_max_normalize(depth_tfboard / 2.0).clip(0.0, 1.0)) + # depth_tfboard = depth_tfboard.transpose(2, 0, 1) + self.logger.experiment.add_images( + f"val/depth_tfboard_{batch_idx}", + depth_tfboard, + global_step=self.current_epoch, + dataformats="HWC", + ) + + if "similarity" in loss_dict: + similarity = loss_dict["similarity"] + self.logger.experiment.add_images( + f"val/semantic_similarity_map_{batch_idx}", + similarity, + global_step=self.current_epoch, + dataformats="HW", + ) + if "semantic_gt_pca" in loss_dict: + semantic_gt_pca = loss_dict["semantic_gt_pca"] + semantic_pred_pca = loss_dict["semantic_pred_pca"] + + semantic_pca = np.concatenate( + [ + semantic_gt_pca, + semantic_pred_pca, + ], + axis=1, + ) + self.logger.experiment.add_images( + f"val/semantic_pca_{batch_idx}", + semantic_pca, + global_step=self.current_epoch, + dataformats="HWC", + ) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + metrics = self.val_metrics.compute() # get current val metrics + self.best_val_metrics(metrics) # update best so far val metrics + # log `best_val_metrics` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log_dict(self.best_val_metrics.compute(), sync_dist=True, prog_bar=True) + + def test_step( + self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> None: + raise NotImplementedError + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + pass + + def setup(self, stage: str) -> None: + if self.hparams.compile and stage == "fit": + self.model = torch.compile(self.model) + + def configure_optimizers(self) -> dict[str, Any]: + optimizer = U.build_optimizer(self.hparams.optimizer, self.model) + if self.hparams.lr_scheduler is not None: + self.hparams.lr_scheduler.scheduler.total_steps = ( + self.trainer.estimated_stepping_batches + ) + scheduler = U.build_scheduler( + self.hparams.lr_scheduler.scheduler, optimizer=optimizer + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": self.hparams.lr_scheduler.get("monitor", "val/loss"), + "interval": self.hparams.lr_scheduler.get("interval", "step"), + "frequency": self.hparams.lr_scheduler.get("frequency", 1), + }, + } + return {"optimizer": optimizer} diff --git a/spa/train.py b/spa/train.py new file mode 100755 index 0000000..2a6562a --- /dev/null +++ b/spa/train.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +import torch +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig, OmegaConf + +OmegaConf.register_new_resolver("eval", eval) + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from spa import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from spa.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating callbacks...") + callbacks: list[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: list[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, callbacks=callbacks, logger=logger + ) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/spa/utils/__init__.py b/spa/utils/__init__.py new file mode 100755 index 0000000..7cf335b --- /dev/null +++ b/spa/utils/__init__.py @@ -0,0 +1,23 @@ +import spa.utils.io as io_utils +from spa.utils.dist import ( + get_rank, + get_world_size, + is_dist_avail_and_initialized, + is_main_process, +) +from spa.utils.instantiators import instantiate_callbacks, instantiate_loggers +from spa.utils.logging_utils import log_hyperparameters +from spa.utils.misc import ( + import_modules_from_strings, + interpolate_linear, + is_seq_of, + make_dirs, +) +from spa.utils.pylogger import RankedLogger +from spa.utils.registry import Registry, build_from_cfg +from spa.utils.rich_utils import enforce_tags, print_config_tree +from spa.utils.utils import extras, get_metric_value, task_wrapper + +from .optimizer import build_optimizer +from .scheduler import build_scheduler +from .transforms import components_from_spherical_harmonics diff --git a/spa/utils/callbacks.py b/spa/utils/callbacks.py new file mode 100644 index 0000000..2b5fd41 --- /dev/null +++ b/spa/utils/callbacks.py @@ -0,0 +1,444 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import copy +import os +import threading +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import lightning.pytorch as pl +import torch +from lightning import Callback +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import rank_zero_info + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. + cpu_offload: Offload weights to CPU. + """ + + def __init__( + self, + decay: float = 0.999, + validate_original_weights: bool = False, + every_n_steps: int = 1, + cpu_offload: bool = False, + ): + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self.decay = decay + self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps + self.cpu_offload = cpu_offload + + def on_fit_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + device = pl_module.device if not self.cpu_offload else torch.device("cpu") + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) + ] + + def on_validation_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_validation_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def on_test_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + return any( + isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers + ) + + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """ + Saves an EMA copy of the model + EMA optimizer states for resume. + """ + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) + + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False + + def on_load_checkpoint( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any], + ) -> None: + checkpoint_callback = trainer.checkpoint_callback + + # use the connector as NeMo calls the connector directly in the exp_manager when restoring. + connector = trainer._checkpoint_connector + # Replace connector._ckpt_path with below to avoid calling into lightning's protected API + ckpt_path = trainer.ckpt_path + + if ( + ckpt_path + and checkpoint_callback is not None + # and "NeMo" in type(checkpoint_callback).__name__ + ): + ext = checkpoint_callback.FILE_EXTENSION + if ckpt_path.endswith(f"-EMA{ext}"): + rank_zero_info( + "loading EMA based weights. " + "The callback will treat the loaded EMA weights as the main weights" + " and create a new EMA copy when training." + ) + return + ema_path = ckpt_path.replace(ext, f"-EMA{ext}") + if os.path.exists(ema_path): + ema_state_dict = torch.load(ema_path, map_location=torch.device("cpu")) + + checkpoint["optimizer_states"] = ema_state_dict["optimizer_states"] + del ema_state_dict + rank_zero_info("EMA state has been restored.") + else: + raise MisconfigurationException( + "Unable to find the associated EMA weights when re-loading, " + f"training will start with new EMA weights. Expected them to be at: {ema_path}", + ) + + +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, + current_model_tuple, + alpha=(1.0 - decay), + ) + + +def run_ema_update_cpu( + ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None +): + if pre_sync_stream is not None: + pre_sync_stream.synchronize() + + ema_update(ema_model_tuple, current_model_tuple, decay) + + +class EMAOptimizer(torch.optim.Optimizer): + r""" + EMAOptimizer is a wrapper for torch.optim.Optimizer that computes + Exponential Moving Average of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + self.in_saving_ema_model_context = False + + def all_parameters(self) -> Iterable[torch.Tensor]: + return (param for group in self.param_groups for param in group["params"]) + + def step(self, closure=None, grad_scaler=None, **kwargs): + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) + for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + if ( + getattr(self.optimizer, "_step_supports_amp_scaling", False) + and grad_scaler is not None + ): + loss = self.optimizer.step(closure=closure, grad_scaler=grad_scaler) + else: + loss = self.optimizer.step(closure) + + if self._should_update_at_step(): + self.update() + self.current_step += 1 + return loss + + def _should_update_at_step(self) -> bool: + return self.current_step % self.every_n_steps == 0 + + @torch.no_grad() + def update(self): + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) + for param in self.all_parameters() + ) + + if self.device.type == "cuda": + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == "cpu": + self.thread = threading.Thread( + target=run_ema_update_cpu, + args=( + self.ema_params, + current_model_state, + self.decay, + self.stream, + ), + ) + self.thread.start() + + def swap_tensors(self, tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self, saving_ema_model: bool = False): + self.join() + self.in_saving_ema_model_context = saving_ema_model + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r""" + A context manager to in-place swap regular parameters with EMA + parameters. + It swaps back to the original regular parameters on context manager + exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + if enabled: + self.switch_main_parameter_weights() + try: + yield + finally: + if enabled: + self.switch_main_parameter_weights() + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def join(self): + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + def state_dict(self): + self.join() + + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = ( + self.ema_params + if not self.in_saving_ema_model_context + else list(self.all_parameters()) + ) + state_dict = { + "opt": self.optimizer.state_dict(), + "ema": ema_params, + "current_step": self.current_step, + "decay": self.decay, + "every_n_steps": self.every_n_steps, + } + return state_dict + + def load_state_dict(self, state_dict): + self.join() + + self.optimizer.load_state_dict(state_dict["opt"]) + self.ema_params = tuple( + param.to(self.device) for param in copy.deepcopy(state_dict["ema"]) + ) + self.current_step = state_dict["current_step"] + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + self.rebuild_ema_params = False + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True + + +class EMAModelCheckpoint(ModelCheckpoint): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _ema_callback(self, trainer: pl.Trainer) -> Optional[EMA]: + ema_callback = None + for callback in trainer.callbacks: + if isinstance(callback, EMA): + ema_callback = callback + return ema_callback + + def _save_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + ema_callback = self._ema_callback(trainer) + if ema_callback is not None: + with ema_callback.save_original_optimizer_state(trainer): + super()._save_checkpoint(trainer, filepath) + + # save EMA copy of the model as well. + with ema_callback.save_ema_model(trainer): + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info( + f"Saving EMA weights to separate checkpoint {filepath}" + ) + super()._save_checkpoint(trainer, filepath) + else: + super()._save_checkpoint(trainer, filepath) + + def _ema_format_filepath(self, filepath: str) -> str: + return filepath.replace(self.FILE_EXTENSION, f"-EMA{self.FILE_EXTENSION}") + + def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool: + return any( + self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints + ) + + def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: + return str(filepath).endswith(f"-EMA{self.FILE_EXTENSION}") + + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + super()._remove_checkpoint(trainer, filepath) + ema_callback = self._ema_callback(trainer) + if ema_callback is not None: + # remove EMA copy of the state dict as well. + filepath = self._ema_format_filepath(filepath) + super()._remove_checkpoint(trainer, filepath) diff --git a/spa/utils/common_utils.py b/spa/utils/common_utils.py new file mode 100644 index 0000000..035ab16 --- /dev/null +++ b/spa/utils/common_utils.py @@ -0,0 +1,322 @@ +import logging +import os +import pickle +import random +import shutil +import subprocess +from functools import partial + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def check_state_dict(state_dict, state_dict_disk, logger): + logger.info("----------- Pre-trained Weights ------------") + update_model_state = {} + for key, val in state_dict_disk.items(): + if key in state_dict and state_dict[key].shape == val.shape: + update_model_state[key] = val + else: + logger.info( + "Not loaded weight: %s, %s" % (key, str(state_dict_disk[key].shape)) + ) + for key, val in state_dict.items(): + if key not in update_model_state: + logger.info( + "Not updated weight: %s, %s" % (key, str(state_dict[key].shape)) + ) + logger.info("==> Done (loaded %d/%d)" % (len(update_model_state), len(state_dict))) + return update_model_state + + +def reduce_mean(tensor): + """Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +def multi_apply(func, *args, **kwargs): + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def check_numpy_to_torch(x): + if isinstance(x, np.ndarray): + return torch.from_numpy(x).float(), True + return x, False + + +def limit_period(val, offset=0.5, period=np.pi): + val, is_numpy = check_numpy_to_torch(val) + ans = val - torch.floor(val / period + offset) * period + return ans.numpy() if is_numpy else ans + + +def drop_info_with_name(info, name): + ret_info = {} + keep_indices = [i for i, x in enumerate(info["name"]) if x != name] + for key in info.keys(): + ret_info[key] = info[key][keep_indices] + return ret_info + + +def rotate_points_along_z(points, angle): + """ + Args: + points: (B, N, 3 + C) + angle: (B), angle along z-axis, angle increases x ==> y + Returns: + + """ + points, is_numpy = check_numpy_to_torch(points) + angle, _ = check_numpy_to_torch(angle) + + cosa = torch.cos(angle) + sina = torch.sin(angle) + zeros = angle.new_zeros(points.shape[0]) + ones = angle.new_ones(points.shape[0]) + rot_matrix = ( + torch.stack((cosa, sina, zeros, -sina, cosa, zeros, zeros, zeros, ones), dim=1) + .view(-1, 3, 3) + .float() + ) + points_rot = torch.matmul(points[:, :, 0:3], rot_matrix) + points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1) + return points_rot.numpy() if is_numpy else points_rot + + +def angle2matrix(angle): + """ + Args: + angle: angle along z-axis, angle increases x ==> y + Returns: + rot_matrix: (3x3 Tensor) rotation matrix + """ + angle, is_numpy = check_numpy_to_torch(angle) + cosa = torch.cos(angle) + sina = torch.sin(angle) + rot_matrix = torch.tensor([[cosa, -sina, 0], [sina, cosa, 0], [0, 0, 1]]) + return rot_matrix.numpy() if is_numpy else rot_matrix + + +def mask_points_by_range(points, limit_range, padding=0.0, with_height=True): + dim = 3 if with_height else 2 + mask = np.all( + points[:, :dim] >= (np.array(limit_range[:dim]) + padding), axis=1 + ) & np.all( + points[:, :dim] <= (np.array(limit_range[3 : 3 + dim]) - padding), axis=1 + ) + return mask + + +def get_voxel_centers( + voxel_coords, downsample_times, voxel_size, point_cloud_range, dim=3 +): + """ + Args: + voxel_coords: (N, 3) + downsample_times: + voxel_size: + point_cloud_range: + + Returns: + + """ + voxel_centers = torch.flip(voxel_coords, dims=[-1]).float() # (x, y, z) or (x, y) + voxel_size = ( + torch.tensor(voxel_size[:dim], device=voxel_centers.device).float() + * downsample_times + ) + pc_range = torch.tensor( + point_cloud_range[:dim], device=voxel_centers.device + ).float() + voxel_centers = (voxel_centers + 0.5) * voxel_size + pc_range + return voxel_centers + + +def create_logger(log_file=None, rank=0, log_level=logging.INFO): + logger = logging.getLogger(__name__) + logger.setLevel(log_level if rank == 0 else "ERROR") + formatter = logging.Formatter("%(asctime)s %(levelname)5s %(message)s") + console = logging.StreamHandler() + console.setLevel(log_level if rank == 0 else "ERROR") + console.setFormatter(formatter) + logger.addHandler(console) + if log_file is not None: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setLevel(log_level if rank == 0 else "ERROR") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.propagate = False + return logger + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def get_pad_params(desired_size, cur_size): + """ + Get padding parameters for np.pad function + Args: + desired_size: int, Desired padded output size + cur_size: int, Current size. Should always be less than or equal to cur_size + Returns: + pad_params: tuple(int), Number of values padded to the edges (before, after) + """ + assert desired_size >= cur_size + + # Calculate amount to pad + diff = desired_size - cur_size + pad_params = (0, diff) + + return pad_params + + +def keep_arrays_by_name(gt_names, used_classes): + inds = [i for i, x in enumerate(gt_names) if x in used_classes] + inds = np.array(inds, dtype=np.int64) + return inds + + +def init_dist_slurm(tcp_port, backend="nccl"): + """ + modified from https://github.com/open-mmlab/mmdetection + Args: + tcp_port: + backend: + + Returns: + + """ + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + "scontrol show hostname {} | head -n1".format(node_list) + ) + os.environ["MASTER_PORT"] = str(tcp_port) + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(ntasks) + os.environ["RANK"] = str(proc_id) + dist.init_process_group(backend=backend) + + total_gpus = dist.get_world_size() + rank = dist.get_rank() + return total_gpus, rank + + +def init_dist_pytorch(tcp_port, backend="nccl"): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method("spawn") + # os.environ['MASTER_PORT'] = str(tcp_port) + # os.environ['MASTER_ADDR'] = 'localhost' + num_gpus = torch.cuda.device_count() + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank % num_gpus) + + dist.init_process_group( + backend=backend, + # init_method='tcp://127.0.0.1:%d' % tcp_port, + # rank=local_rank, + # world_size=num_gpus + ) + rank = dist.get_rank() + return num_gpus, rank + + +def get_dist_info(return_gpu_per_machine=False): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + if return_gpu_per_machine: + gpu_per_machine = torch.cuda.device_count() + return rank, world_size, gpu_per_machine + + return rank, world_size + + +def merge_results_dist(result_part, size, tmpdir): + rank, world_size = get_dist_info() + os.makedirs(tmpdir, exist_ok=True) + + dist.barrier() + pickle.dump( + result_part, open(os.path.join(tmpdir, "result_part_{}.pkl".format(rank)), "wb") + ) + dist.barrier() + + if rank != 0: + return None + + part_list = [] + for i in range(world_size): + part_file = os.path.join(tmpdir, "result_part_{}.pkl".format(i)) + part_list.append(pickle.load(open(part_file, "rb"))) + + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + ordered_results = ordered_results[:size] + shutil.rmtree(tmpdir) + return ordered_results + + +def scatter_point_inds(indices, point_inds, shape): + ret = -1 * torch.ones(*shape, dtype=point_inds.dtype, device=point_inds.device) + ndim = indices.shape[-1] + flattened_indices = indices.view(-1, ndim) + slices = [flattened_indices[:, i] for i in range(ndim)] + ret[slices] = point_inds + return ret + + +def generate_voxel2pinds(sparse_tensor): + device = sparse_tensor.indices.device + batch_size = sparse_tensor.batch_size + spatial_shape = sparse_tensor.spatial_shape + indices = sparse_tensor.indices.long() + point_indices = torch.arange(indices.shape[0], device=device, dtype=torch.int32) + output_shape = [batch_size] + list(spatial_shape) + v2pinds_tensor = scatter_point_inds(indices, point_indices, output_shape) + return v2pinds_tensor + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/spa/utils/dist.py b/spa/utils/dist.py new file mode 100755 index 0000000..d422de8 --- /dev/null +++ b/spa/utils/dist.py @@ -0,0 +1,25 @@ +import torch.distributed as dist + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 diff --git a/spa/utils/fp16_utils.py b/spa/utils/fp16_utils.py new file mode 100644 index 0000000..ac44407 --- /dev/null +++ b/spa/utils/fp16_utils.py @@ -0,0 +1,499 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import warnings +from collections import abc +from inspect import getfullargspec +from typing import Callable, Iterable, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from packaging.version import parse +from torch.nn.parameter import Parameter + +TORCH_VERSION = torch.__version__ + +# from .dist_utils import allreduce_grads as _allreduce_grads + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16 + # manually, so the behavior may not be consistent with real amp. + from torch.cuda.amp import autocast +except ImportError: + pass + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Defaults to 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + assert "parrots" not in version_str + version = parse(version_str) + assert version.release, f"failed to parse version {version_str}" + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {"a": -3, "b": -2, "rc": -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn( + f"unknown prerelease version {version.pre[0]}, " + "version checking may go wrong" + ) + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) # type: ignore + else: + release.extend([0, 0]) + return tuple(release) + + +def cast_tensor_type(inputs, src_type: torch.dtype, dst_type: torch.dtype): + """Recursively convert Tensor in inputs from src_type to dst_type. + + Note: + In v1.4.4 and later, ``cast_tersor_type`` will only convert the + torch.Tensor which is consistent with ``src_type`` to the ``dst_type``. + Before v1.4.4, it ignores the ``src_type`` argument, leading to some + potential problems. For example, + ``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all + tensors in inputs to ``torch.half`` including those originally in + ``torch.Int`` or other types, which is not expected. + + Args: + inputs: Inputs that to be casted. + src_type (torch.dtype): Source type.. + dst_type (torch.dtype): Destination type. + + Returns: + The same type with inputs, but all contained Tensors have been cast. + """ + if isinstance(inputs, nn.Module): + return inputs + elif isinstance(inputs, torch.Tensor): + # we need to ensure that the type of inputs to be casted are the same + # as the argument `src_type`. + return inputs.to(dst_type) if inputs.dtype == src_type else inputs + elif isinstance(inputs, str): + return inputs + elif isinstance(inputs, np.ndarray): + return inputs + elif isinstance(inputs, abc.Mapping): + return type(inputs)( + { # type: ignore + k: cast_tensor_type(v, src_type, dst_type) for k, v in inputs.items() + } + ) + elif isinstance(inputs, abc.Iterable): + return type(inputs)( # type: ignore + cast_tensor_type(item, src_type, dst_type) for item in inputs + ) + else: + return inputs + + +def auto_fp16( + apply_to: Optional[Iterable] = None, + out_fp32: bool = False, + supported_types: tuple = (nn.Module,), +) -> Callable: + """Decorator to enable fp16 training automatically. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If inputs arguments are fp32 tensors, they will + be converted to fp16 automatically. Arguments other than fp32 tensors are + ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp32 (bool): Whether to convert the output back to fp32. + supported_types (tuple): Classes can be decorated by ``auto_fp16``. + `New in version 1.5.0.` + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp16 + >>> @auto_fp16() + >>> def forward(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp16 + >>> @auto_fp16(apply_to=('pred', )) + >>> def do_something(self, pred, others): + >>> pass + """ + + def auto_fp16_wrapper(old_func: Callable) -> Callable: + @functools.wraps(old_func) + def new_func(*args, **kwargs) -> Callable: + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], supported_types): + raise TypeError( + "@auto_fp16 can only be used to decorate the " + f"method of those classes {supported_types}" + ) + if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled): + return old_func(*args, **kwargs) + + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + # NOTE: default args are not taken into consideration + if args: + arg_names = args_info.args[: len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.float, torch.half) + ) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = {} + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.float, torch.half + ) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if TORCH_VERSION != "parrots" and digit_version( + TORCH_VERSION + ) >= digit_version("1.6.0"): + with autocast(enabled=True): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp32: + output = cast_tensor_type(output, torch.half, torch.float) + return output + + return new_func + + return auto_fp16_wrapper + + +def force_fp32(apply_to: Optional[Iterable] = None, out_fp16: bool = False) -> Callable: + """Decorator to convert input arguments to fp32 in force. + + This decorator is useful when you write custom modules and want to support + mixed precision training. If there are some inputs that must be processed + in fp32 mode, then this decorator can handle it. If inputs arguments are + fp16 tensors, they will be converted to fp32 automatically. Arguments other + than fp16 tensors are ignored. If you are using PyTorch >= 1.6, + torch.cuda.amp is used as the backend, otherwise, original mmcv + implementation will be adopted. + + Args: + apply_to (Iterable, optional): The argument names to be converted. + `None` indicates all arguments. + out_fp16 (bool): Whether to convert the output back to fp16. + + Example: + + >>> import torch.nn as nn + >>> class MyModule1(nn.Module): + >>> + >>> # Convert x and y to fp32 + >>> @force_fp32() + >>> def loss(self, x, y): + >>> pass + + >>> import torch.nn as nn + >>> class MyModule2(nn.Module): + >>> + >>> # convert pred to fp32 + >>> @force_fp32(apply_to=('pred', )) + >>> def post_process(self, pred, others): + >>> pass + """ + + def force_fp32_wrapper(old_func): + @functools.wraps(old_func) + def new_func(*args, **kwargs) -> Callable: + # check if the module has set the attribute `fp16_enabled`, if not, + # just fallback to the original method. + if not isinstance(args[0], torch.nn.Module): + raise TypeError( + "@force_fp32 can only be used to decorate the " + "method of nn.Module" + ) + if not (hasattr(args[0], "fp16_enabled") and args[0].fp16_enabled): + return old_func(*args, **kwargs) + # get the arg spec of the decorated method + args_info = getfullargspec(old_func) + # get the argument names to be casted + args_to_cast = args_info.args if apply_to is None else apply_to + # convert the args that need to be processed + new_args = [] + if args: + arg_names = args_info.args[: len(args)] + for i, arg_name in enumerate(arg_names): + if arg_name in args_to_cast: + new_args.append( + cast_tensor_type(args[i], torch.half, torch.float) + ) + else: + new_args.append(args[i]) + # convert the kwargs that need to be processed + new_kwargs = dict() + if kwargs: + for arg_name, arg_value in kwargs.items(): + if arg_name in args_to_cast: + new_kwargs[arg_name] = cast_tensor_type( + arg_value, torch.half, torch.float + ) + else: + new_kwargs[arg_name] = arg_value + # apply converted arguments to the decorated method + if TORCH_VERSION != "parrots" and digit_version( + TORCH_VERSION + ) >= digit_version("1.6.0"): + with autocast(enabled=False): + output = old_func(*new_args, **new_kwargs) + else: + output = old_func(*new_args, **new_kwargs) + # cast the results back to fp32 if necessary + if out_fp16: + output = cast_tensor_type(output, torch.float, torch.half) + return output + + return new_func + + return force_fp32_wrapper + + +# def allreduce_grads(params: List[Parameter], +# coalesce: bool = True, +# bucket_size_mb: int = -1) -> None: +# warnings.warn( +# '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' +# 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads', +# DeprecationWarning) +# _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) + + +def wrap_fp16_model(model: nn.Module) -> None: + """Wrap the FP32 model to FP16. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the + backend, otherwise, original mmcv implementation will be adopted. + + For PyTorch >= 1.6, this function will + 1. Set fp16 flag inside the model to True. + + Otherwise: + 1. Convert FP32 model to FP16. + 2. Remain some necessary layers to be FP32, e.g., normalization layers. + 3. Set `fp16_enabled` flag inside the model to True. + + Args: + model (nn.Module): Model in FP32. + """ + if TORCH_VERSION == "parrots" or digit_version(TORCH_VERSION) < digit_version( + "1.6.0" + ): + # convert model to fp16 + model.half() + # patch the normalization layers to make it work in fp32 mode + patch_norm_fp32(model) + # set `fp16_enabled` flag + for m in model.modules(): + if hasattr(m, "fp16_enabled"): + m.fp16_enabled = True + + +def patch_norm_fp32(module: nn.Module) -> nn.Module: + """Recursively convert normalization layers from FP16 to FP32. + + Args: + module (nn.Module): The modules to be converted in FP16. + + Returns: + nn.Module: The converted module, the normalization layers have been + converted to FP32. + """ + if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): + module.float() + if isinstance(module, nn.GroupNorm) or torch.__version__ < "1.3": + module.forward = patch_forward_method( + module.forward, torch.half, torch.float + ) + for child in module.children(): + patch_norm_fp32(child) + return module + + +def patch_forward_method( + func: Callable, + src_type: torch.dtype, + dst_type: torch.dtype, + convert_output: bool = True, +) -> Callable: + """Patch the forward method of a module. + + Args: + func (callable): The original forward method. + src_type (torch.dtype): Type of input arguments to be converted from. + dst_type (torch.dtype): Type of input arguments to be converted to. + convert_output (bool): Whether to convert the output back to src_type. + + Returns: + callable: The patched forward method. + """ + + def new_forward(*args, **kwargs): + output = func( + *cast_tensor_type(args, src_type, dst_type), + **cast_tensor_type(kwargs, src_type, dst_type), + ) + if convert_output: + output = cast_tensor_type(output, dst_type, src_type) + return output + + return new_forward + + +class LossScaler: + """Class that manages loss scaling in mixed precision training which + supports both dynamic or static mode. + + The implementation refers to + https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py. + Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling. + It's important to understand how :class:`LossScaler` operates. + Loss scaling is designed to combat the problem of underflowing + gradients encountered at long times when training fp16 networks. + Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. + If overflowing gradients are encountered, :class:`FP16_Optimizer` then + skips the update step for this particular iteration/minibatch, + and :class:`LossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients + detected,:class:`LossScaler` increases the loss scale once more. + In this way :class:`LossScaler` attempts to "ride the edge" of always + using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float): Initial loss scale value, default: 2**32. + scale_factor (float): Factor used when adjusting the loss scale. + Default: 2. + mode (str): Loss scaling mode. 'dynamic' or 'static' + scale_window (int): Number of consecutive iterations without an + overflow to wait before increasing the loss scale. Default: 1000. + """ + + def __init__( + self, + init_scale: float = 2**32, + mode: str = "dynamic", + scale_factor: float = 2.0, + scale_window: int = 1000, + ): + self.cur_scale = init_scale + self.cur_iter = 0 + assert mode in ("dynamic", "static"), "mode can only be dynamic or static" + self.mode = mode + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + + def has_overflow(self, params: List[Parameter]) -> bool: + """Check if params contain overflow.""" + if self.mode != "dynamic": + return False + for p in params: + if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data): + return True + return False + + def _has_inf_or_nan(x: torch.Tensor) -> bool: + """Check if params contain NaN.""" + try: + cpu_sum = float(x.float().sum()) + except RuntimeError as instance: + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if ( + cpu_sum == float("inf") + or cpu_sum == -float("inf") + or cpu_sum != cpu_sum + ): + return True + return False + + def update_scale(self, overflow: bool) -> None: + """update the current loss scale value when overflow happens.""" + if self.mode != "dynamic": + return + if overflow: + self.cur_scale = max(self.cur_scale / self.scale_factor, 1) + self.last_overflow_iter = self.cur_iter + else: + if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + def state_dict(self) -> dict: + """Returns the state of the scaler as a :class:`dict`.""" + return dict( + cur_scale=self.cur_scale, + cur_iter=self.cur_iter, + mode=self.mode, + last_overflow_iter=self.last_overflow_iter, + scale_factor=self.scale_factor, + scale_window=self.scale_window, + ) + + def load_state_dict(self, state_dict: dict) -> None: + """Loads the loss_scaler state dict. + + Args: + state_dict (dict): scaler state. + """ + self.cur_scale = state_dict["cur_scale"] + self.cur_iter = state_dict["cur_iter"] + self.mode = state_dict["mode"] + self.last_overflow_iter = state_dict["last_overflow_iter"] + self.scale_factor = state_dict["scale_factor"] + self.scale_window = state_dict["scale_window"] + + @property + def loss_scale(self) -> float: + return self.cur_scale diff --git a/spa/utils/instantiators.py b/spa/utils/instantiators.py new file mode 100755 index 0000000..b1ccf30 --- /dev/null +++ b/spa/utils/instantiators.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from spa.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: list[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: list[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/spa/utils/io.py b/spa/utils/io.py new file mode 100755 index 0000000..aee8244 --- /dev/null +++ b/spa/utils/io.py @@ -0,0 +1,81 @@ +import json +import os +import pickle +from io import BytesIO + +import cv2 +import imageio.v3 as iio +import numpy as np +import torch + + +def load_bytes(filename): + return open(filename, "rb").read() + + +def load_mp4(filename): + return iio.imread(filename, extension=".mp4") + + +def load_text(filename): + return open(filename, "r").read().splitlines() + + +def load_numpy_text(filename): + return np.loadtxt(filename) + + +def listdir(dir): + return os.listdir(dir) + + +def load_json(filename): + return json.load(open(filename)) + + +def dump_json(filename, data, backend="local"): + json.dump(data, open(filename, "w")) + + +def load_numpy_pickle(filename): + array = np.load(filename, allow_pickle=True) + + try: + return array.item() + except: + return array + + +def load_numpy(filename): + return np.load(filename) + + +def load_pickle(filename): + return pickle.load(open(filename, "rb")) + + +def load_image(filename): + if ".jpg" in filename or ".JPG" in filename: + image = iio.imread(filename) + elif ".png" in filename or ".PNG" in filename: # for depth images + image = cv2.imread(filename, cv2.IMREAD_UNCHANGED) + return image + + +def exists(path): + return os.path.exists(path) + + +def isdir(path): + return os.path.isdir(path) + + +def imwrite(img, path): + iio.imwrite(path, img) + + +def load_pth( + path, + map_location="cpu", +): + return torch.load(path, map_location=map_location) diff --git a/spa/utils/logging_utils.py b/spa/utils/logging_utils.py new file mode 100755 index 0000000..6981055 --- /dev/null +++ b/spa/utils/logging_utils.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import Any, Dict + +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +from spa.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/spa/utils/metrics.py b/spa/utils/metrics.py new file mode 100755 index 0000000..f5fa63b --- /dev/null +++ b/spa/utils/metrics.py @@ -0,0 +1,53 @@ +from collections.abc import Sequence + +import torch +import torch.nn as nn + + +class Metrics(nn.Module): + def __init__(self, metrics, input_keys, output_keys): + super().__init__() + self.metrics = metrics + + self.input_keys = input_keys + if not isinstance(self.input_keys, Sequence): + self.input_keys = [self.input_keys] * len(self.metrics) + else: + assert len(self.input_keys) == len(self.metrics) + + self.output_keys = output_keys + if not isinstance(self.output_keys, Sequence): + self.output_keys = [self.output_keys] * len(self.metrics) + else: + assert len(self.output_keys) == len(self.metrics) + + self.metrics = nn.ModuleList(self.metrics) + + def reset(self): + for c in self.metrics: + c.reset() + + def compute(self): + metrics = dict() + for c, out_k in zip(self.metrics, self.output_keys): + metrics[out_k] = c.compute() + return metrics + + def metrics_dict(self): + metrics = dict() + for c, out_k in zip(self.metrics, self.output_keys): + metrics[out_k] = c + return metrics + + @torch.inference_mode() + def forward(self, data_dict): + metrics = dict() + for c, in_k, out_k in zip(self.metrics, self.input_keys, self.output_keys): + if isinstance(in_k, Sequence): + try: + metrics[out_k] = c(*[data_dict[k_] for k_ in in_k]) + except KeyError: + metrics[out_k] = c(data_dict[in_k]) + else: + metrics[out_k] = c(data_dict[in_k]) + return metrics diff --git a/spa/utils/misc.py b/spa/utils/misc.py new file mode 100755 index 0000000..8291b42 --- /dev/null +++ b/spa/utils/misc.py @@ -0,0 +1,85 @@ +import os +import warnings +from collections import abc +from importlib import import_module + +import numpy as np +import torch + + +def make_dirs(dir_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name) + + +def is_seq_of(seq, expected_type, seq_type=None): + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def import_modules_from_strings(imports, allow_failed_imports=False): + """Import modules from the given list of strings. + + Args: + imports (list | str | None): The given module names to be imported. + allow_failed_imports (bool): If True, the failed imports will return + None. Otherwise, an ImportError is raise. Default: False. + + Returns: + list[module] | module | None: The imported modules. + + Examples: + >>> osp, sys = import_modules_from_strings( + ... ['os.path', 'sys']) + >>> import os.path as osp_ + >>> import sys as sys_ + >>> assert osp == osp_ + >>> assert sys == sys_ + """ + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError(f"custom_imports must be a list but got type {type(imports)}") + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def interpolate_linear(target_t: int, t1: int, t2: int, x1: np.ndarray, x2: np.ndarray): + return x1 + (target_t - t1) / (t2 - t1) * (x2 - x1) if t1 != t2 else x1 diff --git a/spa/utils/optimizer.py b/spa/utils/optimizer.py new file mode 100755 index 0000000..567a466 --- /dev/null +++ b/spa/utils/optimizer.py @@ -0,0 +1,276 @@ +import collections.abc +import math +import re +from collections import defaultdict +from copy import deepcopy +from itertools import chain, islice +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from spa.utils import RankedLogger +from spa.utils.registry import Registry + +OPTIMIZERS = Registry("optimizers") + + +OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") +OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") +OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") + +_logger = RankedLogger(__name__, rank_zero_only=True) + +# optimizers to default to multi-tensor +_DEFAULT_FOREACH = { + "lion", +} + +MATCH_PREV_GROUP = (99999,) + + +def group_with_matcher( + named_objects: Iterator[Tuple[str, Any]], + group_matcher: Union[Dict, Callable], + return_values: bool = False, + reverse: bool = False, +): + if isinstance(group_matcher, dict): + # dictionary matcher contains a dict of raw-string regex expr that must be compiled + compiled = [] + for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()): + if mspec is None: + continue + # map all matching specifications into 3-tuple (compiled re, prefix, suffix) + if isinstance(mspec, (tuple, list)): + # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix) + for sspec in mspec: + compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])] + else: + compiled += [(re.compile(mspec), (group_ordinal,), None)] + group_matcher = compiled + + def _get_grouping(name): + if isinstance(group_matcher, (list, tuple)): + for match_fn, prefix, suffix in group_matcher: + r = match_fn.match(name) + if r: + parts = (prefix, r.groups(), suffix) + # map all tuple elem to int for numeric sort, filter out None entries + return tuple(map(float, chain.from_iterable(filter(None, parts)))) + return ( + float("inf"), + ) # un-matched layers (neck, head) mapped to largest ordinal + else: + ord = group_matcher(name) + if not isinstance(ord, collections.abc.Iterable): + return (ord,) + return tuple(ord) + + # map layers into groups via ordinals (ints or tuples of ints) from matcher + grouping = defaultdict(list) + for k, v in named_objects: + grouping[_get_grouping(k)].append(v if return_values else k) + + # remap to integers + layer_id_to_param = defaultdict(list) + lid = -1 + for k in sorted(filter(lambda x: x is not None, grouping.keys())): + if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]: + lid += 1 + layer_id_to_param[lid].extend(grouping[k]) + + if reverse: + assert not return_values, "reverse mapping only sensible for name output" + # output reverse mapping + param_to_layer_id = {} + for lid, lm in layer_id_to_param.items(): + for n in lm: + param_to_layer_id[n] = lid + return param_to_layer_id + + return layer_id_to_param + + +def group_parameters( + module: nn.Module, + group_matcher, + return_values: bool = False, + reverse: bool = False, +): + return group_with_matcher( + module.named_parameters(), + group_matcher, + return_values=return_values, + reverse=reverse, + ) + + +def param_groups_weight_decay( + model: nn.Module, weight_decay=1e-5, no_weight_decay_list=() +): + no_weight_decay_list = set(no_weight_decay_list) + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + no_decay.append(param) + else: + decay.append(param) + + return [ + {"params": no_decay, "weight_decay": 0.0}, + {"params": decay, "weight_decay": weight_decay}, + ] + + +def _group(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def _layer_map(model, layers_per_group=12, num_groups=None): + def _in_head(n, hp): + if not hp: + return True + elif isinstance(hp, (tuple, list)): + return any([n.startswith(hpi) for hpi in hp]) + else: + return n.startswith(hp) + + head_prefix = getattr(model, "pretrained_cfg", {}).get("classifier", None) + names_trunk = [] + names_head = [] + for n, _ in model.named_parameters(): + names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n) + + # group non-head layers + num_trunk_layers = len(names_trunk) + if num_groups is not None: + layers_per_group = -(num_trunk_layers // -num_groups) + names_trunk = list(_group(names_trunk, layers_per_group)) + + num_trunk_groups = len(names_trunk) + layer_map = {n: i for i, l in enumerate(names_trunk) for n in l} + layer_map.update({n: num_trunk_groups for n in names_head}) + return layer_map + + +def param_groups_layer_decay( + model: nn.Module, + weight_decay: float = 0.05, + no_weight_decay_list: Tuple[str] = (), + layer_decay: float = 0.75, + end_layer_decay: Optional[float] = None, + verbose: bool = False, +): + """ + Parameter groups for layer-wise lr decay & weight decay + Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + no_weight_decay_list = set(no_weight_decay_list) + param_group_names = {} # NOTE for debugging + param_groups = {} + + if hasattr(model, "group_matcher"): + # FIXME interface needs more work + layer_map = group_parameters( + model, model.group_matcher(coarse=False), reverse=True + ) + else: + # fallback + layer_map = _layer_map(model) + num_layers = max(layer_map.values()) + 1 + layer_max = num_layers - 1 + layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if param.ndim == 1 or name in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0.0 + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = layer_map.get(name, layer_max) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_groups: + this_scale = layer_scales[layer_id] + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "param_names": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["param_names"].append(name) + param_groups[group_name]["params"].append(param) + + if verbose: + import json + + _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def build_optimizer( + cfg, + model_or_params, + param_group_fn: Optional[Callable] = None, + weight_decay: float = 0.0, + **kwargs, +): + if isinstance(model_or_params, nn.Module): + # a model was passed in, extract parameters and add weight decays to appropriate layers + no_weight_decay = {} + if hasattr(model_or_params, "no_weight_decay"): + no_weight_decay = model_or_params.no_weight_decay() + + if param_group_fn: + parameters = param_group_fn(model_or_params) + elif cfg.get("layer_decay") is not None: + parameters = param_groups_layer_decay( + model_or_params, + weight_decay=cfg.get("weight_decay", 0.0), + layer_decay=cfg.layer_decay, + no_weight_decay_list=no_weight_decay, + verbose=cfg.get("verbose", False), + ) + weight_decay = 0.0 + elif cfg.get("weight_decay", 0.0) and cfg.get("filter_bias_and_bn", True): + parameters = param_groups_weight_decay( + model_or_params, cfg.get("weight_decay", 0.0), no_weight_decay + ) + weight_decay = 0.0 + else: + parameters = model_or_params.parameters() + + opt_args = dict(weight_decay=weight_decay, **kwargs) + + if cfg.get("lr") is not None: + opt_args.setdefault("lr", cfg.lr) + + if cfg.get("foreach") is None: + if cfg.type.lower() in _DEFAULT_FOREACH: + opt_args.setdefault("foreach", True) + else: + opt_args["foreach"] = cfg.foreach + + opt_args["params"] = parameters + opt_args["type"] = cfg.type + + return OPTIMIZERS.build(opt_args) diff --git a/spa/utils/pylogger.py b/spa/utils/pylogger.py new file mode 100755 index 0000000..8cae808 --- /dev/null +++ b/spa/utils/pylogger.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import logging +from collections.abc import Mapping +from typing import Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log( + self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + ) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError( + "The `rank_zero_only.rank` needs to be set before use" + ) + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/spa/utils/registry.py b/spa/utils/registry.py new file mode 100755 index 0000000..2b53d76 --- /dev/null +++ b/spa/utils/registry.py @@ -0,0 +1,325 @@ +import inspect +import warnings +from functools import partial + +from omegaconf import DictConfig, OmegaConf + +from .misc import is_seq_of + + +def build_from_cfg(cfg, registry, default_args=None, **kwargs): + """Build a module from configs dict. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + + Returns: + object: The constructed object. + """ + if not (isinstance(cfg, dict) or isinstance(cfg, DictConfig)): + raise TypeError(f"cfg must be a dict, but got {type(cfg)}") + if "type" not in cfg: + if default_args is None or "type" not in default_args: + raise KeyError( + '`cfg` or `default_args` must contain the key "type", ' + f"but got {cfg}\n{default_args}" + ) + if not isinstance(registry, Registry): + raise TypeError( + "registry must be an mmcv.Registry object, " f"but got {type(registry)}" + ) + if not ( + isinstance(default_args, dict) + or isinstance(default_args, DictConfig) + or default_args is None + ): + raise TypeError( + "default_args must be a dict or None, " f"but got {type(default_args)}" + ) + + args = cfg.copy() + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + + if kwargs is not None: + for name, value in kwargs.items(): + args.setdefault(name, value) + + obj_type = args.pop("type") + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError(f"{obj_type} is not in the {registry.name} registry") + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}") + try: + return obj_cls(**args) + except Exception as e: + # Normal TypeError does not print class name. + raise type(e)(f"{obj_cls.__name__}: {e}") + + +class Registry: + """A registry to map strings to classes. + + Registered object could be built from registry. + Example: + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + >>> resnet = MODELS.build(dict(type='ResNet')) + + Please refer to + https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for + advanced usage. + + Args: + name (str): Registry name. + build_func(func, optional): Build function to construct instance from + Registry, func:`build_from_cfg` is used if neither ``parent`` or + ``build_func`` is specified. If ``parent`` is specified and + ``build_func`` is not given, ``build_func`` will be inherited + from ``parent``. Default: None. + parent (Registry, optional): Parent registry. The class registered in + children registry could be built from parent. Default: None. + scope (str, optional): The scope of registry. It is the key to search + for children registry. If not specified, scope will be the name of + the package where class is defined, e.g. mmdet, mmcls, mmseg. + Default: None. + """ + + def __init__(self, name, build_func=None, parent=None, scope=None): + self._name = name + self._module_dict = dict() + self._children = dict() + self._scope = self.infer_scope() if scope is None else scope + + # self.build_func will be set with the following priority: + # 1. build_func + # 2. parent.build_func + # 3. build_from_cfg + if build_func is None: + if parent is not None: + self.build_func = parent.build_func + else: + self.build_func = build_from_cfg + else: + self.build_func = build_func + if parent is not None: + assert isinstance(parent, Registry) + parent._add_children(self) + self.parent = parent + else: + self.parent = None + + def __len__(self): + return len(self._module_dict) + + def __contains__(self, key): + return self.get(key) is not None + + def __repr__(self): + format_str = ( + self.__class__.__name__ + f"(name={self._name}, " + f"items={self._module_dict})" + ) + return format_str + + @staticmethod + def infer_scope(): + """Infer the scope of registry. + + The name of the package where registry is defined will be returned. + + Example: + # in mmdet/models/backbone/resnet.py + >>> MODELS = Registry('models') + >>> @MODELS.register_module() + >>> class ResNet: + >>> pass + The scope of ``ResNet`` will be ``mmdet``. + + + Returns: + scope (str): The inferred scope name. + """ + # inspect.stack() trace where this function is called, the index-2 + # indicates the frame where `infer_scope()` is called + filename = inspect.getmodule(inspect.stack()[2][0]).__name__ + split_filename = filename.split(".") + return split_filename[0] + + @staticmethod + def split_scope_key(key): + """Split scope and key. + + The first scope will be split from key. + + Examples: + >>> Registry.split_scope_key('mmdet.ResNet') + 'mmdet', 'ResNet' + >>> Registry.split_scope_key('ResNet') + None, 'ResNet' + + Return: + scope (str, None): The first scope. + key (str): The remaining key. + """ + split_index = key.find(".") + if split_index != -1: + return key[:split_index], key[split_index + 1 :] + else: + return None, key + + @property + def name(self): + return self._name + + @property + def scope(self): + return self._scope + + @property + def module_dict(self): + return self._module_dict + + @property + def children(self): + return self._children + + def get(self, key): + """Get the registry record. + + Args: + key (str): The class name in string format. + + Returns: + class: The corresponding class. + """ + scope, real_key = self.split_scope_key(key) + if scope is None or scope == self._scope: + # get from self + if real_key in self._module_dict: + return self._module_dict[real_key] + else: + # get from self._children + if scope in self._children: + return self._children[scope].get(real_key) + else: + # goto root + parent = self.parent + while parent.parent is not None: + parent = parent.parent + return parent.get(key) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def _add_children(self, registry): + """Add children for a registry. + + The ``registry`` will be added as children based on its scope. + The parent registry could build objects from children registry. + + Example: + >>> models = Registry('models') + >>> mmdet_models = Registry('models', parent=models) + >>> @mmdet_models.register_module() + >>> class ResNet: + >>> pass + >>> resnet = models.build(dict(type='mmdet.ResNet')) + """ + + assert isinstance(registry, Registry) + assert registry.scope is not None + assert ( + registry.scope not in self.children + ), f"scope {registry.scope} exists in {self.name} registry" + self.children[registry.scope] = registry + + def _register_module(self, module_class, module_name=None, force=False): + if not inspect.isclass(module_class): + raise TypeError("module must be a class, " f"but got {type(module_class)}") + + if module_name is None: + module_name = module_class.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in self._module_dict: + raise KeyError(f"{name} is already registered " f"in {self.name}") + self._module_dict[name] = module_class + + def deprecated_register_module(self, cls=None, force=False): + warnings.warn( + "The old API of register_module(module, force=False) " + "is deprecated and will be removed, please use the new API " + "register_module(name=None, force=False, module=None) instead." + ) + if cls is None: + return partial(self.deprecated_register_module, force=force) + self._register_module(cls, force=force) + return cls + + def register_module(self, name=None, force=False, module=None): + """Register a module. + + A record will be added to `self._module_dict`, whose key is the class + name or the specified name, and value is the class itself. + It can be used as a decorator or a normal function. + + Example: + >>> backbones = Registry('backbone') + >>> @backbones.register_module() + >>> class ResNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> @backbones.register_module(name='mnet') + >>> class MobileNet: + >>> pass + + >>> backbones = Registry('backbone') + >>> class ResNet: + >>> pass + >>> backbones.register_module(ResNet) + + Args: + name (str | None): The module name to be registered. If not + specified, the class name will be used. + force (bool, optional): Whether to override an existing class with + the same name. Default: False. + module (type): Module class to be registered. + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + # NOTE: This is a walkaround to be compatible with the old api, + # while it may introduce unexpected bugs. + if isinstance(name, type): + return self.deprecated_register_module(name, force=force) + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_seq_of(name, str)): + raise TypeError( + "name must be either of None, an instance of str or a sequence" + f" of str, but got {type(name)}" + ) + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + self._register_module(module_class=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(cls): + self._register_module(module_class=cls, module_name=name, force=force) + return cls + + return _register diff --git a/spa/utils/rich_utils.py b/spa/utils/rich_utils.py new file mode 100755 index 0000000..2e2afef --- /dev/null +++ b/spa/utils/rich_utils.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from collections.abc import Sequence +from pathlib import Path + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from spa.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + ( + queue.append(field) + if field in cfg + else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/spa/utils/scheduler.py b/spa/utils/scheduler.py new file mode 100755 index 0000000..c48724f --- /dev/null +++ b/spa/utils/scheduler.py @@ -0,0 +1,143 @@ +import torch.optim.lr_scheduler as lr_scheduler +from omegaconf import OmegaConf + +from .registry import Registry + +SCHEDULERS = Registry("schedulers") + + +@SCHEDULERS.register_module() +class MultiStepLR(lr_scheduler.MultiStepLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + last_epoch=-1, + verbose=False, + ): + super().__init__( + optimizer=optimizer, + milestones=[rate * total_steps for rate in milestones], + gamma=gamma, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class MultiStepWithWarmupLR(lr_scheduler.LambdaLR): + def __init__( + self, + optimizer, + milestones, + total_steps, + gamma=0.1, + warmup_rate=0.05, + warmup_scale=1e-6, + last_epoch=-1, + verbose=False, + ): + milestones = [rate * total_steps for rate in milestones] + + def multi_step_with_warmup(s): + factor = 1.0 + for i in range(len(milestones)): + if s < milestones[i]: + break + factor *= gamma + + if s <= warmup_rate * total_steps: + warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * ( + 1 - warmup_scale + ) + else: + warmup_coefficient = 1.0 + return warmup_coefficient * factor + + super().__init__( + optimizer=optimizer, + lr_lambda=multi_step_with_warmup, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class PolyLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class ExpLR(lr_scheduler.LambdaLR): + def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + lr_lambda=lambda s: gamma ** (s / total_steps), + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR): + def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False): + super().__init__( + optimizer=optimizer, + T_max=total_steps, + eta_min=eta_min, + last_epoch=last_epoch, + verbose=verbose, + ) + + +@SCHEDULERS.register_module() +class OneCycleLR(lr_scheduler.OneCycleLR): + r"""Torch.optim.lr_scheduler.OneCycleLR, Block total_steps.""" + + def __init__( + self, + optimizer, + max_lr, + total_steps=None, + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False, + ): + if len(optimizer.param_groups) > 1: + max_lr = [max_lr * pg.get("lr_scale", 1.0) for pg in optimizer.param_groups] + + super().__init__( + optimizer=optimizer, + max_lr=max_lr, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + three_phase=three_phase, + last_epoch=last_epoch, + verbose=verbose, + ) + + +def build_scheduler(cfg, optimizer): + cfg.optimizer = optimizer + return SCHEDULERS.build(cfg=OmegaConf.to_container(cfg)) diff --git a/spa/utils/transforms.py b/spa/utils/transforms.py new file mode 100644 index 0000000..691fe09 --- /dev/null +++ b/spa/utils/transforms.py @@ -0,0 +1,75 @@ +import numpy as np +import torch + + +def components_from_spherical_harmonics( + levels: int, + directions: torch.Tensor, +) -> torch.Tensor: + """ + Returns value for each component of spherical harmonics. + + Args: + levels: Number of spherical harmonic levels to compute. + directions: Spherical harmonic coefficients + """ + num_components = levels**2 + components = torch.zeros( + (*directions.shape[:-1], num_components), device=directions.device + ) + + assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}" + assert ( + directions.shape[-1] == 3 + ), f"Direction input should have three dimensions. Got {directions.shape[-1]}" + + x = directions[..., 0] + y = directions[..., 1] + z = directions[..., 2] + + xx = x**2 + yy = y**2 + zz = z**2 + + # l0 + components[..., 0] = 0.28209479177387814 + + # l1 + if levels > 1: + components[..., 1] = 0.4886025119029199 * y + components[..., 2] = 0.4886025119029199 * z + components[..., 3] = 0.4886025119029199 * x + + # l2 + if levels > 2: + components[..., 4] = 1.0925484305920792 * x * y + components[..., 5] = 1.0925484305920792 * y * z + components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999 + components[..., 7] = 1.0925484305920792 * x * z + components[..., 8] = 0.5462742152960396 * (xx - yy) + + # l3 + if levels > 3: + components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy) + components[..., 10] = 2.890611442640554 * x * y * z + components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1) + components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3) + components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1) + components[..., 14] = 1.445305721320277 * z * (xx - yy) + components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy) + + # l4 + if levels > 4: + components[..., 16] = 2.5033429417967046 * x * y * (xx - yy) + components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy) + components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1) + components[..., 19] = 0.6690465435572892 * y * z * (7 * zz - 3) + components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3) + components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3) + components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1) + components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy) + components[..., 24] = 0.6258357354491761 * ( + xx * (xx - 3 * yy) - yy * (3 * xx - yy) + ) + + return components diff --git a/spa/utils/utils.py b/spa/utils/utils.py new file mode 100755 index 0000000..0c9d4a9 --- /dev/null +++ b/spa/utils/utils.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import warnings +from importlib.util import find_spec +from typing import Any, Callable, Dict, Optional, Tuple + +from omegaconf import DictConfig + +from spa.utils import pylogger, rich_utils + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value( + metric_dict: dict[str, Any], metric_name: Optional[str] +) -> Optional[float]: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: If provided, the name of the metric to retrieve. + :return: If a metric name was provided, the value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value