diff --git a/README.md b/README.md
index 295cf95..ef68d6c 100755
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@
[![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/ashleve/lightning-hydra-template#license)
-[**Project Page**](https://haoyizhu.github.io/spa/) | [**Paper**](https://haoyizhu.github.io/spa/static/images/paper.pdf) | [**Arxiv**]() | [**HF Model**](https://huggingface.co/HaoyiZhu/SPA)
+[**Project Page**](https://haoyizhu.github.io/spa/) | [**Paper**](https://haoyizhu.github.io/spa/static/images/paper.pdf) | [**arXiv**](https://arxiv.org/abs/2410.08208) | [**HuggingFace Model**](https://huggingface.co/HaoyiZhu/SPA) | [**Real-World Codebase**](https://github.com/HaoyiZhu/RealRobot)
[Haoyi Zhu](https://www.haoyizhu.site/), [Honghui Yang](https://hhyangcs.github.io/), [Yating Wang](https://scholar.google.com/citations?hl=zh-CN&user=5SuBWh0AAAAJ), [Jiange Yang](https://yangjiangeyjg.github.io/), [Liming Wang](https://wanglimin.github.io/), [Tong He](http://tonghe90.github.io/)
@@ -20,13 +20,352 @@
**SPA** is a novel representation learning framework that emphasizes the importance of **3D spatial awareness in embodied AI**. It leverages **differentiable neural rendering** on multi-view images to endow a vanilla Vision Transformer (ViT) with intrinsic spatial understanding. We also present the most comprehensive evaluation of embodied representation learning to date, covering **268 tasks** across **8 simulators** with diverse policies in both single-task and language-conditioned multi-task scenarios.
+:partying_face: **NEWS**:
+
+- *Oct. 2024:* Codebase and pre-trained checkpoints are released! Paper is available on [arXiv](https://arxiv.org/abs/2410.08208).
+
+## :clipboard: Contents
+
+- [Project Structure](#telescope-project-structure)
+- [Installation](#installation)
+- [Usage](#star2-usage)
+- [Pre-Training](#rocket-pre-training)
+- [SPA Large-Scale Evaluation](#bulb-spa-large-scale-evaluation)
+- [Gotchas](#tada-gotchas)
+- [License](#books-license)
+- [Acknowledgement](#sparkles-acknowledgement)
+- [Citation](#pencil-citation)
+
+## :telescope: Project Structure
+
+Our codebase draws significant inspiration from the excellent [Lightning Hydra Template](https://github.com/ashleve/lightning-hydra-template). The directory structure of this project is organized as follows:
+
+
+Show directory structure
+
+```
+├── .github <- Github Actions workflows
+│
+├── configs <- Hydra configs
+│ ├── callbacks <- Callbacks configs
+│ ├── data <- Data configs
+│ ├── debug <- Debugging configs
+│ ├── experiment <- Experiment configs
+│ ├── extras <- Extra utilities configs
+│ ├── hydra <- Hydra configs
+│ ├── local <- Local configs
+│ ├── logger <- Logger configs
+│ ├── model <- Model configs
+│ ├── paths <- Project paths configs
+│ ├── trainer <- Trainer configs
+| |
+│ └── train.yaml <- Main config for training
+│
+├── data <- Project data
+│
+├── logs <- Logs generated by hydra and lightning loggers
+│
+├── scripts <- Shell or Python scripts
+|
+├── spa <- Source code of SPA
+│ ├── data <- Data scripts
+│ ├── models <- Model scripts
+│ ├── utils <- Utility scripts
+│ │
+│ └── train.py <- Run SPA pre-training
+│
+├── .gitignore <- List of files ignored by git
+├── .project-root <- File for inferring the position of project root directory
+├── requirements.txt <- File for installing python dependencies
+├── setup.py <- File for installing project as a package
+└── README.md
+```
+
+
+
+## :hammer: Installation
+
+Basics
+
+```bash
+# clone project
+git clone https://github.com/HaoyiZhu/SPA.git
+cd SPA
+
+# crerate conda environment
+conda create -n spa python=3.11 -y
+conda activate spa
+
+# install PyTorch, please refer to https://pytorch.org/ for other CUDA versions
+# e.g. cuda 11.8:
+pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
+# install basic packages
+pip3 install -r requirements.txt
+```
+
+
+
+SPA
+
+```bash
+# (optional) if you want to use SPA's volume decoder
+cd libs/spa-ops
+pip install -e .
+cd ../..
+
+# install SPA, so that you can import from anywhere
+pip install -e .
+```
+
+
+## :star2: Usage
+
+
+ Example of Using SPA Pre-trained Encoder
+
+We provide pre-trained SPA weights for feature extraction. The checkpoints are available on [🤗Hugging Face](https://huggingface.co/HaoyiZhu/SPA). You don't need to manually download the weights, as SPA will automatically handle this if needed.
+
+```python
+import torch
+
+from spa.models import spa_vit_base_patch16, spa_vit_large_patch16
+
+image = torch.rand((1, 3, 224, 224)) # range in [0, 1]
+
+# Example usage of SPA-Large (recommended)
+# or you can use `spa_vit_base_patch16` for SPA-base
+model = spa_vit_large_patch16(pretrained=True)
+model.eval()
+
+# Freeze the model
+model.freeze()
+
+# (Recommended) move to CUDA
+image = image.cuda()
+model = model.cuda()
+
+# Obtain the [CLS] token
+cls_token = model(image) # torch.Size([1, 1024])
+
+# Obtain the reshaped feature map concatenated with [CLS] token
+feature_map_cat_cls = model(
+ image, feature_map=True, cat_cls=True
+) # torch.Size([1, 2048, 14, 14])
+
+# Obtain the reshaped feature map without [CLS] token
+feature_map_wo_cls = model(
+ image, feature_map=True, cat_cls=False
+) # torch.Size([1, 1024, 14, 14])
+```
+
+> **Note:** The inputs will be automatically resized to `224 x 224` and normalized within the [SPA ViT encoder](spa/models/components/img_backbones/vit.py#L69).
+
+
+
+
+
+## :rocket: Pre-Training
+
+
+ Example of Pre-Training on ScanNet
+
+We give an example on pre-training SPA on the [ScanNet](http://www.scan-net.org/) v2 dataset.
+
+1) Prepare the dataset
+ - Download the [ScanNet](http://www.scan-net.org/) v2 dataset.
+ - Pre-process and extract RGB-D images following [PonderV2](https://github.com/OpenGVLab/PonderV2/blob/main/docs/data_preparation.md#scannet-v2). The preprocessed data should be put under `data/scannet/`.
+ - Pre-generate metadata for fast data loading. The following command will generate metadata under `data/scannet/metadata`.
+ ```bash
+ python scripts/generate_scannet_metadata.py
+ ```
+
+2) Run the following command for pre-training. Remember to modify hyper-parameters such as number of nodes and GPU devices according to your machines.
+ ```bash
+ python spa/train.py experiment=spa_pretrain_vitl trainer.num_nodes=5 trainer.devices=8
+ ```
+
+
+
+## :bulb: SPA Large-Scale Evaluation
+
+
+ TBD
+
+
+
+## :tada: Gotchas
+
+
+ Override any config parameter from command line
+
+This codebase is based on [Hydra](https://github.com/facebookresearch/hydra), which allows for convenient configuration overriding:
+```bash
+python src/train.py trainer.max_epochs=20 seed=300
+```
+> **Note**: You can also add new parameters with `+` sign.
+```bash
+python src/train.py +some_new_param=some_new_value
+```
+
+
+
+
+Train on CPU, GPU, multi-GPU and TPU
+
+```bash
+# train on CPU
+python src/train.py trainer=cpu
+
+# train on 1 GPU
+python src/train.py trainer=gpu
+
+# train on TPU
+python src/train.py +trainer.tpu_cores=8
+
+# train with DDP (Distributed Data Parallel) (4 GPUs)
+python src/train.py trainer=ddp trainer.devices=4
+
+# train with DDP (Distributed Data Parallel) (8 GPUs, 2 nodes)
+python src/train.py trainer=ddp trainer.devices=4 trainer.num_nodes=2
+
+# simulate DDP on CPU processes
+python src/train.py trainer=ddp_sim trainer.devices=2
+
+# accelerate training on mac
+python src/train.py trainer=mps
+```
+
+
+
+
+Train with mixed precision
+
+```bash
+# train with pytorch native automatic mixed precision (AMP)
+python src/train.py trainer=gpu +trainer.precision=16
+```
+
+
+
+
+Use different tricks available in Pytorch Lightning
+
+```yaml
+# gradient clipping may be enabled to avoid exploding gradients
+python src/train.py trainer.gradient_clip_val=0.5
+
+# run validation loop 4 times during a training epoch
+python src/train.py +trainer.val_check_interval=0.25
+
+# accumulate gradients
+python src/train.py trainer.accumulate_grad_batches=10
+
+# terminate training after 12 hours
+python src/train.py +trainer.max_time="00:12:00:00"
+```
+
+> **Note**: PyTorch Lightning provides about [40+ useful trainer flags](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags).
+
+
+
+
+Easily debug
+
+```bash
+# runs 1 epoch in default debugging mode
+# changes logging directory to `logs/debugs/...`
+# sets level of all command line loggers to 'DEBUG'
+# enforces debug-friendly configuration
+python src/train.py debug=default
+
+# run 1 train, val and test loop, using only 1 batch
+python src/train.py debug=fdr
+
+# print execution time profiling
+python src/train.py debug=profiler
+
+# try overfitting to 1 batch
+python src/train.py debug=overfit
+
+# raise exception if there are any numerical anomalies in tensors, like NaN or +/-inf
+python src/train.py +trainer.detect_anomaly=true
+
+# use only 20% of the data
+python src/train.py +trainer.limit_train_batches=0.2 \
++trainer.limit_val_batches=0.2 +trainer.limit_test_batches=0.2
+```
+
+> **Note**: Visit [configs/debug/](configs/debug/) for different debugging configs.
+
+
+
+
+Resume training from checkpoint
+
+```yaml
+python src/train.py ckpt_path="/path/to/ckpt/name.ckpt"
+```
+
+> **Note**: Checkpoint can be either path or URL.
+
+> **Note**: Currently loading ckpt doesn't resume logger experiment, but it will be supported in future Lightning release.
+
+
+
+
+Create a sweep over hyperparameters
+
+```bash
+# this will run 9 experiments one after the other,
+# each with different combination of seed and learning rate
+python src/train.py -m seed=100,200,300 model.optimizer.lr=0.0001,0.00005,0.00001
+```
+
+> **Note**: Hydra composes configs lazily at job launch time. If you change code or configs after launching a job/sweep, the final composed configs might be impacted.
+
+
+
+
+Execute all experiments from folder
+
+```bash
+python src/train.py -m 'exp_maniskill2_act_policy/maniskill2_task@maniskill2_task=glob(*)'
+```
+
+> **Note**: Hydra provides special syntax for controlling behavior of multiruns. Learn more [here](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run). The command above executes all task experiments from [configs/exp_maniskill2_act_policy/maniskill2_task](configs/experiment/).
+
+
+
+
+Execute run for multiple different seeds
+
+```bash
+python src/train.py -m seed=100,200,300 trainer.deterministic=True
+```
+
+> **Note**: `trainer.deterministic=True` makes pytorch more deterministic but impacts the performance.
+
+
+
+For more instructions, refer to the official documentation for [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning), [Hydra](https://github.com/facebookresearch/hydra), and [Lightning Hydra Template](https://github.com/ashleve/lightning-hydra-template).
+
+## :books: License
+
+This repository is released under the [MIT license](LICENSE).
+
+## :sparkles: Acknowledgement
+
+Our work is primarily built upon [PonderV2](https://github.com/OpenGVLab/PonderV2), [UniPAD](https://github.com/Nightmare-n/UniPAD), [Pytorch Lightning](https://github.com/Lightning-AI/pytorch-lightning), [Hydra](https://github.com/facebookresearch/hydra), [Lightning Hydra Template](https://github.com/ashleve/lightning-hydra-template), [RLBench](https://github.com/stepjam/RLBench), [PerAct](https://github.com/peract/peract), [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO), [Meta-Wolrd](https://github.com/Farama-Foundation/Metaworld), [ACT](https://github.com/tonyzhaozh/act), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [DP3](https://github.com/YanjieZe/3D-Diffusion-Policy), [TIMM](https://github.com/huggingface/pytorch-image-models), [VC1](https://github.com/facebookresearch/eai-vc), [R3M](https://github.com/facebookresearch/r3m). We extend our gratitude to all these authors for their generously open-sourced code and their significant contributions to the community.
+
+Contact [Haoyi Zhu](https://www.haoyizhu.site/) if you have any questions or suggestions.
+
## :pencil: Citation
```bib
@article{zhu2024spa,
title = {SPA: 3D Spatial-Awareness Enables Effective Embodied Representation},
author = {Zhu, Haoyi and and Yang, Honghui and Wang, Yating and Yang, Jiange and Wang, Limin and He, Tong},
- journal = {arXiv preprint},
+ journal = {arXiv preprint arxiv:2410.08208},
year = {2024},
}
```
\ No newline at end of file
diff --git a/configs/__init__.py b/configs/__init__.py
new file mode 100755
index 0000000..56bf7f4
--- /dev/null
+++ b/configs/__init__.py
@@ -0,0 +1 @@
+# this file is needed here to include configs when building project as a package
diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml
new file mode 100755
index 0000000..628fa70
--- /dev/null
+++ b/configs/callbacks/default.yaml
@@ -0,0 +1,26 @@
+defaults:
+ - model_checkpoint
+ - early_stopping
+ - model_summary
+ - rich_progress_bar
+ - lr_monitor
+ - device_stats_monitor
+ # - stochastic_weight_averaging
+ - _self_
+
+model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: true
+ auto_insert_metric_name: false
+ save_top_k: 3
+
+early_stopping:
+ monitor: "val/loss"
+ patience: 100
+ mode: "min"
+
+model_summary:
+ max_depth: -1
diff --git a/configs/callbacks/default_ema.yaml b/configs/callbacks/default_ema.yaml
new file mode 100644
index 0000000..8908323
--- /dev/null
+++ b/configs/callbacks/default_ema.yaml
@@ -0,0 +1,32 @@
+defaults:
+ - ema_model_checkpoint
+ - early_stopping
+ - model_summary
+ - rich_progress_bar
+ - lr_monitor
+ - device_stats_monitor
+ - ema
+ - _self_
+
+model_checkpoint:
+ dirpath: ${paths.output_dir}/checkpoints
+ filename: "epoch_{epoch:03d}"
+ monitor: "val/loss"
+ mode: "min"
+ save_last: true
+ auto_insert_metric_name: false
+ save_top_k: 3
+
+early_stopping:
+ monitor: "val/loss"
+ patience: 100
+ mode: "min"
+
+model_summary:
+ max_depth: -1
+
+ema:
+ decay: 0.999
+ validate_original_weights: false
+ every_n_steps: 1
+ cpu_offload: true
diff --git a/configs/callbacks/device_stats_monitor.yaml b/configs/callbacks/device_stats_monitor.yaml
new file mode 100644
index 0000000..ee34f4b
--- /dev/null
+++ b/configs/callbacks/device_stats_monitor.yaml
@@ -0,0 +1,3 @@
+device_stat_monitor:
+ _target_: lightning.pytorch.callbacks.DeviceStatsMonitor
+ cpu_stats: null
\ No newline at end of file
diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml
new file mode 100755
index 0000000..f4c90e0
--- /dev/null
+++ b/configs/callbacks/early_stopping.yaml
@@ -0,0 +1,15 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
+
+early_stopping:
+ _target_: lightning.pytorch.callbacks.EarlyStopping
+ monitor: ??? # quantity to be monitored, must be specified !!!
+ min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
+ patience: 3 # number of checks with no improvement after which training will be stopped
+ verbose: false # verbosity mode
+ mode: "min" # "max" means higher metric value is better, can be also "min"
+ strict: true # whether to crash the training if monitor is not found in the validation metrics
+ check_finite: true # when set True, stops training when the monitor becomes NaN or infinite
+ stopping_threshold: # stop training immediately once the monitored quantity reaches this threshold
+ divergence_threshold: # stop training as soon as the monitored quantity becomes worse than this threshold
+ check_on_train_epoch_end: # whether to run early stopping at the end of the training epoch
+ # log_rank_zero_only: False # this keyword argument isn't available in stable version
diff --git a/configs/callbacks/ema.yaml b/configs/callbacks/ema.yaml
new file mode 100644
index 0000000..02450c7
--- /dev/null
+++ b/configs/callbacks/ema.yaml
@@ -0,0 +1,6 @@
+ema:
+ _target_: spa.utils.callbacks.EMA
+ decay: 0.999
+ validate_original_weights: false
+ every_n_steps: 1
+ cpu_offload: true
\ No newline at end of file
diff --git a/configs/callbacks/ema_model_checkpoint.yaml b/configs/callbacks/ema_model_checkpoint.yaml
new file mode 100755
index 0000000..e1a6360
--- /dev/null
+++ b/configs/callbacks/ema_model_checkpoint.yaml
@@ -0,0 +1,17 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
+
+model_checkpoint:
+ _target_: spa.utils.callbacks.EMAModelCheckpoint
+ dirpath: # directory to save the model file
+ filename: # checkpoint filename
+ monitor: # name of the logged metric which determines when model is improving
+ verbose: false # verbosity mode
+ save_last: # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 5 # save k best models (determined by above metric)
+ mode: "min" # "max" means higher metric value is better, can be also "min"
+ auto_insert_metric_name: true # when True, the checkpoints filenames will contain the metric name
+ save_weights_only: false # if True, then only the model’s weights will be saved
+ every_n_train_steps: # number of training steps between checkpoints
+ train_time_interval: # checkpoints are monitored at the specified time interval
+ every_n_epochs: # number of epochs between checkpoints
+ save_on_train_epoch_end: # whether to run checkpointing at the end of the training epoch or the end of validation
diff --git a/configs/callbacks/lr_monitor.yaml b/configs/callbacks/lr_monitor.yaml
new file mode 100755
index 0000000..1527b90
--- /dev/null
+++ b/configs/callbacks/lr_monitor.yaml
@@ -0,0 +1,5 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#lightning.pytorch.callbacks.LearningRateMonitor
+
+lr_monitor:
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
+ logging_interval: 'step'
diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml
new file mode 100755
index 0000000..0a5a053
--- /dev/null
+++ b/configs/callbacks/model_checkpoint.yaml
@@ -0,0 +1,17 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
+
+model_checkpoint:
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
+ dirpath: # directory to save the model file
+ filename: # checkpoint filename
+ monitor: # name of the logged metric which determines when model is improving
+ verbose: false # verbosity mode
+ save_last: # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+ save_top_k: 10 # save k best models (determined by above metric)
+ mode: "min" # "max" means higher metric value is better, can be also "min"
+ auto_insert_metric_name: true # when True, the checkpoints filenames will contain the metric name
+ save_weights_only: false # if True, then only the model’s weights will be saved
+ every_n_train_steps: # number of training steps between checkpoints
+ train_time_interval: # checkpoints are monitored at the specified time interval
+ every_n_epochs: # number of epochs between checkpoints
+ save_on_train_epoch_end: # whether to run checkpointing at the end of the training epoch or the end of validation
diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml
new file mode 100755
index 0000000..b75981d
--- /dev/null
+++ b/configs/callbacks/model_summary.yaml
@@ -0,0 +1,5 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
+
+model_summary:
+ _target_: lightning.pytorch.callbacks.RichModelSummary
+ max_depth: 1 # the maximum depth of layer nesting that the summary will include
diff --git a/configs/callbacks/none.yaml b/configs/callbacks/none.yaml
new file mode 100755
index 0000000..e69de29
diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml
new file mode 100755
index 0000000..de6f1cc
--- /dev/null
+++ b/configs/callbacks/rich_progress_bar.yaml
@@ -0,0 +1,4 @@
+# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
+
+rich_progress_bar:
+ _target_: lightning.pytorch.callbacks.RichProgressBar
diff --git a/configs/callbacks/stochastic_weight_averaging.yaml b/configs/callbacks/stochastic_weight_averaging.yaml
new file mode 100755
index 0000000..a28826c
--- /dev/null
+++ b/configs/callbacks/stochastic_weight_averaging.yaml
@@ -0,0 +1,5 @@
+# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.StochasticWeightAveraging.html#stochasticweightaveraging
+
+stochastic_weight_averaging:
+ _target_: lightning.pytorch.callbacks.StochasticWeightAveraging
+ swa_lrs: 0.05
diff --git a/configs/data/scannet.yaml b/configs/data/scannet.yaml
new file mode 100644
index 0000000..b5b82d1
--- /dev/null
+++ b/configs/data/scannet.yaml
@@ -0,0 +1,100 @@
+_target_: spa.data.base_cat_datamodule.BaseCatDataModule
+
+batch_size_train: 4 # per device
+batch_size_val: 1
+batch_size_test: 1
+num_workers: 16
+pin_memory: true
+
+data_scannet_processor_cfg:
+ enabled_proc_list:
+ train:
+ - imresize
+ - imcrop
+ - trans_to_local
+ - merge_trans_matrix
+ - filter_depth_outlier_old
+ - calc_ray_from_depth_v2
+ - calc_scene_bbox
+ - calc_voxel_size
+ - sample_ray
+ - collect
+ test:
+ - imresize
+ - imcrop
+ - trans_to_local
+ - merge_trans_matrix
+ - filter_depth_outlier_old
+ - calc_ray_from_depth_v2
+ - calc_scene_bbox
+ - calc_voxel_size
+ - sample_ray
+ - collect
+ proc_config:
+ imresize:
+ mv_consistency: true
+ extra_keys: ["depth"]
+ resize_scale:
+ train: [0.47, 0.5]
+ test: [0.47, 0.47]
+ imcrop: # (480, 640)
+ mv_consistency: true
+ extra_keys: ["depth"]
+ crop_size: [224, 224]
+ crop_pos:
+ train: [[0.0, 1.0], [0.0, 1.0]]
+ test: [[0.5, 0.5], [0.5, 0.5]]
+ trans_to_local:
+ mapping_keys: {}
+ to_local_key: ['world2cam', True]
+ merge_trans_matrix:
+ keys:
+ trans2d_matrix: [['cam2img', False], ]
+ trans3d_matrix: [['world2cam', True], ]
+ trans_normal:
+ key: ['world2cam', True]
+ filter_depth_outlier_old:
+ percentile: 0.05
+ calc_ray_from_depth_v2:
+ placeholder: null
+ calc_scene_bbox:
+ type: dynamic_depth
+ calc_voxel_size:
+ grid_size: [128, 128, 32]
+ sample_ray:
+ collider:
+ type: AABBBoxColliderNp
+ near_plane: 0.1
+ ray_nsample: 512
+ collect:
+ keys:
+ train: ['img', 'dataset_name', 'semantic_img', 'depth', 'ori_shape', 'world2cam', 'cam2img', 'trans2d_matrix', 'scene_name', 'frame_list', 'ray_depth', 'ray_o', 'ray_d', 'ray_p', 'ray_near', 'ray_far', 'ray_idx', 'ray_scale', 'point_cloud_range', 'voxel_size', 'grid_size']
+ test: ['img', 'dataset_name', 'semantic_img', 'depth', 'ori_shape', 'world2cam', 'cam2img', 'trans2d_matrix', 'scene_name', 'frame_list', 'ray_depth', 'ray_o', 'ray_d', 'ray_p', 'ray_near', 'ray_far', 'ray_idx', 'ray_scale', 'point_cloud_range', 'voxel_size', 'grid_size']
+
+train:
+ - _target_: spa.data.components.scannet.ScanNetMultiViewSPAPretrain
+ split: train
+ scene_root: data/scannet/
+ frame_interval: 20
+ downsample_ratio: 0.5
+ num_cameras: 8
+ loop: 3
+ data_processor_cfg: ${data.data_scannet_processor_cfg}
+ batch_max_num_img: 24
+ mode: train
+ scene_box_threshold: 0.05
+ depth_area_threshold: 0.1
+ semantic_size: [1024, 1024]
+
+val:
+ _target_: spa.data.components.scannet.ScanNetMultiViewSPAPretrain
+ split: val
+ scene_root: data/scannet/
+ frame_interval: 20
+ num_cameras: 8
+ loop: 1
+ downsample_ratio: 1
+ data_processor_cfg: ${data.data_scannet_processor_cfg}
+ batch_max_num_img: 24
+ mode: test
+ semantic_size: [1024, 1024]
diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml
new file mode 100755
index 0000000..b90fa15
--- /dev/null
+++ b/configs/debug/default.yaml
@@ -0,0 +1,35 @@
+# @package _global_
+
+# default debugging setup, runs 1 full epoch
+# other debugging configs can inherit from this one
+
+# overwrite task name so debugging logs are stored in separate folder
+task_name: "debug"
+
+# disable callbacks and loggers during debugging
+callbacks:
+logger:
+
+extras:
+ ignore_warnings: false
+ enforce_tags: false
+
+# sets level of all command line loggers to 'DEBUG'
+# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
+hydra:
+ job_logging:
+ root:
+ level: DEBUG
+
+ # use this to also set hydra loggers to 'DEBUG'
+ # verbose: True
+
+trainer:
+ max_epochs: 1
+ accelerator: cpu # debuggers don't like gpus
+ devices: 1 # debuggers don't like multiprocessing
+ detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
+
+data:
+ num_workers: 0 # debuggers don't like multiprocessing
+ pin_memory: false # disable gpu memory pin
diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml
new file mode 100755
index 0000000..7f2d34f
--- /dev/null
+++ b/configs/debug/fdr.yaml
@@ -0,0 +1,9 @@
+# @package _global_
+
+# runs 1 train, 1 validation and 1 test step
+
+defaults:
+ - default
+
+trainer:
+ fast_dev_run: true
diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml
new file mode 100755
index 0000000..514d77f
--- /dev/null
+++ b/configs/debug/limit.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# uses only 1% of the training data and 5% of validation/test data
+
+defaults:
+ - default
+
+trainer:
+ max_epochs: 3
+ limit_train_batches: 0.01
+ limit_val_batches: 0.05
+ limit_test_batches: 0.05
diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml
new file mode 100755
index 0000000..22a6b5b
--- /dev/null
+++ b/configs/debug/overfit.yaml
@@ -0,0 +1,13 @@
+# @package _global_
+
+# overfits to 3 batches
+
+defaults:
+ - default
+
+trainer:
+ max_epochs: 20
+ overfit_batches: 3
+
+# model ckpt and early stopping need to be disabled during overfitting
+callbacks:
diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml
new file mode 100755
index 0000000..2bd7da8
--- /dev/null
+++ b/configs/debug/profiler.yaml
@@ -0,0 +1,12 @@
+# @package _global_
+
+# runs with execution time profiling
+
+defaults:
+ - default
+
+trainer:
+ max_epochs: 1
+ profiler: "simple"
+ # profiler: "advanced"
+ # profiler: "pytorch"
diff --git a/configs/experiment/spa_pretrain_vitb.yaml b/configs/experiment/spa_pretrain_vitb.yaml
new file mode 100755
index 0000000..4063a38
--- /dev/null
+++ b/configs/experiment/spa_pretrain_vitb.yaml
@@ -0,0 +1,70 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /data: scannet
+ - override /trainer: ddp
+ - override /model: spa_pretrain
+ - override /callbacks: default_ema
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+task_name: "spa_pretrain_vitb"
+
+default_ema:
+ ema:
+ decay: 0.999
+ validate_original_weights: false
+ every_n_steps: 1
+ cpu_offload: true
+
+trainer:
+ devices: 8
+ num_nodes: 1
+ max_epochs: 2000
+ check_val_every_n_epoch: 20
+ accelerator: gpu
+ strategy: auto
+ precision: 16-mixed
+ gradient_clip_val: 1.0
+ accumulate_grad_batches: 8
+ sync_batchnorm: false
+ limit_val_batches: 2
+ use_distributed_sampler: false
+
+lr_scale: 1.0
+
+data:
+ batch_size_train: 2
+ num_workers: 16
+
+model:
+ optimizer:
+ type: AdamW
+ lr: ${eval:'0.00001 * ${data.batch_size_train} * ${trainer.accumulate_grad_batches} * ${trainer.devices} * ${trainer.num_nodes} / 8'}
+ weight_decay: 0.04
+ betas: [0.9, 0.95]
+ layer_decay: 0.8
+ verbose: true
+ param_dicts: null
+ lr_scheduler:
+ scheduler:
+ type: OneCycleLR
+ max_lr: ${model.optimizer.lr}
+ pct_start: 0.05
+ anneal_strategy: cos
+ div_factor: 100.0
+ final_div_factor: 1000.0
+ # monitor: val/loss
+ interval: step
+ frequency: 1
+ model:
+ ckpt_name: spa-b
+ img_backbone:
+ embed_dim: 768
+ depth: 12
+ num_heads: 12
+
\ No newline at end of file
diff --git a/configs/experiment/spa_pretrain_vitl.yaml b/configs/experiment/spa_pretrain_vitl.yaml
new file mode 100755
index 0000000..9f6a4d4
--- /dev/null
+++ b/configs/experiment/spa_pretrain_vitl.yaml
@@ -0,0 +1,63 @@
+# @package _global_
+
+# to execute this experiment run:
+# python train.py experiment=example
+
+defaults:
+ - override /data: scannet
+ - override /trainer: ddp
+ - override /model: spa_pretrain
+ - override /callbacks: default_ema
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+task_name: "spa_pretrain_vitl"
+
+default_ema:
+ ema:
+ decay: 0.999
+ validate_original_weights: false
+ every_n_steps: 1
+ cpu_offload: true
+
+trainer:
+ devices: 8
+ num_nodes: 1
+ max_epochs: 2000
+ check_val_every_n_epoch: 20
+ accelerator: gpu
+ strategy: auto
+ precision: 16-mixed
+ gradient_clip_val: 1.0
+ accumulate_grad_batches: 8
+ sync_batchnorm: false
+ limit_val_batches: 2
+ use_distributed_sampler: false
+
+lr_scale: 1.0
+
+data:
+ batch_size_train: 2
+ num_workers: 16
+
+model:
+ optimizer:
+ type: AdamW
+ lr: ${eval:'0.000005 * ${data.batch_size_train} * ${trainer.accumulate_grad_batches} * ${trainer.devices} * ${trainer.num_nodes} / 8'}
+ weight_decay: 0.04
+ betas: [0.9, 0.95]
+ layer_decay: 0.8
+ verbose: true
+ param_dicts: null
+ lr_scheduler:
+ scheduler:
+ type: OneCycleLR
+ max_lr: ${model.optimizer.lr}
+ pct_start: 0.05
+ anneal_strategy: cos
+ div_factor: 100.0
+ final_div_factor: 1000.0
+ # monitor: val/loss
+ interval: step
+ frequency: 1
\ No newline at end of file
diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml
new file mode 100755
index 0000000..7e94b39
--- /dev/null
+++ b/configs/extras/default.yaml
@@ -0,0 +1,8 @@
+# disable python warnings if they annoy you
+ignore_warnings: false
+
+# ask user for tags if none are provided in the config
+enforce_tags: true
+
+# pretty print config tree at the start of the run using Rich library
+print_config: true
diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml
new file mode 100755
index 0000000..a61e9b3
--- /dev/null
+++ b/configs/hydra/default.yaml
@@ -0,0 +1,19 @@
+# https://hydra.cc/docs/configure_hydra/intro/
+
+# enable color logging
+defaults:
+ - override hydra_logging: colorlog
+ - override job_logging: colorlog
+
+# output directory, generated dynamically on each run
+run:
+ dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
+sweep:
+ dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
+ subdir: ${hydra.job.num}
+
+job_logging:
+ handlers:
+ file:
+ # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
+ filename: ${hydra.runtime.output_dir}/${task_name}.log
diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep
new file mode 100755
index 0000000..e69de29
diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml
new file mode 100755
index 0000000..2abfa27
--- /dev/null
+++ b/configs/logger/aim.yaml
@@ -0,0 +1,28 @@
+# https://aimstack.io/
+
+# example usage in lightning module:
+# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py
+
+# open the Aim UI with the following command (run in the folder containing the `.aim` folder):
+# `aim up`
+
+aim:
+ _target_: aim.pytorch_lightning.AimLogger
+ repo: ${paths.root_dir} # .aim folder will be created here
+ # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html#
+
+ # aim allows to group runs under experiment name
+ experiment: # any string, set to "default" if not specified
+
+ train_metric_prefix: "train/"
+ val_metric_prefix: "val/"
+ test_metric_prefix: "test/"
+
+ # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.)
+ system_tracking_interval: 10 # set to null to disable system metrics tracking
+
+ # enable/disable logging of system params such as installed packages, git info, env vars, etc.
+ log_system_params: true
+
+ # enable/disable tracking console logs (default value is true)
+ capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550
diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml
new file mode 100755
index 0000000..a669d67
--- /dev/null
+++ b/configs/logger/comet.yaml
@@ -0,0 +1,12 @@
+# https://www.comet.ml
+
+comet:
+ _target_: lightning.pytorch.loggers.comet.CometLogger
+ api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
+ save_dir: "${paths.output_dir}"
+ project_name: "lightning-hydra-template"
+ rest_api_key:
+ # experiment_name: ""
+ experiment_key: # set to resume experiment
+ offline: false
+ prefix: ""
diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml
new file mode 100755
index 0000000..fa028e9
--- /dev/null
+++ b/configs/logger/csv.yaml
@@ -0,0 +1,7 @@
+# csv logger built in lightning
+
+csv:
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
+ save_dir: "${paths.output_dir}"
+ name: "csv/"
+ prefix: ""
diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml
new file mode 100755
index 0000000..934ff79
--- /dev/null
+++ b/configs/logger/many_loggers.yaml
@@ -0,0 +1,9 @@
+# train with many loggers at once
+
+defaults:
+ # - comet.yaml
+ - csv
+ # - mlflow.yaml
+ # - neptune.yaml
+ - tensorboard
+ - wandb
diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml
new file mode 100755
index 0000000..ed2535a
--- /dev/null
+++ b/configs/logger/mlflow.yaml
@@ -0,0 +1,12 @@
+# https://mlflow.org
+
+mlflow:
+ _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger
+ # experiment_name: ""
+ # run_name: ""
+ tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
+ tags:
+ # save_dir: "./mlruns"
+ prefix: ""
+ artifact_location:
+ # run_id: ""
diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml
new file mode 100755
index 0000000..dfd7f00
--- /dev/null
+++ b/configs/logger/neptune.yaml
@@ -0,0 +1,9 @@
+# https://neptune.ai
+
+neptune:
+ _target_: lightning.pytorch.loggers.neptune.NeptuneLogger
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
+ project: username/lightning-hydra-template
+ # name: ""
+ log_model_checkpoints: true
+ prefix: ""
diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml
new file mode 100755
index 0000000..97775e1
--- /dev/null
+++ b/configs/logger/tensorboard.yaml
@@ -0,0 +1,10 @@
+# https://www.tensorflow.org/tensorboard/
+
+tensorboard:
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+ save_dir: "${paths.output_dir}/tensorboard/"
+ name:
+ log_graph: false
+ default_hp_metric: true
+ prefix: ""
+ version: ${task_name}
diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml
new file mode 100755
index 0000000..c169bd9
--- /dev/null
+++ b/configs/logger/wandb.yaml
@@ -0,0 +1,16 @@
+# https://wandb.ai
+
+wandb:
+ _target_: lightning.pytorch.loggers.wandb.WandbLogger
+ # name: "" # name of the run (normally generated by wandb)
+ save_dir: "${paths.output_dir}"
+ offline: false
+ id: # pass correct id to resume experiment!
+ anonymous: # enable anonymous logging
+ project: "lightning-hydra-template"
+ log_model: false # upload lightning ckpts
+ prefix: "" # a string to put at the beginning of metric keys
+ # entity: "" # set to name of your wandb team
+ group: ""
+ tags: []
+ job_type: ""
diff --git a/configs/model/spa_pretrain.yaml b/configs/model/spa_pretrain.yaml
new file mode 100644
index 0000000..b0b0039
--- /dev/null
+++ b/configs/model/spa_pretrain.yaml
@@ -0,0 +1,185 @@
+_target_: spa.models.spa_pretrain_module.SPAPretrainModule
+
+optimizer:
+ type: AdamW
+ lr: ${eval:'0.00001 * ${data.batch_size_train} * ${trainer.accumulate_grad_batches} * ${trainer.devices} * ${trainer.num_nodes} / 8'}
+ weight_decay: 0.04
+ betas: [0.9, 0.95]
+ layer_decay: 0.8
+ verbose: true
+param_dicts: null
+
+lr_scheduler:
+ scheduler:
+ type: OneCycleLR
+ max_lr: ${model.optimizer.lr}
+ pct_start: 0.05
+ anneal_strategy: cos
+ div_factor: 100.0
+ final_div_factor: 1000.0
+ # monitor: val/loss
+ interval: step
+ frequency: 1
+
+model:
+ _target_: spa.models.components.SPA
+ ckpt_name: spa-l
+ data_processor_cfg:
+ enabled_proc_list:
+ train:
+ - random_photometric_distort
+ - imnormalize
+ test:
+ - imnormalize
+ proc_config:
+ random_photometric_distort:
+ mv_consistency: true
+ brightness: [0.875, 1.125]
+ contrast: [0.5, 1.5]
+ saturation: [0.5, 1.5]
+ hue: [-0.05, 0.05]
+ p: 0.5
+ imnormalize:
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ fp16_enabled_layers: ["img_backbone"]
+ img_backbone:
+ _target_: spa.models.components.img_backbones.vit.SPAViT
+ img_size: 224
+ patch_size: 16
+ in_chans: 3
+ embed_dim: 1024
+ depth: 24
+ num_heads: 16
+ # depth / mae shared decoder
+ decoder_embed_dim: 512
+ mlp_ratio: 4.0
+ qkv_bias: true
+ pretrained_weight:
+ mask_ratio: 0.5
+ out_feature_channels: 64
+ view_transform:
+ _target_: spa.models.components.view_transforms.lss_voxelformer.LSSVoxelformer
+ in_channels: ${model.model.img_backbone.out_feature_channels}
+ grid_size: [128, 128, 32]
+ feature_map_stride: 1 # bev_stride
+ transformer_cfg:
+ num_layers: 1
+ transformerlayers:
+ attn_cfgs:
+ - type: VoxelformerCrossAttention
+ embed_dims: ${model.model.img_backbone.out_feature_channels}
+ num_heads: 8
+ num_points: 5
+ num_levels: 1
+ - type: VoxelformerSelfAttention
+ embed_dims: ${model.model.img_backbone.out_feature_channels}
+ num_convs: 1
+ ffn_cfgs:
+ feedforward_channels: ${eval:'${model.model.img_backbone.out_feature_channels} * 4'}
+ num_fcs: 2
+ ffn_drop: 0.0
+ operation_order: ['cross_attn', 'norm', 'self_attn', 'norm', 'ffn', 'norm']
+ dense_head:
+ _target_: spa.models.components.dense_heads.render_head.RenderHead
+ in_channels: ${model.model.img_backbone.out_feature_channels}
+ feature_map_stride: ${model.model.view_transform.feature_map_stride}
+ val_ray_split: 8192
+ feature_type: 3d_to_3d
+ collider_cfg:
+ type: AABBBoxCollider
+ near_plane: 0.4
+ semantic_cfg:
+ use_semantic: true
+ type: radio
+ img_radio_cfg:
+ model: radio_v2
+ proj_cfg:
+ type: ExpConv3D
+ mid_channels: 64
+ sdf_channels: 33
+ rgb_channels: 27 # 3 * (sh_deg + 1) ** 2, sh_deg = 2
+ group_norm: true
+ semantic_channels: 1280
+ use_convnext_block: true
+ depth: 1
+ render_cfg:
+ type: NeuSModel
+ scene_scale: 1.6
+ field_cfg:
+ type: SDFFieldExp
+ beta_init: 0.3
+ use_gradient: True
+ volume_type: default
+ padding_mode: zeros
+ render_rgb: true
+ render_semantic: true
+ shared_volume: true
+ use_alpha: true
+ sampler_cfg:
+ type: NeuSSampler
+ initial_sampler: UniformSampler
+ num_samples: 72
+ num_samples_importance: 24
+ num_upsample_steps: 1
+ train_stratified: true
+ single_jitter: true
+ loss_cfg:
+ sensor_depth_truncation: 0.25
+ temperature: 0.01
+ weights:
+ eikonal_loss: 0.01
+ free_space_loss: 1.0
+ sdf_loss: 10.0
+ depth_loss: 1.0
+ rgb_loss: 10.0
+ semantic_loss: 1.0
+
+train_metrics:
+ _target_: spa.utils.metrics.Metrics
+ metrics:
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ input_keys:
+ - loss
+ - rgb_loss
+ - psnr
+ - depth_loss
+ - semantic_loss
+ output_keys:
+ - train/loss
+ - train/rgb_loss
+ - train/psnr
+ - train/depth_loss
+ - train/semantic_loss
+val_metrics:
+ _target_: spa.utils.metrics.Metrics
+ metrics:
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ - _target_: torchmetrics.MeanMetric
+ input_keys:
+ - loss
+ - rgb_loss
+ - psnr
+ - depth_loss
+ - semantic_loss
+ output_keys:
+ - val/loss
+ - val/rgb_loss
+ - val/psnr
+ - val/depth_loss
+ - val/semantic_loss
+best_val_metrics:
+ _target_: spa.utils.metrics.Metrics
+ metrics:
+ - _target_: torchmetrics.MinMetric
+ input_keys:
+ - val/loss
+ output_keys:
+ - val/loss
diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml
new file mode 100755
index 0000000..ec81db2
--- /dev/null
+++ b/configs/paths/default.yaml
@@ -0,0 +1,18 @@
+# path to root directory
+# this requires PROJECT_ROOT environment variable to exist
+# you can replace it with "." if you want the root to be the current working directory
+root_dir: ${oc.env:PROJECT_ROOT}
+
+# path to data directory
+data_dir: ${paths.root_dir}/data/
+
+# path to logging directory
+log_dir: ${paths.root_dir}/logs/
+
+# path to output directory, created dynamically by hydra
+# path generation pattern is specified in `configs/hydra/default.yaml`
+# use it to store all files generated during the run, like ckpts and metrics
+output_dir: ${hydra:runtime.output_dir}
+
+# path to working directory
+work_dir: ${hydra:runtime.cwd}
diff --git a/configs/train.yaml b/configs/train.yaml
new file mode 100755
index 0000000..e82ca8b
--- /dev/null
+++ b/configs/train.yaml
@@ -0,0 +1,52 @@
+# @package _global_
+
+# specify here default configuration
+# order of defaults determines the order in which configs override each other
+defaults:
+ - _self_
+ - data: scannet
+ - model: spa
+ - callbacks: default_ema
+ - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
+ - trainer: ddp
+ - paths: default
+ - extras: default
+ - hydra: default
+
+ # experiment configs allow for version control of specific hyperparameters
+ # e.g. best hyperparameters for given model and datamodule
+ - experiment:
+
+ # config for hyperparameter optimization
+ - hparams_search:
+
+ # optional local config for machine/user specific settings
+ # it's optional since it doesn't need to exist and is excluded from version control
+ - optional local: default
+
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
+ - debug:
+
+# task name, determines output directory path
+task_name: "train"
+
+# tags to help you identify your experiments
+# you can overwrite this in experiment configs
+# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
+tags: ["dev"]
+
+# set False to skip model training
+train: true
+
+# evaluate on test set, using best model weights achieved during training
+# lightning chooses best weights based on the metric specified in checkpoint callback
+test: false
+
+# compile model for faster training with pytorch 2.0
+compile: false
+
+# simply provide checkpoint path to resume training
+ckpt_path:
+
+# seed for random number generators in pytorch, numpy and python.random
+seed: 3407
diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml
new file mode 100755
index 0000000..b7d6767
--- /dev/null
+++ b/configs/trainer/cpu.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default
+
+accelerator: cpu
+devices: 1
diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml
new file mode 100755
index 0000000..303d0b9
--- /dev/null
+++ b/configs/trainer/ddp.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - default
+
+strategy: ddp # _find_unused_parameters_true # ddp
+
+accelerator: gpu
+devices: 4
+num_nodes: 1
+sync_batchnorm: true
+max_epochs: 100
+check_val_every_n_epoch: 2
+gradient_clip_val: 0.5
+accumulate_grad_batches: 1
+precision: 32-true
+log_every_n_steps: 50
\ No newline at end of file
diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml
new file mode 100755
index 0000000..8404419
--- /dev/null
+++ b/configs/trainer/ddp_sim.yaml
@@ -0,0 +1,7 @@
+defaults:
+ - default
+
+# simulate DDP on CPU, useful for debugging
+accelerator: cpu
+devices: 2
+strategy: ddp_spawn
diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml
new file mode 100755
index 0000000..292f4d6
--- /dev/null
+++ b/configs/trainer/default.yaml
@@ -0,0 +1,19 @@
+_target_: lightning.pytorch.trainer.Trainer
+
+default_root_dir: ${paths.output_dir}
+
+min_epochs: 1 # prevents early stopping
+max_epochs: 10
+
+accelerator: cpu
+devices: 1
+
+# mixed precision for extra speed-up
+# precision: 16
+
+# perform a validation loop every N training epochs
+check_val_every_n_epoch: 1
+
+# set True to to ensure deterministic results
+# makes training slower but gives more reproducibility than just setting seeds
+deterministic: false
diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml
new file mode 100755
index 0000000..b238951
--- /dev/null
+++ b/configs/trainer/gpu.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default
+
+accelerator: gpu
+devices: 1
diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml
new file mode 100755
index 0000000..1ecf6d5
--- /dev/null
+++ b/configs/trainer/mps.yaml
@@ -0,0 +1,5 @@
+defaults:
+ - default
+
+accelerator: mps
+devices: 1
diff --git a/libs/spa-ops/__init__.py b/libs/spa-ops/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/libs/spa-ops/deform_attn/ms_deform_attn_utils.py b/libs/spa-ops/deform_attn/ms_deform_attn_utils.py
new file mode 100644
index 0000000..406af6f
--- /dev/null
+++ b/libs/spa-ops/deform_attn/ms_deform_attn_utils.py
@@ -0,0 +1,60 @@
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from . import ms_deform_attn_cuda
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ def forward(
+ ctx,
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step,
+ ):
+ ctx.im2col_step = im2col_step
+ output = ms_deform_attn_cuda.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ctx.im2col_step,
+ )
+ ctx.save_for_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ (
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ) = ctx.saved_tensors
+ (
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight,
+ ) = ms_deform_attn_cuda.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output,
+ ctx.im2col_step,
+ )
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn.cpp b/libs/spa-ops/deform_attn/src/ms_deform_attn.cpp
new file mode 100644
index 0000000..2201f63
--- /dev/null
+++ b/libs/spa-ops/deform_attn/src/ms_deform_attn.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn.h b/libs/spa-ops/deform_attn/src/ms_deform_attn.h
new file mode 100644
index 0000000..288f7b7
--- /dev/null
+++ b/libs/spa-ops/deform_attn/src/ms_deform_attn.h
@@ -0,0 +1,57 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include "ms_deform_attn_cuda.h"
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.cu b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000..af37ca5
--- /dev/null
+++ b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.cu
@@ -0,0 +1,153 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data_ptr() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data_ptr(),
+ level_start_index.data_ptr(),
+ sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data_ptr());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data_ptr(),
+ value.data_ptr() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data_ptr(),
+ level_start_index.data_ptr(),
+ sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data_ptr() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
\ No newline at end of file
diff --git a/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.h b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.h
new file mode 100644
index 0000000..c7ae53f
--- /dev/null
+++ b/libs/spa-ops/deform_attn/src/ms_deform_attn_cuda.h
@@ -0,0 +1,30 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/libs/spa-ops/deform_attn/src/ms_deform_im2col_cuda.cuh b/libs/spa-ops/deform_attn/src/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000..f504dd9
--- /dev/null
+++ b/libs/spa-ops/deform_attn/src/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ // const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ // const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ // const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ // const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ // const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ // const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ // const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
\ No newline at end of file
diff --git a/libs/spa-ops/grid_sampler/__init__.py b/libs/spa-ops/grid_sampler/__init__.py
new file mode 100644
index 0000000..e6dfef8
--- /dev/null
+++ b/libs/spa-ops/grid_sampler/__init__.py
@@ -0,0 +1,3 @@
+from .grid_sampler import GridSampler2D, GridSampler3D
+
+__all__ = ["GridSampler2D", "GridSampler3D"]
diff --git a/libs/spa-ops/grid_sampler/grid_sampler.py b/libs/spa-ops/grid_sampler/grid_sampler.py
new file mode 100644
index 0000000..317d1a0
--- /dev/null
+++ b/libs/spa-ops/grid_sampler/grid_sampler.py
@@ -0,0 +1,156 @@
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+from . import grid_sampler_cuda
+
+padding_mode_enum = {"zeros": 0, "border": 1, "reflection": 2}
+
+
+class GridSampler3DBackward(Function):
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ grid,
+ grad_output,
+ padding_mode="zeros",
+ align_corners=True,
+ ):
+ ctx.align_corners = align_corners
+ ctx.padding_mode = padding_mode
+ grad_input, grad_grid = grid_sampler_cuda.grid_sampler_3d_backward(
+ grad_output,
+ input,
+ grid,
+ padding_mode_enum[padding_mode],
+ ctx.align_corners,
+ )
+ ctx.save_for_backward(input, grid, grad_output)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_gard_input, grad2_grad_grid):
+ input, grid, grad_output = ctx.saved_tensors
+
+ (
+ grad_input,
+ grad_grid,
+ grad_grad_output,
+ ) = grid_sampler_cuda.grid_sampler_3d_backward_backward(
+ grad2_gard_input.contiguous(),
+ grad2_grad_grid.contiguous(),
+ input,
+ grid,
+ grad_output,
+ padding_mode_enum[ctx.padding_mode],
+ ctx.align_corners,
+ )
+ return grad_input, grad_grid, grad_grad_output, None, None
+
+
+class GridSampler3D(Function):
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ grid,
+ padding_mode="zeros",
+ align_corners=True,
+ ):
+ output = grid_sampler_cuda.grid_sampler_3d_forward(
+ input,
+ grid,
+ padding_mode_enum[padding_mode],
+ align_corners,
+ )
+ ctx.save_for_backward(input, grid)
+ ctx.align_corners = align_corners
+ ctx.padding_mode = padding_mode
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ input, grid = ctx.saved_tensors
+ d_input, d_grid = GridSampler3DBackward.apply(
+ input,
+ grid,
+ grad_out.contiguous(),
+ ctx.padding_mode,
+ ctx.align_corners,
+ )
+ return d_input, d_grid, None, None
+
+
+class GridSampler2DBackward(Function):
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ grid,
+ grad_output,
+ padding_mode="zeros",
+ align_corners=True,
+ ):
+ ctx.align_corners = align_corners
+ ctx.padding_mode = padding_mode
+ grad_input, grad_grid = grid_sampler_cuda.grid_sampler_2d_backward(
+ grad_output,
+ input,
+ grid,
+ padding_mode_enum[padding_mode],
+ ctx.align_corners,
+ )
+ ctx.save_for_backward(input, grid, grad_output)
+ return grad_input, grad_grid
+
+ @staticmethod
+ def backward(ctx, grad2_gard_input, grad2_grad_grid):
+ input, grid, grad_output = ctx.saved_tensors
+ (
+ grad_input,
+ grad_grid,
+ grad_grad_output,
+ ) = grid_sampler_cuda.grid_sampler_2d_backward_backward(
+ grad2_gard_input.contiguous(),
+ grad2_grad_grid.contiguous(),
+ input,
+ grid,
+ grad_output,
+ padding_mode_enum[ctx.padding_mode],
+ ctx.align_corners,
+ )
+ return grad_input, grad_grid, grad_grad_output, None, None
+
+
+class GridSampler2D(Function):
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ grid,
+ padding_mode="zeros",
+ align_corners=True,
+ ):
+ output = grid_sampler_cuda.grid_sampler_2d_forward(
+ input,
+ grid,
+ padding_mode_enum[padding_mode],
+ align_corners,
+ )
+ ctx.save_for_backward(input, grid)
+ ctx.align_corners = align_corners
+ ctx.padding_mode = padding_mode
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ input, grid = ctx.saved_tensors
+ d_input, d_grid = GridSampler2DBackward.apply(
+ input,
+ grid,
+ grad_out.contiguous(),
+ ctx.padding_mode,
+ ctx.align_corners,
+ )
+ return d_input, d_grid, None, None
diff --git a/libs/spa-ops/grid_sampler/src/grid_sampler.cpp b/libs/spa-ops/grid_sampler/src/grid_sampler.cpp
new file mode 100644
index 0000000..9557ddd
--- /dev/null
+++ b/libs/spa-ops/grid_sampler/src/grid_sampler.cpp
@@ -0,0 +1,133 @@
+#include
+#include
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA Tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+void launch_grid_sampler_2d_forward_kernel(
+ const torch::TensorBase &output, const torch::TensorBase &input, const torch::TensorBase &grid,
+ int64_t padding_mode, bool align_corners);
+
+void launch_grid_sampler_3d_forward_kernel(
+ const torch::TensorBase &output, const torch::TensorBase &input, const torch::TensorBase &grid,
+ int64_t padding_mode, bool align_corners);
+
+void launch_grid_sampler_2d_backward_kernel(
+ const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid,
+ const torch::TensorBase &grad_output, const torch::TensorBase &input,
+ const torch::TensorBase &grid, int64_t padding_mode, bool align_corners);
+
+void launch_grid_sampler_3d_backward_kernel(
+ const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid,
+ const torch::TensorBase &grad_output, const torch::TensorBase &input,
+ const torch::TensorBase &grid, int64_t padding_mode, bool align_corners);
+
+void launch_grid_sampler_2d_backward_backward_kernel(
+ const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid, const torch::TensorBase &grad_grad_output,
+ const torch::TensorBase &grad2_grad_input, const torch::TensorBase &grad2_grad_grid,
+ const torch::TensorBase &input, const torch::TensorBase &grid, const torch::TensorBase &grad_output,
+ int64_t padding_mode, bool align_corners);
+
+void launch_grid_sampler_3d_backward_backward_kernel(
+ const torch::TensorBase &grad_input, const torch::TensorBase &grad_grid, const torch::TensorBase &grad_grad_output,
+ const torch::TensorBase &grad2_grad_input, const torch::TensorBase &grad2_grad_grid,
+ const torch::TensorBase &input, const torch::TensorBase &grid, const torch::TensorBase &grad_output,
+ int64_t padding_mode, bool align_corners);
+
+torch::Tensor grid_sampler_2d_forward(const torch::Tensor& input, const torch::Tensor& grid,
+ int64_t padding_mode, bool align_corners) {
+ CHECK_INPUT(input)
+ CHECK_INPUT(grid)
+ auto in_size = input.sizes();
+ auto grid_size = grid.sizes();
+ auto output = at::empty(
+ {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
+ launch_grid_sampler_2d_forward_kernel(
+ output, input, grid, padding_mode, align_corners);
+ return output;
+}
+
+torch::Tensor grid_sampler_3d_forward(const torch::Tensor &input, const torch::Tensor &grid,
+ int64_t padding_mode, bool align_corners) {
+ CHECK_INPUT(input)
+ CHECK_INPUT(grid)
+ auto in_size = input.sizes();
+ auto grid_size = grid.sizes();
+ auto output = torch::empty(
+ {in_size[0], in_size[1], grid_size[1], grid_size[2], grid_size[3]}, input.options());
+ launch_grid_sampler_3d_forward_kernel(
+ output, input, grid, padding_mode, align_corners);
+ return output;
+}
+
+std::tuple
+grid_sampler_2d_backward(const torch::Tensor &grad_output, const torch::Tensor &input,
+ const torch::Tensor &grid, int64_t padding_mode, bool align_corners) {
+ CHECK_INPUT(grad_output)
+ CHECK_INPUT(input)
+ CHECK_INPUT(grid)
+ auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ launch_grid_sampler_2d_backward_kernel(
+ grad_input, grad_grid, grad_output, input,
+ grid, padding_mode, align_corners);
+ return std::make_tuple(grad_input, grad_grid);
+}
+
+std::tuple
+grid_sampler_3d_backward(const torch::Tensor &grad_output, const torch::Tensor &input,
+ const torch::Tensor &grid, int64_t padding_mode, bool align_corners) {
+ CHECK_INPUT(grad_output)
+ CHECK_INPUT(input)
+ CHECK_INPUT(grid)
+ auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ launch_grid_sampler_3d_backward_kernel(
+ grad_input, grad_grid, grad_output, input,
+ grid, padding_mode, align_corners);
+ return std::make_tuple(grad_input, grad_grid);
+}
+
+std::tuple
+grid_sampler_2d_backward_backward(const torch::Tensor &grad2_grad_input, const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &input, const torch::Tensor &grid, const torch::Tensor &grad_output,
+ int64_t padding_mode, bool align_corners) {
+ CHECK_INPUT(grad2_grad_input)
+ CHECK_INPUT(grad2_grad_grid)
+ CHECK_INPUT(input)
+ CHECK_INPUT(grid)
+ CHECK_INPUT(grad_output)
+ auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ launch_grid_sampler_2d_backward_backward_kernel(grad_input, grad_grid, grad_grad_output, grad2_grad_input, grad2_grad_grid,
+ input, grid, grad_output, padding_mode, align_corners);
+ return std::make_tuple(grad_input, grad_grid, grad_grad_output);
+}
+
+std::tuple
+grid_sampler_3d_backward_backward(const torch::Tensor &grad2_grad_input, const torch::Tensor &grad2_grad_grid,
+ const torch::Tensor &input, const torch::Tensor &grid, const torch::Tensor &grad_output,
+ int64_t padding_mode, bool align_corners) {
+ CHECK_INPUT(grad2_grad_input)
+ CHECK_INPUT(grad2_grad_grid)
+ CHECK_INPUT(input)
+ CHECK_INPUT(grid)
+ CHECK_INPUT(grad_output)
+ auto grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_grid = torch::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ launch_grid_sampler_3d_backward_backward_kernel(grad_input, grad_grid, grad_grad_output, grad2_grad_input, grad2_grad_grid,
+ input, grid, grad_output, padding_mode, align_corners);
+ return std::make_tuple(grad_input, grad_grid, grad_grad_output);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("grid_sampler_3d_forward", &grid_sampler_3d_forward, "grid_sampler_3d_forward");
+ m.def("grid_sampler_3d_backward", &grid_sampler_3d_backward, "grid_sampler_3d_backward");
+ m.def("grid_sampler_3d_backward_backward", &grid_sampler_3d_backward_backward, "grid_sampler_3d_backward_backward");
+ m.def("grid_sampler_2d_forward", &grid_sampler_2d_forward, "grid_sampler_2d_forward");
+ m.def("grid_sampler_2d_backward", &grid_sampler_2d_backward, "grid_sampler_2d_backward");
+ m.def("grid_sampler_2d_backward_backward", &grid_sampler_2d_backward_backward, "grid_sampler_2d_backward_backward");
+}
\ No newline at end of file
diff --git a/libs/spa-ops/grid_sampler/src/grid_sampler_cuda.cu b/libs/spa-ops/grid_sampler/src/grid_sampler_cuda.cu
new file mode 100644
index 0000000..453f8e3
--- /dev/null
+++ b/libs/spa-ops/grid_sampler/src/grid_sampler_cuda.cu
@@ -0,0 +1,1276 @@
+#define TORCH_ASSERT_NO_OPERATORS
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+
+using namespace at::cuda::detail;
+using at::native::detail::GridSamplerPadding;
+
+
+template
+C10_LAUNCH_BOUNDS_1(256)
+__global__ void grid_sampler_2d_kernel(
+ const index_t nthreads,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo output,
+ const GridSamplerPadding padding_mode,
+ bool align_corners) {
+
+ index_t C = input.sizes[1];
+ index_t inp_H = input.sizes[2];
+ index_t inp_W = input.sizes[3];
+ index_t out_H = grid.sizes[1];
+ index_t out_W = grid.sizes[2];
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sH = input.strides[2];
+ index_t inp_sW = input.strides[3];
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sH = grid.strides[1];
+ index_t grid_sW = grid.strides[2];
+ index_t grid_sCoor = grid.strides[3];
+ index_t out_sN = output.strides[0];
+ index_t out_sC = output.strides[1];
+ index_t out_sH = output.strides[2];
+ index_t out_sW = output.strides[3];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t n = index / (out_H * out_W);
+ const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t x = grid.data[grid_offset];
+ scalar_t y = grid.data[grid_offset + grid_sCoor];
+
+ scalar_t ix = at::native::grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
+ scalar_t iy = at::native::grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_nw = static_cast(::floor(ix));
+ index_t iy_nw = static_cast(::floor(iy));
+ index_t ix_ne = ix_nw + 1;
+ index_t iy_ne = iy_nw;
+ index_t ix_sw = ix_nw;
+ index_t iy_sw = iy_nw + 1;
+ index_t ix_se = ix_nw + 1;
+ index_t iy_se = iy_nw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t nw = (ix_se - ix) * (iy_se - iy);
+ scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
+ scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
+ scalar_t se = (ix - ix_nw) * (iy - iy_nw);
+
+ // calculate bilinear weighted pixel value and set output pixel
+ auto inp_ptr_NC = input.data + n * inp_sN;
+ auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
+ for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
+ *out_ptr_NCHW = static_cast(0);
+ if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
+ *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
+ }
+ if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
+ *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
+ }
+ if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
+ *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
+ }
+ if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
+ *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
+ }
+ }
+ }
+}
+
+template
+C10_LAUNCH_BOUNDS_1(512)
+__global__ void grid_sampler_3d_kernel(
+ const index_t nthreads,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo output,
+ const GridSamplerPadding padding_mode,
+ bool align_corners) {
+
+ index_t C = input.sizes[1];
+ index_t inp_D = input.sizes[2];
+ index_t inp_H = input.sizes[3];
+ index_t inp_W = input.sizes[4];
+ index_t out_D = grid.sizes[1];
+ index_t out_H = grid.sizes[2];
+ index_t out_W = grid.sizes[3];
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sD = input.strides[2];
+ index_t inp_sH = input.strides[3];
+ index_t inp_sW = input.strides[4];
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sD = grid.strides[1];
+ index_t grid_sH = grid.strides[2];
+ index_t grid_sW = grid.strides[3];
+ index_t grid_sCoor = grid.strides[4];
+ index_t out_sN = output.strides[0];
+ index_t out_sC = output.strides[1];
+ index_t out_sD = output.strides[2];
+ index_t out_sH = output.strides[3];
+ index_t out_sW = output.strides[4];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t d = (index / (out_H * out_W)) % out_D;
+ const index_t n = index / (out_D * out_H * out_W);
+ const index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y, z co-ordinates from grid
+ scalar_t ix = grid.data[grid_offset];
+ scalar_t iy = grid.data[grid_offset + grid_sCoor];
+ scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
+
+ ix = at::native::grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
+ iy = at::native::grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
+ iz = at::native::grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
+
+ // get corner pixel values from (x, y, z)
+ // for 4d, we used north-east-south-west
+ // for 5d, we add top-bottom
+ index_t ix_tnw = static_cast(::floor(ix));
+ index_t iy_tnw = static_cast(::floor(iy));
+ index_t iz_tnw = static_cast(::floor(iz));
+
+ index_t ix_tne = ix_tnw + 1;
+ index_t iy_tne = iy_tnw;
+ index_t iz_tne = iz_tnw;
+
+ index_t ix_tsw = ix_tnw;
+ index_t iy_tsw = iy_tnw + 1;
+ index_t iz_tsw = iz_tnw;
+
+ index_t ix_tse = ix_tnw + 1;
+ index_t iy_tse = iy_tnw + 1;
+ index_t iz_tse = iz_tnw;
+
+ index_t ix_bnw = ix_tnw;
+ index_t iy_bnw = iy_tnw;
+ index_t iz_bnw = iz_tnw + 1;
+
+ index_t ix_bne = ix_tnw + 1;
+ index_t iy_bne = iy_tnw;
+ index_t iz_bne = iz_tnw + 1;
+
+ index_t ix_bsw = ix_tnw;
+ index_t iy_bsw = iy_tnw + 1;
+ index_t iz_bsw = iz_tnw + 1;
+
+ index_t ix_bse = ix_tnw + 1;
+ index_t iy_bse = iy_tnw + 1;
+ index_t iz_bse = iz_tnw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
+ scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
+ scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
+ scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
+ scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
+ scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
+ scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
+ scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
+
+ auto inp_ptr_NC = input.data + n * inp_sN;
+ auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
+ for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
+ // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
+ // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
+ // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
+ // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
+ *out_ptr_NCDHW = static_cast(0);
+ if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
+ }
+ if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
+ }
+ if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
+ }
+ if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
+ }
+ if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
+ }
+ if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
+ }
+ if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
+ }
+ if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+ *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
+ }
+ }
+ }
+}
+
+// Note [Passing pointer and offset to fastAtomicAdd]
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// For its internal bounds checking, fastAtomicAdd needs to know where the destination address
+// lies relative to the entire tensor, so we pass the base grad_input.data and full offset information,
+// including batch * channel offset (NC_offset).
+
+template
+C10_LAUNCH_BOUNDS_1(256)
+__global__ void grid_sampler_2d_backward_kernel(
+ const index_t nthreads,
+ TensorInfo grad_output,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_input, // initialized to zeros (or unused if input_requires_grad is false)
+ TensorInfo grad_grid, // initialized to empty
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_H = input.sizes[2];
+ index_t inp_W = input.sizes[3];
+ index_t out_H = grid.sizes[1];
+ index_t out_W = grid.sizes[2];
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sH = input.strides[2];
+ index_t inp_sW = input.strides[3];
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sH = grid.strides[1];
+ index_t grid_sW = grid.strides[2];
+ index_t grid_sCoor = grid.strides[3];
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sH = grad_output.strides[2];
+ index_t gOut_sW = grad_output.strides[3];
+ // gInp_* (and NC_offset below) are not really needed if input_requires_grad is false.
+ index_t gInp_sN = grad_input.strides[0];
+ index_t gInp_sC = grad_input.strides[1];
+ index_t gInp_sH = grad_input.strides[2];
+ index_t gInp_sW = grad_input.strides[3];
+
+ index_t gGrid_sW = grad_grid.strides[2];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t n = index / (out_H * out_W);
+ const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t x = grid.data[grid_offset];
+ scalar_t y = grid.data[grid_offset + grid_sCoor];
+
+ // multipliers for gradients on ix and iy
+ scalar_t gix_mult, giy_mult;
+ scalar_t ix = at::native::grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
+ scalar_t iy = at::native::grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_nw = static_cast(::floor(ix));
+ index_t iy_nw = static_cast(::floor(iy));
+ index_t ix_ne = ix_nw + 1;
+ index_t iy_ne = iy_nw;
+ index_t ix_sw = ix_nw;
+ index_t iy_sw = iy_nw + 1;
+ index_t ix_se = ix_nw + 1;
+ index_t iy_se = iy_nw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t nw = (ix_se - ix) * (iy_se - iy);
+ scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
+ scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
+ scalar_t se = (ix - ix_nw) * (iy - iy_nw);
+
+ scalar_t gix = static_cast(0), giy = static_cast(0);
+ scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+ index_t NC_offset = n * gInp_sN;
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+ for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
+ scalar_t gOut = *gOut_ptr_NCHW;
+
+ // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
+ at::native::safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut, NC_offset, grad_input_memory_span);
+ at::native::safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut, NC_offset, grad_input_memory_span);
+ at::native::safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut, NC_offset, grad_input_memory_span);
+ at::native::safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut, NC_offset, grad_input_memory_span);
+
+ // calculate grad_grid
+ if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
+ scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
+ gix -= nw_val * (iy_se - iy) * gOut;
+ giy -= nw_val * (ix_se - ix) * gOut;
+ }
+ if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
+ scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
+ gix += ne_val * (iy_sw - iy) * gOut;
+ giy -= ne_val * (ix - ix_sw) * gOut;
+ }
+ if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
+ scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
+ gix -= sw_val * (iy - iy_ne) * gOut;
+ giy += sw_val * (ix_ne - ix) * gOut;
+ }
+ if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
+ scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
+ gix += se_val * (iy - iy_nw) * gOut;
+ giy += se_val * (ix - ix_nw) * gOut;
+ }
+ }
+
+ // assuming grad_grid is contiguous
+ // thus we can
+ // 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
+ // 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
+ scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
+ gGrid_ptr_NHW[0] = gix_mult * gix;
+ gGrid_ptr_NHW[1] = giy_mult * giy;
+ }
+}
+
+template
+C10_LAUNCH_BOUNDS_1(256)
+__global__ void grid_sampler_3d_backward_kernel(
+ const index_t nthreads,
+ TensorInfo grad_output,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_input, // initialized to zeros (or unused if input_requires_grad is false)
+ TensorInfo grad_grid, // initialized to empty
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_D = input.sizes[2];
+ index_t inp_H = input.sizes[3];
+ index_t inp_W = input.sizes[4];
+ index_t out_D = grid.sizes[1];
+ index_t out_H = grid.sizes[2];
+ index_t out_W = grid.sizes[3];
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sD = input.strides[2];
+ index_t inp_sH = input.strides[3];
+ index_t inp_sW = input.strides[4];
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sD = grid.strides[1];
+ index_t grid_sH = grid.strides[2];
+ index_t grid_sW = grid.strides[3];
+ index_t grid_sCoor = grid.strides[4];
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sD = grad_output.strides[2];
+ index_t gOut_sH = grad_output.strides[3];
+ index_t gOut_sW = grad_output.strides[4];
+ // gInp_* (and NC_offset below) are not really needed if input_requires_grad is false.
+ int64_t gInp_sN = grad_input.strides[0];
+ int64_t gInp_sC = grad_input.strides[1];
+ int64_t gInp_sD = grad_input.strides[2];
+ int64_t gInp_sH = grad_input.strides[3];
+ int64_t gInp_sW = grad_input.strides[4];
+
+ index_t gGrid_sW = grad_grid.strides[3];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t d = (index / (out_H * out_W)) % out_D;
+ const index_t n = index / (out_D * out_H * out_W);
+ const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y, z co-ordinates from grid
+ scalar_t ix = grid.data[grid_offset];
+ scalar_t iy = grid.data[grid_offset + grid_sCoor];
+ scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
+
+ // multipliers for gradients on ix, iy, and iz
+ scalar_t gix_mult, giy_mult, giz_mult;
+ ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+ iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+ iz = at::native::grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
+
+ // get corner pixel values from (x, y, z)
+ // for 4d, we used north-east-south-west
+ // for 5d, we add top-bottom
+ index_t ix_tnw = static_cast(::floor(ix));
+ index_t iy_tnw = static_cast(::floor(iy));
+ index_t iz_tnw = static_cast(::floor(iz));
+
+ index_t ix_tne = ix_tnw + 1;
+ index_t iy_tne = iy_tnw;
+ index_t iz_tne = iz_tnw;
+
+ index_t ix_tsw = ix_tnw;
+ index_t iy_tsw = iy_tnw + 1;
+ index_t iz_tsw = iz_tnw;
+
+ index_t ix_tse = ix_tnw + 1;
+ index_t iy_tse = iy_tnw + 1;
+ index_t iz_tse = iz_tnw;
+
+ index_t ix_bnw = ix_tnw;
+ index_t iy_bnw = iy_tnw;
+ index_t iz_bnw = iz_tnw + 1;
+
+ index_t ix_bne = ix_tnw + 1;
+ index_t iy_bne = iy_tnw;
+ index_t iz_bne = iz_tnw + 1;
+
+ index_t ix_bsw = ix_tnw;
+ index_t iy_bsw = iy_tnw + 1;
+ index_t iz_bsw = iz_tnw + 1;
+
+ index_t ix_bse = ix_tnw + 1;
+ index_t iy_bse = iy_tnw + 1;
+ index_t iz_bse = iz_tnw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
+ scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
+ scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
+ scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
+ scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
+ scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
+ scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
+ scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
+
+ scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0);
+ scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+ index_t NC_offset = n * gInp_sN;
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+ // calculate bilinear weighted pixel value and set output pixel
+ for (index_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) {
+ scalar_t gOut = *gOut_ptr_NCDHW;
+
+ // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
+ at::native::safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut,
+ NC_offset, grad_input_memory_span);
+ // calculate grad_grid
+ if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+ scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
+ gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
+ giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
+ giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
+ }
+ if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+ scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
+ gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
+ giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
+ giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
+ }
+ if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+ scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
+ gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
+ giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
+ giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
+ }
+ if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+ scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
+ gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
+ giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
+ giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
+ }
+ if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+ scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
+ gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
+ giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
+ giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
+ }
+ if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+ scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
+ gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
+ giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
+ giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
+ }
+ if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+ scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
+ gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
+ giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
+ giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
+ }
+ if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+ scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
+ gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
+ giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
+ giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
+ }
+ }
+
+ // assuming grad_grid is contiguous
+ // thus we can
+ // 1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW
+ // 2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
+ scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
+ gGrid_ptr_NDHW[0] = gix_mult * gix;
+ gGrid_ptr_NDHW[1] = giy_mult * giy;
+ gGrid_ptr_NDHW[2] = giz_mult * giz;
+ }
+}
+
+
+template
+ C10_LAUNCH_BOUNDS_1(256)
+ __global__ void grid_sampler_2d_backward_backward_kernel(
+ const index_t nthreads,
+ TensorInfo grad2_grad_input,
+ TensorInfo grad2_grad_grid,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_output,
+ TensorInfo grad_input,
+ TensorInfo grad_grid,
+ TensorInfo grad_grad_output,
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_H = input.sizes[2];
+ index_t inp_W = input.sizes[3];
+
+ index_t out_H = grid.sizes[1];
+ index_t out_W = grid.sizes[2];
+
+ index_t g2inp_sN = grad2_grad_input.strides[0];
+ index_t g2inp_sC = grad2_grad_input.strides[1];
+ index_t g2inp_sH = grad2_grad_input.strides[2];
+ index_t g2inp_sW = grad2_grad_input.strides[3];
+
+ index_t g2grid_sN = grad2_grad_grid.strides[0];
+ index_t g2grid_sH = grad2_grad_grid.strides[1];
+ index_t g2grid_sW = grad2_grad_grid.strides[2];
+ index_t g2grid_sCoor = grad2_grad_grid.strides[3];
+
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sH = grad_output.strides[2];
+ index_t gOut_sW = grad_output.strides[3];
+
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sH = input.strides[2];
+ index_t inp_sW = input.strides[3];
+
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sH = grid.strides[1];
+ index_t grid_sW = grid.strides[2];
+ index_t grid_sCoor = grid.strides[3];
+
+ index_t gInp_sN = grad_input.strides[0];
+ index_t gInp_sC = grad_input.strides[1];
+ index_t gInp_sH = grad_input.strides[2];
+ index_t gInp_sW = grad_input.strides[3];
+
+ index_t gGrid_sW = grad_grid.strides[2];
+
+ index_t ggOut_sN = grad_grad_output.strides[0];
+ index_t ggOut_sC = grad_grad_output.strides[1];
+ index_t ggOut_sH = grad_grad_output.strides[2];
+ index_t ggOut_sW = grad_grad_output.strides[3];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t n = index / (out_H * out_W);
+
+ /* Grid related staff */
+ index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t ix = grid.data[grid_offset];
+ scalar_t iy = grid.data[grid_offset + grid_sCoor];
+
+ // multipliers for gradients on ix and iy
+ scalar_t gix_mult, giy_mult;
+ ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+ iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_nw = static_cast(::floor(ix));
+ index_t iy_nw = static_cast(::floor(iy));
+
+ index_t ix_ne = ix_nw + 1;
+ index_t iy_ne = iy_nw;
+
+ index_t ix_sw = ix_nw;
+ index_t iy_sw = iy_nw + 1;
+
+ index_t ix_se = ix_nw + 1;
+ index_t iy_se = iy_nw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t nw = (ix_se - ix) * (iy_se - iy);
+ scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
+ scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
+ scalar_t se = (ix - ix_nw) * (iy - iy_nw);
+
+ /* grad2_grad_input related init */
+ scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
+
+ /* grad2_grad_grid related init */
+ grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW;
+ scalar_t dx = grad2_grad_grid.data[grid_offset];
+ scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
+
+ dx = dx * gix_mult;
+ dy = dy * giy_mult;
+
+ /* grad_output related init */
+ scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
+
+ /* input related init */
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+
+ /* grad_grad_output related init */
+ scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW;
+
+ /* grad_input related init */
+ index_t NC_offset = n * gInp_sN;
+
+ /* grad_grid related init */
+ scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
+ scalar_t gix = static_cast(0), giy = static_cast(0);
+
+ scalar_t nw_val = static_cast(0), ne_val = static_cast(0), sw_val = static_cast(0), se_val = static_cast(0);
+ scalar_t g2_nw_val = static_cast(0), g2_ne_val = static_cast(0), g2_sw_val = static_cast(0), g2_se_val = static_cast(0);
+
+ for (index_t c = 0; c < C; ++c, g2_inp_ptr_NC += g2inp_sC, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC, ggOut_ptr_NCHW += ggOut_sC) {
+ if (at::native::within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
+ nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
+ g2_nw_val = g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW];
+ }
+ if (at::native::within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
+ ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
+ g2_ne_val = g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW];
+ }
+ if (at::native::within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
+ sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
+ g2_sw_val = g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW];
+ }
+ if (at::native::within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
+ se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
+ g2_se_val = g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW];
+ }
+ // Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
+ // grad2_grad_input * x * y
+ *ggOut_ptr_NCHW = static_cast(0);
+ *ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se;
+
+ scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix);
+ scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw);
+ scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix);
+ scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw);
+
+ // grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
+ *ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val;
+
+ // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x
+ scalar_t gOut = *gOut_ptr_NCHW;
+ //scalar_t val;
+ //val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix));
+ at::native::safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw));
+ at::native::safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix));
+ at::native::safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span);
+ //val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw));
+ at::native::safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span);
+
+ scalar_t dxy = nw_val - ne_val - sw_val + se_val;
+ // Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut
+ gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy)
+ -g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw));
+ gix += gOut * dy * dxy;
+
+ // Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut
+ giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw)
+ +g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw));
+ giy += gOut * dx * dxy;
+ }
+
+ gGrid_ptr_NHW[0] = gix * gix_mult;
+ gGrid_ptr_NHW[1] = giy * giy_mult;
+ }
+}
+
+template
+ C10_LAUNCH_BOUNDS_1(256)
+ __global__ void grid_sampler_3d_backward_backward_kernel(
+ const index_t nthreads,
+ TensorInfo grad2_grad_input,
+ TensorInfo grad2_grad_grid,
+ TensorInfo input,
+ TensorInfo grid,
+ TensorInfo grad_output,
+ TensorInfo grad_input,
+ TensorInfo grad_grid,
+ TensorInfo grad_grad_output,
+ const GridSamplerPadding padding_mode,
+ bool align_corners,
+ const index_t grad_input_memory_span) {
+
+ index_t C = input.sizes[1];
+ index_t inp_D = input.sizes[2];
+ index_t inp_H = input.sizes[3];
+ index_t inp_W = input.sizes[4];
+
+ index_t out_D = grid.sizes[1];
+ index_t out_H = grid.sizes[2];
+ index_t out_W = grid.sizes[3];
+
+ index_t g2inp_sN = grad2_grad_input.strides[0];
+ index_t g2inp_sC = grad2_grad_input.strides[1];
+ index_t g2inp_sD = grad2_grad_input.strides[2];
+ index_t g2inp_sH = grad2_grad_input.strides[3];
+ index_t g2inp_sW = grad2_grad_input.strides[4];
+
+ index_t g2grid_sN = grad2_grad_grid.strides[0];
+ index_t g2grid_sD = grad2_grad_grid.strides[1];
+ index_t g2grid_sH = grad2_grad_grid.strides[2];
+ index_t g2grid_sW = grad2_grad_grid.strides[3];
+ index_t g2grid_sCoor = grad2_grad_grid.strides[4];
+
+ index_t gOut_sN = grad_output.strides[0];
+ index_t gOut_sC = grad_output.strides[1];
+ index_t gOut_sD = grad_output.strides[2];
+ index_t gOut_sH = grad_output.strides[3];
+ index_t gOut_sW = grad_output.strides[4];
+
+ index_t inp_sN = input.strides[0];
+ index_t inp_sC = input.strides[1];
+ index_t inp_sD = input.strides[2];
+ index_t inp_sH = input.strides[3];
+ index_t inp_sW = input.strides[4];
+
+ index_t grid_sN = grid.strides[0];
+ index_t grid_sD = grid.strides[1];
+ index_t grid_sH = grid.strides[2];
+ index_t grid_sW = grid.strides[3];
+ index_t grid_sCoor = grid.strides[4];
+
+ index_t gInp_sN = grad_input.strides[0];
+ index_t gInp_sC = grad_input.strides[1];
+ index_t gInp_sD = grad_input.strides[2];
+ index_t gInp_sH = grad_input.strides[3];
+ index_t gInp_sW = grad_input.strides[4];
+
+ index_t gGrid_sW = grad_grid.strides[3];
+
+ index_t ggOut_sN = grad_grad_output.strides[0];
+ index_t ggOut_sC = grad_grad_output.strides[1];
+ index_t ggOut_sD = grad_grad_output.strides[2];
+ index_t ggOut_sH = grad_grad_output.strides[3];
+ index_t ggOut_sW = grad_grad_output.strides[4];
+
+ CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
+ const index_t w = index % out_W;
+ const index_t h = (index / out_W) % out_H;
+ const index_t d = (index / (out_H * out_W)) % out_D;
+ const index_t n = index / (out_D * out_H * out_W);
+
+ /* Grid related staff */
+ index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
+
+ // get the corresponding input x, y co-ordinates from grid
+ scalar_t ix = grid.data[grid_offset];
+ scalar_t iy = grid.data[grid_offset + grid_sCoor];
+ scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
+
+ // multipliers for gradients on ix and iy
+ scalar_t gix_mult, giy_mult, giz_mult;
+ ix = at::native::grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
+ iy = at::native::grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
+ iz = at::native::grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
+
+ // get NE, NW, SE, SW pixel values from (x, y)
+ index_t ix_tnw = static_cast(::floor(ix));
+ index_t iy_tnw = static_cast(::floor(iy));
+ index_t iz_tnw = static_cast(::floor(iz));
+
+ index_t ix_tne = ix_tnw + 1;
+ index_t iy_tne = iy_tnw;
+ index_t iz_tne = iz_tnw;
+
+ index_t ix_tsw = ix_tnw;
+ index_t iy_tsw = iy_tnw + 1;
+ index_t iz_tsw = iz_tnw;
+
+ index_t ix_tse = ix_tnw + 1;
+ index_t iy_tse = iy_tnw + 1;
+ index_t iz_tse = iz_tnw;
+
+ index_t ix_bnw = ix_tnw;
+ index_t iy_bnw = iy_tnw;
+ index_t iz_bnw = iz_tnw + 1;
+
+ index_t ix_bne = ix_tnw + 1;
+ index_t iy_bne = iy_tnw;
+ index_t iz_bne = iz_tnw + 1;
+
+ index_t ix_bsw = ix_tnw;
+ index_t iy_bsw = iy_tnw + 1;
+ index_t iz_bsw = iz_tnw + 1;
+
+ index_t ix_bse = ix_tnw + 1;
+ index_t iy_bse = iy_tnw + 1;
+ index_t iz_bse = iz_tnw + 1;
+
+ // get surfaces to each neighbor:
+ scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
+ scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
+ scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
+ scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
+ scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
+ scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
+ scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
+ scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
+
+ /* grad2_grad_input related init */
+ scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
+
+ /* grad2_grad_grid related init */
+ grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW;
+ scalar_t dx = grad2_grad_grid.data[grid_offset];
+ scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
+ scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor];
+
+ dx = dx * gix_mult;
+ dy = dy * giy_mult;
+ dz = dz * giz_mult;
+
+ /* grad_output related init */
+ scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
+
+ /* input related init */
+ scalar_t *inp_ptr_NC = input.data + n * inp_sN;
+
+ /* grad_grad_output related init */
+ scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW;
+
+ /* grad_input related init */
+ index_t NC_offset = n * gInp_sN;
+
+ /* grad_grid related init */
+ scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
+ scalar_t gix = static_cast(0), giy = static_cast(0), giz = static_cast(0);
+
+ scalar_t tnw_val = static_cast(0), tne_val = static_cast(0), tsw_val = static_cast(0), tse_val = static_cast(0);
+ scalar_t bnw_val = static_cast(0), bne_val = static_cast(0), bsw_val = static_cast(0), bse_val = static_cast(0);
+ scalar_t g2_tnw_val = static_cast(0), g2_tne_val = static_cast(0), g2_tsw_val = static_cast(0), g2_tse_val = static_cast(0);
+ scalar_t g2_bnw_val = static_cast(0), g2_bne_val = static_cast(0), g2_bsw_val = static_cast(0), g2_bse_val = static_cast(0);
+
+ for (index_t c = 0; c < C; ++c, g2_inp_ptr_NC += g2inp_sC, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCDHW += gOut_sC, ggOut_ptr_NCDHW += ggOut_sC) {
+
+ if (at::native::within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
+ tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
+ g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW];
+ }
+ if (at::native::within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
+ tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
+ g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW];
+ }
+ if (at::native::within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
+ tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
+ g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW];
+ }
+ if (at::native::within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
+ tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
+ g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW];
+ }
+ if (at::native::within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
+ bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
+ g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW];
+ }
+ if (at::native::within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
+ bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
+ g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW];
+ }
+ if (at::native::within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
+ bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
+ g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW];
+ }
+ if (at::native::within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
+ bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
+ g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW];
+ }
+
+ // Computing gradient wrt to grad_output =
+ // grad2_grad_input * x * y * z
+ *ggOut_ptr_NCDHW = static_cast(0);
+ *ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse
+ +g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse;
+
+ // +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y)
+ scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy));
+ scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy));
+ scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne));
+ scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw));
+ scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy));
+ scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy));
+ scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne));
+ scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw));
+
+ *ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp
+ +bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp;
+
+ // Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z +
+ // grad2_grad_grid_z * grad_output * y * z
+ scalar_t gOut = *gOut_ptr_NCDHW;
+
+ at::native::safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+ at::native::safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut,
+ NC_offset, grad_input_memory_span);
+
+ //Computing gradient wrt grid
+ scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz)
+ -tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz)
+ +bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw)
+ -bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw));
+
+ scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy)
+ +tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw)
+ -bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy)
+ -bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw));
+
+ scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw)
+ -tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw)
+ -bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw)
+ +bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw));
+
+
+ // Computing gradient wrt grid_x =
+ // grad2_grad_input * z * y * gOut
+ gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz)
+ -g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz)
+ -g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw)
+ -g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw));
+
+ //+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut
+ gix += gOut * (dz * dxz + dy * dxy);
+
+ // Computing gradient wrt grid_y =
+ // grad2_grad_input * x * z * gOut
+ giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz)
+ +g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz)
+ -g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw)
+ +g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw));
+ //+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut
+ giy += gOut * (dx * dxy + dz * dyz);
+
+ // Computing gradient wrt grid_z =
+ // grad2_grad_input * x * y * gOut
+ giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy)
+ -g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw)
+ +g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy)
+ +g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw));
+ //+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut
+ giz += gOut * (dx * dxz + dy * dyz);
+ }
+
+ gGrid_ptr_NDHW[0] = gix * gix_mult;
+ gGrid_ptr_NDHW[1] = giy * giy_mult;
+ gGrid_ptr_NDHW[2] = giz * giz_mult;
+ }
+}
+
+void launch_grid_sampler_2d_forward_kernel(
+ const at::TensorBase &output, const at::TensorBase &input, const at::TensorBase &grid,
+ int64_t padding_mode, bool align_corners) {
+ auto N = input.size(0);
+ auto H = grid.size(1);
+ auto W = grid.size(2);
+ int64_t count = N * H * W;
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(output)) {
+ grid_sampler_2d_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(output),
+ static_cast(padding_mode),
+ align_corners);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_2d_kernel
+ <<>>(
+ count,
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(output),
+ static_cast(padding_mode),
+ align_corners);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+}
+
+void launch_grid_sampler_3d_forward_kernel(
+ const at::TensorBase &output, const at::TensorBase &input, const at::TensorBase &grid,
+ int64_t padding_mode, bool align_corners) {
+ auto N = input.size(0);
+ auto D = grid.size(1);
+ auto H = grid.size(2);
+ auto W = grid.size(3);
+ int64_t count = N * D * H * W;
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(output)) {
+ grid_sampler_3d_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(output),
+ static_cast(padding_mode),
+ align_corners);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_3d_kernel
+ <<>>(
+ count,
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(output),
+ static_cast(padding_mode),
+ align_corners);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+}
+
+void launch_grid_sampler_2d_backward_kernel(
+ const at::TensorBase &grad_input, const at::TensorBase &grad_grid,
+ const at::TensorBase &grad_output, const at::TensorBase &input,
+ const at::TensorBase &grid, int64_t padding_mode, bool align_corners) {
+ auto N = input.size(0);
+ auto H = grid.size(1);
+ auto W = grid.size(2);
+ int64_t count = N * H * W;
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(grad_output)) {
+ grid_sampler_2d_backward_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ static_cast(grad_input.numel()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_2d_backward_kernel
+ <<>>(
+ count,
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ grad_input.numel());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+}
+
+void launch_grid_sampler_3d_backward_kernel(
+ const at::TensorBase &grad_input, const at::TensorBase &grad_grid,
+ const at::TensorBase &grad_output, const at::TensorBase &input,
+ const at::TensorBase &grid, int64_t padding_mode, bool align_corners) {
+ auto N = input.size(0);
+ auto D = grid.size(1);
+ auto H = grid.size(2);
+ auto W = grid.size(3);
+ int64_t count = N * D * H * W;
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(grad_output)) {
+ grid_sampler_3d_backward_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ static_cast(grad_input.numel()));
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ } else {
+ grid_sampler_3d_backward_kernel
+ <<>>(
+ count,
+ getTensorInfo(grad_output),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ static_cast(padding_mode),
+ align_corners,
+ grad_input.numel());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ });
+ }
+}
+
+void launch_grid_sampler_2d_backward_backward_kernel(
+ const at::TensorBase &grad_input, const at::TensorBase &grad_grid, const at::TensorBase &grad_grad_output,
+ const at::TensorBase &grad2_grad_input, const at::TensorBase &grad2_grad_grid,
+ const at::TensorBase &input, const at::TensorBase &grid, const at::TensorBase &grad_output,
+ int64_t padding_mode, bool align_corners) {
+
+ auto N = input.size(0);
+ auto H = grid.size(1);
+ auto W = grid.size(2);
+ int64_t count = N * H * W;
+ if (count > 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_backward_backward_cuda", [&] {
+ if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
+ canUse32BitIndexMath(grad_output)) {
+ grid_sampler_2d_backward_backward_kernel
+ <<>>(
+ static_cast(count),
+ getTensorInfo(grad2_grad_input),
+ getTensorInfo(grad2_grad_grid),
+ getTensorInfo(input),
+ getTensorInfo(grid),
+ getTensorInfo(grad_output),
+ getTensorInfo(grad_input),
+ getTensorInfo(grad_grid),
+ getTensorInfo