diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..ba5282e Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..868d7eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +*.pyc +*.so +*.pkl +.DS_Store/ +__pycache__/ +*.pt +.idea/ +.vscode/ +build/ +*.egg-info/ +images/ +*.blend1 +.vscode/ +.history/ +tools/ +fb_sweep/ +fb_scripts/ \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..c7540fd --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,45 @@ +# Open Source Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +Using welcoming and inclusive language +Being respectful of differing viewpoints and experiences +Gracefully accepting constructive criticism +Focusing on what is best for the community +Showing empathy towards other community members +Examples of unacceptable behavior by participants include: + +The use of sexualized language or imagery and unwelcome sexual attention or advances +Trolling, insulting/derogatory comments, and personal or political attacks +Public or private harassment +Publishing others’ private information, such as a physical or electronic address, without explicit permission +Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..2701446 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# Contributing to Neural Sparse Voxel Fields (NSVF) +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +## License +By contributing to Neural Sparse Voxel Fields, +you agree that your contributions will be licensed under the LICENSE file in +the root directory of this \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..87cbf53 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..34a94ae --- /dev/null +++ b/README.md @@ -0,0 +1,221 @@ +# Neural Sparse Voxel Fields (NSVF) + +Photo-realistic free-viewpoint rendering of real-world scenes using classical computer graphics techniques is a challenging problem because it requires the difficult step of capturing detailed appearance and geometry models. +Neural rendering is an emerging field that employs deep neural networks to implicitly learn scene representations encapsulating both geometry and appearance from 2D observations with or without a coarse geometry. +However, existing approaches in this field often show blurry renderings or suffer from slow rendering process. We propose [Neural Sparse Voxel Fields (NSVF)](https://arxiv.org/abs/2007.11571), a new neural scene representation for fast and high-quality free-viewpoint rendering. + +Here is the official repo for the paper: + +* [Neural Sparse Voxel Fields (Liu et al., 2020)](https://arxiv.org/abs/2007.11571). + + + +## Requirements and Installation + +This code is implemented in PyTorch using [fairseq framework](https://github.com/pytorch/fairseq). + +The code has been tested on the following system: + +* Python >= 3.6 +* PyTorch 1.4.0 +* [Nvidia apex library](https://github.com/NVIDIA/apex) (optional) +* Nvidia GPU (Tesla V100 32GB) CUDA 10.1 + +Only learning and rendering on GPUs are supported. + +To install, first clone this repo and install all dependencies: + +```bash +pip install -r requirements.txt +``` + +Then, run + +```bash +pip install --editable ./ +``` + +Or if you want to install the code locally, run: + +```bash +python setup.py build_ext --inplace +``` + +## Dataset + +You can download the pre-processed synthetic and real datasets used in our paper. +Please also cite the original papers if you use any of them in your work. + +Dataset | Download Link | Notes on Dataset Split +---|---|--- +Synthetic-NSVF | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip) | 0_\* (training) 1_\* (validation) 2_\* (testing) +[Synthetic-NeRF](https://github.com/bmild/nerf) | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NeRF.zip) | 0_\* (training) 1_\* (validation) 2_\* (testing) +[BlendedMVS](https://github.com/YoYo000/BlendedMVS) | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/BlendedMVS.zip) | 0_\* (training) 1_\* (testing) +[Tanks&Temples](https://www.tanksandtemples.org/) | [download (.zip)](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) | 0_\* (training) 1_\* (testing) + +### Prepare your own dataset + +To prepare a new dataset of a single scene for training and testing, please follow the data structure: + +```bash + +|-- bbox.txt # bounding-box file +|-- intrinsics.txt # 4x4 camera intrinsics +|-- rgb + |-- 0.png # target image for each view + |-- 1.png + ... +|-- pose + |-- 0.txt # camera pose for each view (4x4 matrices) + |-- 1.txt + ... +[optional] +|-- test_traj.txt # camera pose for free-view rendering demonstration (4N x 4) +``` + +where the ``bbox.txt`` file contains a line describing the initial bounding box and voxel size: + +```bash +x_min y_min z_min x_max y_max z_max initial_voxel_size +``` + +Note that the file names of target images and those of the corresponding camera pose files are not required to be exactly the same. However, the orders of these two kinds of files (sorted by string) must match. The datasets are split with view indices. +For example, "``train (0..100)``, ``valid (100..200)`` and ``test (200..400)``" mean the first 100 views for training, 100-199th views for validation, and 200-399th views for testing. + +## Train a new model + +Given the dataset of a single scene (``{DATASET}``), we use the following command for training an NSVF model to synthesize novel views at ``800x800`` pixels, with a batch size of ``4`` images per GPU and ``2048`` rays per image. By default, the code will automatically detect all available GPUs. + +In the following example, we use a pre-defined architecture ``nsvf_base`` with specific arguments: + +* By setting ``--no-sampling-at-reader``, the model only samples pixels in the projected image region of sparse voxels for training. +* By default, we set the ray-marching step size to be the ratio ``1/8 (0.125)`` of the voxel size which is typically described in the ``bbox.txt`` file. +* It is optional to turn on ``--use-octree``. It will build a sparse voxel octree to speed-up the ray-voxel intersection especially when the number of voxels is greater than ``10000``. +* By setting ``--pruning-every-steps`` as ``2500``, the model performs self-pruning at every ``2500`` steps. +* By setting ``--half-voxel-size-at`` and ``--reduce-step-size-at`` as ``5000,25000,75000``, the voxel size and step size are halved at ``5k``, ``25k`` and ``75k``, respectively. + +Note that, although above parameter settings are used for most of the experiments in the paper, it is possible to tune these parameters to achieve better quality. Besides the above parameters, other parameters can also use default settings. + +Besides the architecture ``nsvf_base``, you may check other architectures or define your own architectures in the file ``fairnr/models/nsvf.py``. + +```bash +python -u train.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --train-views "0..100" --view-resolution "800x800" \ + --max-sentences 1 --view-per-batch 4 --pixel-per-view 2048 \ + --no-preload \ + --sampling-on-mask 1.0 --no-sampling-at-reader \ + --valid-views "100..200" --valid-view-resolution "400x400" \ + --valid-view-per-batch 1 \ + --transparent-background "1.0,1.0,1.0" --background-stop-gradient \ + --arch nsvf_base \ + --initial-boundingbox ${DATASET}/bbox.txt \ + --use-octree \ + --raymarching-stepsize-ratio 0.125 \ + --discrete-regularization \ + --color-weight 128.0 --alpha-weight 1.0 \ + --optimizer "adam" --adam-betas "(0.9, 0.999)" \ + --lr 0.001 --lr-scheduler "polynomial_decay" --total-num-update 150000 \ + --criterion "srn_loss" --clip-norm 0.0 \ + --num-workers 0 \ + --seed 2 \ + --save-interval-updates 500 --max-update 150000 \ + --virtual-epoch-steps 5000 --save-interval 1 \ + --half-voxel-size-at "5000,25000,75000" \ + --reduce-step-size-at "5000,25000,75000" \ + --pruning-every-steps 2500 \ + --keep-interval-updates 5 --keep-last-epochs 5 \ + --log-format simple --log-interval 1 \ + --save-dir ${SAVE} \ + --tensorboard-logdir ${SAVE}/tensorboard \ + | tee -a $SAVE/train.log +``` + +The checkpoints are saved in ``{SAVE}``. You can launch tensorboard to check training progress: + +```bash +tensorboard --logdir=${SAVE}/tensorboard --port=10000 +``` + +There are more examples of training scripts to reproduce the results of our paper under [examples](./examples/train/). + +## Evaluation + +Once the model is trained, the following command is used to evaluate rendering quality on the test views given the ``{MODEL_PATH}``. + +```bash +python validate.py ${DATASET} \ + --user-dir fairnr \ + --valid-views "200..400" \ + --valid-view-resolution "800x800" \ + --no-preload \ + --task single_object_rendering \ + --max-sentences 1 \ + --valid-view-per-batch 1 \ + --path ${MODEL_PATH} \ + --model-overrides '{"chunk_size":512,"raymarching_tolerance":0.01,"tensorboard_logdir":"","eval_lpips":True}' \ +``` + +Note that we override the ``raymarching_tolerance`` to ``0.01`` to enable early termination for rendering speed-up. + +## Free Viewpoint Rendering + +Free-viewpoint rendering can be achieved once a model is trained and a rendering trajectory is specified. For example, the following command is for rendering with a circle trajectory (angular speed 3 degree/frame, 15 frames per GPU). This outputs per-view rendered images and merge the images into a ``.mp4`` video in ``${SAVE}/output`` as follows: + + + +By default, the code can detect all available GPUs. + +```bash +python render.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --path ${MODEL_PATH} \ + --model-overrides '{"chunk_size":512,"raymarching_tolerance":0.01}' \ + --render-beam 1 --render-angular-speed 3 --render-num-frames 15 \ + --render-save-fps 24 \ + --render-resolution "800x800" \ + --render-path-style "circle" \ + --render-path-args "{'radius': 3, 'h': 2, 'axis': 'z', 't0': -2, 'r':-1}" \ + --render-output ${SAVE}/output \ + --render-output-types "color" "depth" "voxel" "normal" --render-combine-output \ + --log-format "simple" +``` + +Our code also supports rendering for given camera poses. +For instance, the following command is for rendering with the camera poses defined in the 200-399th files under folder ``${DATASET}/pose``: + +```bash +python render.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --path ${MODEL_PATH} \ + --model-overrides '{"chunk_size":512,"raymarching_tolerance":0.01}' \ + --render-save-fps 24 \ + --render-resolution "800x800" \ + --render-camera-poses ${DATASET}/pose \ + --render-views "200..400" \ + --render-output ${SAVE}/output \ + --render-output-types "color" "depth" "voxel" "normal" --render-combine-output \ + --log-format "simple" +``` + +The code also supports rendering with camera poses defined in a ``.txt`` file. Please refer to this [example](./examples/render/render_jade.sh). + +## License + +NSVF is MIT-licensed. +The license applies to the pre-trained models as well. + +## Citation + +Please cite as +```bibtex +@article{liu2020neural, + title={Neural Sparse Voxel Fields}, + author={Liu, Lingjie and Gu, Jiatao and Lin, Kyaw Zaw and Chua, Tat-Seng and Theobalt, Christian}, + journal={NeurIPS}, + year={2020} +} +``` diff --git a/docs/figs/framework.png b/docs/figs/framework.png new file mode 100644 index 0000000..c565246 Binary files /dev/null and b/docs/figs/framework.png differ diff --git a/docs/figs/results.gif b/docs/figs/results.gif new file mode 100644 index 0000000..c8cdacd Binary files /dev/null and b/docs/figs/results.gif differ diff --git a/examples/render/render_jade.sh b/examples/render/render_jade.sh new file mode 100755 index 0000000..64e93eb --- /dev/null +++ b/examples/render/render_jade.sh @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# just for debugging +DATA="Jade" +RES="576x768" +ARCH="nsvf_base" +SUFFIX="v1" +DATASET=/private/home/jgu/data/shapenet/release/BlendedMVS/${DATA} +SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA +MODEL=$ARCH$SUFFIX +MODEL_PATH=$SAVE/$MODEL/checkpoint_last.pt + +# additional rendering args +MODELTEMP='{"chunk_size":%d,"raymarching_tolerance":%.3f,"use_octree":True}' +MODELARGS=$(printf "$MODELTEMP" 256 0.0) + +# rendering with pre-defined testing trajectory +python render.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --path ${MODEL_PATH} \ + --render-beam 1 \ + --render-save-fps 24 \ + --render-camera-poses $DATASET/test_traj.txt \ + --model-overrides $MODELARGS \ + --render-resolution $RES \ + --render-output ${SAVE}/$ARCH/output \ + --render-output-types "color" "depth" "voxel" "normal" \ + --render-combine-output --log-format "simple" \ No newline at end of file diff --git a/examples/render/render_wineholder.sh b/examples/render/render_wineholder.sh new file mode 100755 index 0000000..e234801 --- /dev/null +++ b/examples/render/render_wineholder.sh @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# just for debugging +DATA="Wineholder" +RES="800x800" +ARCH="nsvf_base" +SUFFIX="v1" +DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} +SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA +MODEL=$ARCH$SUFFIX +MODEL_PATH=$SAVE/$MODEL/checkpoint_last.pt + +# CUDA_VISIBLE_DEVICES=0 \ +python render.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --path ${MODEL_PATH} \ + --render-beam 1 \ + --render-save-fps 24 \ + --render-camera-poses ${DATASET}/pose \ + --render-views "200..400" \ + --model-overrides '{"chunk_size":256,"raymarching_tolerance":0.01}' \ + --render-resolution $RES \ + --render-output ${SAVE}/output \ + --render-output-types "color" "depth" "voxel" "normal" \ + --render-combine-output --log-format "simple" \ No newline at end of file diff --git a/examples/train/train_family.sh b/examples/train/train_family.sh new file mode 100644 index 0000000..d654670 --- /dev/null +++ b/examples/train/train_family.sh @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# just for debugging +DATA="Family" +RES="1080x1920" +VALIDRES="540x960" # the original size maybe too slow for evaluation + # we can optionally half the image size only for validation +ARCH="nsvf_base" +SUFFIX="v1" +DATASET=/private/home/jgu/data/shapenet/release/TanksAndTemple/${DATA} +SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA +MODEL=$ARCH$SUFFIX +mkdir -p $SAVE/$MODEL + +# start training locally +python train.py ${DATASET} \ + --slurm-args ${SLURM_ARGS//[[:space:]]/} \ + --user-dir fairnr \ + --task single_object_rendering \ + --train-views "0..133" \ + --view-resolution $RES \ + --max-sentences 1 \ + --view-per-batch 2 \ + --pixel-per-view 2048 \ + --valid-chunk-size 128 \ + --no-preload\ + --sampling-on-mask 1.0 --no-sampling-at-reader \ + --valid-view-resolution $VALIDRES \ + --valid-views "133..152" \ + --valid-view-per-batch 1 \ + --transparent-background "1.0,1.0,1.0" \ + --background-stop-gradient \ + --arch $ARCH \ + --initial-boundingbox ${DATASET}/bbox.txt \ + --raymarching-stepsize-ratio 0.125 \ + --discrete-regularization \ + --color-weight 128.0 \ + --alpha-weight 1.0 \ + --optimizer "adam" \ + --adam-betas "(0.9, 0.999)" \ + --lr-scheduler "polynomial_decay" \ + --total-num-update 150000 \ + --lr 0.001 \ + --clip-norm 0.0 \ + --criterion "srn_loss" \ + --num-workers 0 \ + --seed 2 \ + --save-interval-updates 500 --max-update 150000 \ + --virtual-epoch-steps 5000 --save-interval 1 \ + --half-voxel-size-at "5000,25000,75000" \ + --reduce-step-size-at "5000,25000,75000" \ + --pruning-every-steps 2500 \ + --keep-interval-updates 5 \ + --log-format simple --log-interval 1 \ + --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ + --save-dir ${SAVE}/${MODEL} diff --git a/examples/train/train_jade.sh b/examples/train/train_jade.sh new file mode 100644 index 0000000..baf3155 --- /dev/null +++ b/examples/train/train_jade.sh @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# just for debugging +DATA="Jade" +RES="576x768" +ARCH="nsvf_base" +SUFFIX="v1" +DATASET=/private/home/jgu/data/shapenet/release/BlendedMVS/${DATA} +SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA +MODEL=$ARCH$SUFFIX +mkdir -p $SAVE/$MODEL + +# start training locally +python train.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --train-views "0..50" \ + --view-resolution $RES \ + --max-sentences 1 \ + --view-per-batch 4 \ + --pixel-per-view 2048 \ + --no-preload \ + --sampling-on-mask 1.0 --no-sampling-at-reader \ + --valid-view-resolution $RES \ + --valid-views "50..58" \ + --valid-view-per-batch 1 \ + --transparent-background "0.0,0.0,0.0" \ + --background-stop-gradient \ + --arch $ARCH \ + --initial-boundingbox ${DATASET}/bbox.txt \ + --raymarching-stepsize-ratio 0.125 \ + --use-octree \ + --discrete-regularization \ + --color-weight 128.0 \ + --alpha-weight 1.0 \ + --optimizer "adam" \ + --adam-betas "(0.9, 0.999)" \ + --lr-scheduler "polynomial_decay" \ + --total-num-update 150000 \ + --lr 0.001 \ + --clip-norm 0.0 \ + --criterion "srn_loss" \ + --num-workers 0 \ + --seed 2 \ + --save-interval-updates 500 --max-update 100000 \ + --virtual-epoch-steps 5000 --save-interval 1 \ + --half-voxel-size-at "5000,25000" \ + --reduce-step-size-at "5000,25000" \ + --pruning-every-steps 2500 \ + --keep-interval-updates 5 \ + --log-format simple --log-interval 1 \ + --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ + --save-dir ${SAVE}/${MODEL} diff --git a/examples/train/train_wineholder.sh b/examples/train/train_wineholder.sh new file mode 100644 index 0000000..3d6eec6 --- /dev/null +++ b/examples/train/train_wineholder.sh @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# just for debugging +DATA="Wineholder" +RES="800x800" +ARCH="nsvf_base" +SUFFIX="v1" +DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} +SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA +MODEL=$ARCH$SUFFIX +mkdir -p $SAVE/$MODEL + +# start training locally +python train.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --train-views "0..100" \ + --view-resolution $RES \ + --max-sentences 1 \ + --view-per-batch 2 \ + --pixel-per-view 2048 \ + --no-preload \ + --sampling-on-mask 1.0 --no-sampling-at-reader \ + --valid-view-resolution $RES \ + --valid-views "100..200" \ + --valid-view-per-batch 1 \ + --transparent-background "1.0,1.0,1.0" \ + --background-stop-gradient \ + --arch $ARCH \ + --initial-boundingbox ${DATASET}/bbox.txt \ + --raymarching-stepsize-ratio 0.125 \ + --use-octree \ + --discrete-regularization \ + --color-weight 128.0 \ + --alpha-weight 1.0 \ + --optimizer "adam" \ + --adam-betas "(0.9, 0.999)" \ + --lr-scheduler "polynomial_decay" \ + --total-num-update 150000 \ + --lr 0.001 \ + --clip-norm 0.0 \ + --criterion "srn_loss" \ + --num-workers 0 \ + --seed 2 \ + --save-interval-updates 500 --max-update 150000 \ + --virtual-epoch-steps 5000 --save-interval 1 \ + --half-voxel-size-at "5000,25000,75000" \ + --reduce-step-size-at "5000,25000,75000" \ + --pruning-every-steps 2500 \ + --keep-interval-updates 5 \ + --log-format simple --log-interval 1 \ + --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ + --save-dir ${SAVE}/${MODEL} diff --git a/examples/train/train_wineholder_with_slurm.sh b/examples/train/train_wineholder_with_slurm.sh new file mode 100644 index 0000000..4df31d0 --- /dev/null +++ b/examples/train/train_wineholder_with_slurm.sh @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# just for debugging +DATA="Wineholder" +RES="800x800" +ARCH="nsvf_base" +SUFFIX="v1" +DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} +SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA +MODEL=$ARCH$SUFFIX +mkdir -p $SAVE/$MODEL + +# By defining the following environment variables +# The code will automatically detect it and trying to submit the code in slurm-based clusters +# We don't need to change the main body of the training code. +export SLURM_ARGS="""{ + 'job-name': '${DATA}-${MODEL}', + 'partition': 'priority', + 'comment': 'NeurIPS2020 open-source', + 'nodes': 1, + 'gpus': 8, + 'output': '$SAVE/$MODEL/train.out', + 'error': '$SAVE/$MODEL/train.stderr.%j', + 'constraint': 'volta32gb', + 'local': False} +""" + +# start training based on SLURM_ARGS +python train.py ${DATASET} \ + --user-dir fairnr \ + --task single_object_rendering \ + --train-views "0..100" \ + --view-resolution $RES \ + --max-sentences 1 \ + --view-per-batch 2 \ + --pixel-per-view 2048 \ + --no-preload \ + --sampling-on-mask 1.0 --no-sampling-at-reader \ + --valid-view-resolution $RES \ + --valid-views "100..200" \ + --valid-view-per-batch 1 \ + --transparent-background "1.0,1.0,1.0" \ + --background-stop-gradient \ + --arch $ARCH \ + --initial-boundingbox ${DATASET}/bbox.txt \ + --raymarching-stepsize-ratio 0.125 \ + --use-octree \ + --discrete-regularization \ + --color-weight 128.0 \ + --alpha-weight 1.0 \ + --optimizer "adam" \ + --adam-betas "(0.9, 0.999)" \ + --lr-scheduler "polynomial_decay" \ + --total-num-update 150000 \ + --lr 0.001 \ + --clip-norm 0.0 \ + --criterion "srn_loss" \ + --num-workers 0 \ + --seed 2 \ + --save-interval-updates 500 --max-update 150000 \ + --virtual-epoch-steps 5000 --save-interval 1 \ + --half-voxel-size-at "5000,25000,75000" \ + --reduce-step-size-at "5000,25000,75000" \ + --pruning-every-steps 2500 \ + --keep-interval-updates 5 \ + --log-format simple --log-interval 1 \ + --tensorboard-logdir ${SAVE}/tensorboard/${MODEL} \ + --save-dir ${SAVE}/${MODEL} diff --git a/examples/valid/valid_wineholder.sh b/examples/valid/valid_wineholder.sh new file mode 100644 index 0000000..7995c79 --- /dev/null +++ b/examples/valid/valid_wineholder.sh @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# just for debugging +DATA="Steamtrain" +RES="800x800" +ARCH="nsvf_base" +SUFFIX="v1" +DATASET=/private/home/jgu/data/shapenet/release/Synthetic_NSVF/${DATA} +SAVE=/checkpoint/jgu/space/neuralrendering/new_release/$DATA +MODEL=$ARCH$SUFFIX +MODEL_PATH=$SAVE/$MODEL/checkpoint_last.pt + +# start validating a trained model with target images. +# CUDA_VISIBLE_DEVICES=0 \ +python validate.py ${DATASET} \ + --user-dir fairnr \ + --valid-views "200..400" \ + --valid-view-resolution "800x800" \ + --no-preload \ + --task single_object_rendering \ + --max-sentences 1 \ + --valid-view-per-batch 1 \ + --path ${MODEL_PATH} \ + --model-overrides '{"chunk_size":1024,"raymarching_tolerance":0.01,"tensorboard_logdir":"","eval_lpips":True}' \ \ No newline at end of file diff --git a/extract.py b/extract.py new file mode 100644 index 0000000..ac1d14c --- /dev/null +++ b/extract.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairnr_cli.extract import cli_main + + +if __name__ == '__main__': + cli_main() diff --git a/fairnr/__init__.py b/fairnr/__init__.py new file mode 100644 index 0000000..a4fc7d7 --- /dev/null +++ b/fairnr/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +class ResetTrainerException(Exception): + pass + + +from . import data, tasks, models, modules, criterions diff --git a/fairnr/clib/__init__.py b/fairnr/clib/__init__.py new file mode 100644 index 0000000..6c811d3 --- /dev/null +++ b/fairnr/clib/__init__.py @@ -0,0 +1,365 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' +from __future__ import ( + division, + absolute_import, + with_statement, + print_function, + unicode_literals, +) +import os, sys +import torch +import torch.nn.functional as F +from torch.autograd import Function +import torch.nn as nn +import sys +import numpy as np + +try: + import builtins +except: + import __builtin__ as builtins + +try: + import fairnr.clib._ext as _ext +except ImportError: + raise ImportError( + "Could not import _ext module.\n" + "Please see the setup instructions in the README" + ) + +class BallRayIntersect(Function): + @staticmethod + def forward(ctx, radius, n_max, points, ray_start, ray_dir): + r""" + + Parameters + ---------- + radius : float + radius of the balls + n_max: int + maximum number of points to intersect. + xyz : torch.Tensor + (B, N, 3) xyz coordinates of the features + new_xyz : torch.Tensor + (B, npoint, 3) centers of the ball query + + Returns + ------- + torch.Tensor + (B, npoint) tensor with the nearest indicies of the features that form the query balls + """ + inds, min_depth, max_depth = _ext.ball_intersect( + ray_start.float(), ray_dir.float(), points.float(), radius, n_max) + min_depth = min_depth.type_as(ray_start) + max_depth = max_depth.type_as(ray_start) + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(min_depth) + ctx.mark_non_differentiable(max_depth) + return inds, min_depth, max_depth + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None + +ball_ray_intersect = BallRayIntersect.apply + + +class AABBRayIntersect(Function): + @staticmethod + def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir): + r""" + + Parameters + ---------- + radius : float + radius of the balls + n_max: int + maximum number of points to intersect. + xyz : torch.Tensor + (B, N, 3) xyz coordinates of the features + new_xyz : torch.Tensor + (B, npoint, 3) centers of the ball query + + Returns + ------- + torch.Tensor + (B, npoint) tensor with the nearest indicies of the features that form the query balls + """ + # HACK: speed-up ray-voxel intersection by batching... + G = 2048 + S, N = ray_start.shape[:2] + K = int(np.ceil(N / G)) + H = K * G + if H > N: + ray_start = torch.cat([ray_start, ray_start[:, :H-N]], 1) + ray_dir = torch.cat([ray_dir, ray_dir[:, :H-N]], 1) + ray_start = ray_start.reshape(S * G, K, 3) + ray_dir = ray_dir.reshape(S * G, K, 3) + points = points.expand(S * G, *points.size()[1:]).contiguous() + + inds, min_depth, max_depth = _ext.aabb_intersect( + ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max) + min_depth = min_depth.type_as(ray_start) + max_depth = max_depth.type_as(ray_start) + + inds = inds.reshape(S, H, -1) + min_depth = min_depth.reshape(S, H, -1) + max_depth = max_depth.reshape(S, H, -1) + if H > N: + inds = inds[:, :N] + min_depth = min_depth[:, :N] + max_depth = max_depth[:, :N] + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(min_depth) + ctx.mark_non_differentiable(max_depth) + return inds, min_depth, max_depth + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None + +aabb_ray_intersect = AABBRayIntersect.apply + + +class SparseVoxelOctreeRayIntersect(Function): + @staticmethod + def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir): + r""" + + Parameters + ---------- + radius : float + radius of the balls + n_max: int + maximum number of points to intersect. + xyz : torch.Tensor + (B, N, 3) xyz coordinates of the features + new_xyz : torch.Tensor + (B, npoint, 3) centers of the ball query + + Returns + ------- + torch.Tensor + (B, npoint) tensor with the nearest indicies of the features that form the query balls + """ + # HACK: speed-up ray-voxel intersection by batching... + G = 2048 + S, N = ray_start.shape[:2] + K = int(np.ceil(N / G)) + H = K * G + if H > N: + ray_start = torch.cat([ray_start, ray_start[:, :H-N]], 1) + ray_dir = torch.cat([ray_dir, ray_dir[:, :H-N]], 1) + ray_start = ray_start.reshape(S * G, K, 3) + ray_dir = ray_dir.reshape(S * G, K, 3) + points = points.expand(S * G, *points.size()[1:]).contiguous() + children = children.expand(S * G, *children.size()[1:]).contiguous() + inds, min_depth, max_depth = _ext.svo_intersect( + ray_start.float(), ray_dir.float(), points.float(), children.int(), voxelsize, n_max) + # from fairseq import pdb; pdb.set_trace() + + min_depth = min_depth.type_as(ray_start) + max_depth = max_depth.type_as(ray_start) + + inds = inds.reshape(S, H, -1) + min_depth = min_depth.reshape(S, H, -1) + max_depth = max_depth.reshape(S, H, -1) + if H > N: + inds = inds[:, :N] + min_depth = min_depth[:, :N] + max_depth = max_depth[:, :N] + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(min_depth) + ctx.mark_non_differentiable(max_depth) + return inds, min_depth, max_depth + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None + +svo_ray_intersect = SparseVoxelOctreeRayIntersect.apply + + +class TriangleRayIntersect(Function): + @staticmethod + def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir): + # HACK: speed-up ray-voxel intersection by batching... + G = 2048 + S, N = ray_start.shape[:2] + K = int(np.ceil(N / G)) + H = K * G + if H > N: + ray_start = torch.cat([ray_start, ray_start[:, :H-N]], 1) + ray_dir = torch.cat([ray_dir, ray_dir[:, :H-N]], 1) + ray_start = ray_start.reshape(S * G, K, 3) + ray_dir = ray_dir.reshape(S * G, K, 3) + face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3)) + face_points = face_points.unsqueeze(0).expand(S * G, *face_points.size()).contiguous() + inds, depth, uv = _ext.triangle_intersect( + ray_start.float(), ray_dir.float(), face_points.float(), cagesize, blur_ratio, n_max) + depth = depth.type_as(ray_start) + uv = uv.type_as(ray_start) + + inds = inds.reshape(S, H, -1) + depth = depth.reshape(S, H, -1, 3) + uv = uv.reshape(S, H, -1) + if H > N: + inds = inds[:, :N] + depth = depth[:, :N] + uv = uv[:, :N] + + ctx.mark_non_differentiable(inds) + ctx.mark_non_differentiable(depth) + ctx.mark_non_differentiable(uv) + return inds, depth, uv + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None, None + +triangle_ray_intersect = TriangleRayIntersect.apply + + +class UniformRaySampling(Function): + @staticmethod + def forward(ctx, pts_idx, min_depth, max_depth, step_size, max_ray_length, deterministic=False): + G, N, P = 256, pts_idx.size(0), pts_idx.size(1) + H = int(np.ceil(N / G)) * G + if H > N: + pts_idx = torch.cat([pts_idx, pts_idx[:H-N]], 0) + min_depth = torch.cat([min_depth, min_depth[:H-N]], 0) + max_depth = torch.cat([max_depth, max_depth[:H-N]], 0) + pts_idx = pts_idx.reshape(G, -1, P) + min_depth = min_depth.reshape(G, -1, P) + max_depth = max_depth.reshape(G, -1, P) + + # pre-generate noise + max_steps = int(max_ray_length / step_size) + max_steps = max_steps + min_depth.size(-1) * 2 + noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps) + if deterministic: + noise += 0.5 + else: + noise = noise.uniform_() + + # call cuda function + sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling( + pts_idx, min_depth.float(), max_depth.float(), noise.float(), step_size, max_steps) + sampled_depth = sampled_depth.type_as(min_depth) + sampled_dists = sampled_dists.type_as(min_depth) + + sampled_idx = sampled_idx.reshape(H, -1) + sampled_depth = sampled_depth.reshape(H, -1) + sampled_dists = sampled_dists.reshape(H, -1) + if H > N: + sampled_idx = sampled_idx[: N] + sampled_depth = sampled_depth[: N] + sampled_dists = sampled_dists[: N] + + max_len = sampled_idx.ne(-1).sum(-1).max() + sampled_idx = sampled_idx[:, :max_len] + sampled_depth = sampled_depth[:, :max_len] + sampled_dists = sampled_dists[:, :max_len] + + ctx.mark_non_differentiable(sampled_idx) + ctx.mark_non_differentiable(sampled_depth) + ctx.mark_non_differentiable(sampled_dists) + return sampled_idx, sampled_depth, sampled_dists + + @staticmethod + def backward(ctx, a, b, c): + return None, None, None, None, None, None + +uniform_ray_sampling = UniformRaySampling.apply + + +# back-up for ray point sampling +@torch.no_grad() +def _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): + # uniform sampling + _min_depth = min_depth.min(1)[0] + _max_depth = max_depth.masked_fill(max_depth.eq(MAX_DEPTH), 0).max(1)[0] + max_ray_length = (_max_depth - _min_depth).max() + + delta = torch.arange(int(max_ray_length / MARCH_SIZE), device=min_depth.device, dtype=min_depth.dtype) + delta = delta[None, :].expand(min_depth.size(0), delta.size(-1)) + if deterministic: + delta = delta + 0.5 + else: + delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99) + delta = delta * MARCH_SIZE + sampled_depth = min_depth[:, :1] + delta + sampled_idx = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 + sampled_idx = pts_idx.gather(1, sampled_idx) + + # include all boundary points + sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1) + sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1) + + # reorder + sampled_depth, ordered_index = sampled_depth.sort(-1) + sampled_idx = sampled_idx.gather(1, ordered_index) + sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1] # distances + sampled_depth = .5 * (sampled_depth[:, 1:] + sampled_depth[:, :-1]) # mid-points + + # remove all invalid depths + min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1 + max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1) + + sampled_depth.masked_fill_( + (max_ids.ne(min_ids)) | + (sampled_depth > _max_depth[:, None]) | + (sampled_dists == 0.0) + , MAX_DEPTH) + sampled_depth, ordered_index = sampled_depth.sort(-1) # sort again + sampled_masks = sampled_depth.eq(MAX_DEPTH) + num_max_steps = (~sampled_masks).sum(-1).max() + + sampled_depth = sampled_depth[:, :num_max_steps] + sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_(sampled_masks, 0.0)[:, :num_max_steps] + sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_(sampled_masks, -1)[:, :num_max_steps] + + return sampled_idx, sampled_depth, sampled_dists + + +@torch.no_grad() +def parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False): + chunk_size=4096 + full_size = min_depth.shape[0] + if full_size <= chunk_size: + return _parallel_ray_sampling(MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic) + + outputs = zip(*[ + _parallel_ray_sampling( + MARCH_SIZE, + pts_idx[i:i+chunk_size], min_depth[i:i+chunk_size], max_depth[i:i+chunk_size], + deterministic=deterministic) + for i in range(0, full_size, chunk_size)]) + sampled_idx, sampled_depth, sampled_dists = outputs + + def padding_points(xs, pad): + if len(xs) == 1: + return xs[0] + + maxlen = max([x.size(1) for x in xs]) + full_size = sum([x.size(0) for x in xs]) + xt = xs[0].new_ones(full_size, maxlen).fill_(pad) + st = 0 + for i in range(len(xs)): + xt[st: st + xs[i].size(0), :xs[i].size(1)] = xs[i] + st += xs[i].size(0) + return xt + + sampled_idx = padding_points(sampled_idx, -1) + sampled_depth = padding_points(sampled_depth, MAX_DEPTH) + sampled_dists = padding_points(sampled_dists, 0.0) + return sampled_idx, sampled_depth, sampled_dists + diff --git a/fairnr/clib/include/cuda_utils.h b/fairnr/clib/include/cuda_utils.h new file mode 100644 index 0000000..d4c4bb4 --- /dev/null +++ b/fairnr/clib/include/cuda_utils.h @@ -0,0 +1,46 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include +#include + +#include +#include + +#include + +#define TOTAL_THREADS 512 + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + + return max(min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = + max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + do { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } while (0) + +#endif diff --git a/fairnr/clib/include/cutil_math.h b/fairnr/clib/include/cutil_math.h new file mode 100644 index 0000000..d8748b9 --- /dev/null +++ b/fairnr/clib/include/cutil_math.h @@ -0,0 +1,793 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +/* + * Copyright 1993-2009 NVIDIA Corporation. All rights reserved. + * + * NVIDIA Corporation and its licensors retain all intellectual property and + * proprietary rights in and to this software and related documentation and + * any modifications thereto. Any use, reproduction, disclosure, or distribution + * of this software and related documentation without an express license + * agreement from NVIDIA Corporation is strictly prohibited. + * + */ + +/* + This file implements common mathematical operations on vector types + (float3, float4 etc.) since these are not provided as standard by CUDA. + + The syntax is modelled on the Cg standard library. +*/ + +#ifndef CUTIL_MATH_H +#define CUTIL_MATH_H + +#include "cuda_runtime.h" + +//////////////////////////////////////////////////////////////////////////////// +typedef unsigned int uint; +typedef unsigned short ushort; + +#ifndef __CUDACC__ +#include + +inline float fminf(float a, float b) +{ + return a < b ? a : b; +} + +inline float fmaxf(float a, float b) +{ + return a > b ? a : b; +} + +inline int max(int a, int b) +{ + return a > b ? a : b; +} + +inline int min(int a, int b) +{ + return a < b ? a : b; +} + +inline float rsqrtf(float x) +{ + return 1.0f / sqrtf(x); +} + +#endif + +// float functions +//////////////////////////////////////////////////////////////////////////////// + +// lerp +inline __device__ __host__ float lerp(float a, float b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float clamp(float f, float a, float b) +{ + return fmaxf(a, fminf(f, b)); +} + +inline __device__ __host__ void swap(float &a, float &b) +{ + float c = a; + a = b; + b = c; +} + +inline __device__ __host__ void swap(int &a, int &b) +{ + float c = a; + a = b; + b = c; +} + + +// int2 functions +//////////////////////////////////////////////////////////////////////////////// + +// negate +inline __host__ __device__ int2 operator-(int2 &a) +{ + return make_int2(-a.x, -a.y); +} + +// addition +inline __host__ __device__ int2 operator+(int2 a, int2 b) +{ + return make_int2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(int2 &a, int2 b) +{ + a.x += b.x; a.y += b.y; +} + +// subtract +inline __host__ __device__ int2 operator-(int2 a, int2 b) +{ + return make_int2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(int2 &a, int2 b) +{ + a.x -= b.x; a.y -= b.y; +} + +// multiply +inline __host__ __device__ int2 operator*(int2 a, int2 b) +{ + return make_int2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ int2 operator*(int2 a, int s) +{ + return make_int2(a.x * s, a.y * s); +} +inline __host__ __device__ int2 operator*(int s, int2 a) +{ + return make_int2(a.x * s, a.y * s); +} +inline __host__ __device__ void operator*=(int2 &a, int s) +{ + a.x *= s; a.y *= s; +} + +// float2 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ float2 make_float2(float s) +{ + return make_float2(s, s); +} +inline __host__ __device__ float2 make_float2(int2 a) +{ + return make_float2(float(a.x), float(a.y)); +} + +// negate +inline __host__ __device__ float2 operator-(float2 &a) +{ + return make_float2(-a.x, -a.y); +} + +// addition +inline __host__ __device__ float2 operator+(float2 a, float2 b) +{ + return make_float2(a.x + b.x, a.y + b.y); +} +inline __host__ __device__ void operator+=(float2 &a, float2 b) +{ + a.x += b.x; a.y += b.y; +} + +// subtract +inline __host__ __device__ float2 operator-(float2 a, float2 b) +{ + return make_float2(a.x - b.x, a.y - b.y); +} +inline __host__ __device__ void operator-=(float2 &a, float2 b) +{ + a.x -= b.x; a.y -= b.y; +} + +// multiply +inline __host__ __device__ float2 operator*(float2 a, float2 b) +{ + return make_float2(a.x * b.x, a.y * b.y); +} +inline __host__ __device__ float2 operator*(float2 a, float s) +{ + return make_float2(a.x * s, a.y * s); +} +inline __host__ __device__ float2 operator*(float s, float2 a) +{ + return make_float2(a.x * s, a.y * s); +} +inline __host__ __device__ void operator*=(float2 &a, float s) +{ + a.x *= s; a.y *= s; +} + +// divide +inline __host__ __device__ float2 operator/(float2 a, float2 b) +{ + return make_float2(a.x / b.x, a.y / b.y); +} +inline __host__ __device__ float2 operator/(float2 a, float s) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ float2 operator/(float s, float2 a) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ void operator/=(float2 &a, float s) +{ + float inv = 1.0f / s; + a *= inv; +} + +// lerp +inline __device__ __host__ float2 lerp(float2 a, float2 b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float2 clamp(float2 v, float a, float b) +{ + return make_float2(clamp(v.x, a, b), clamp(v.y, a, b)); +} + +inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b) +{ + return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); +} + +// dot product +inline __host__ __device__ float dot(float2 a, float2 b) +{ + return a.x * b.x + a.y * b.y; +} + +// length +inline __host__ __device__ float length(float2 v) +{ + return sqrtf(dot(v, v)); +} + +// normalize +inline __host__ __device__ float2 normalize(float2 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +// floor +inline __host__ __device__ float2 floor(const float2 v) +{ + return make_float2(floor(v.x), floor(v.y)); +} + +// reflect +inline __host__ __device__ float2 reflect(float2 i, float2 n) +{ + return i - 2.0f * n * dot(n,i); +} + +// absolute value +inline __host__ __device__ float2 fabs(float2 v) +{ + return make_float2(fabs(v.x), fabs(v.y)); +} + +// float3 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ float3 make_float3(float s) +{ + return make_float3(s, s, s); +} +inline __host__ __device__ float3 make_float3(float2 a) +{ + return make_float3(a.x, a.y, 0.0f); +} +inline __host__ __device__ float3 make_float3(float2 a, float s) +{ + return make_float3(a.x, a.y, s); +} +inline __host__ __device__ float3 make_float3(float4 a) +{ + return make_float3(a.x, a.y, a.z); // discards w +} +inline __host__ __device__ float3 make_float3(int3 a) +{ + return make_float3(float(a.x), float(a.y), float(a.z)); +} + +// negate +inline __host__ __device__ float3 operator-(float3 &a) +{ + return make_float3(-a.x, -a.y, -a.z); +} + +// min +static __inline__ __host__ __device__ float3 fminf(float3 a, float3 b) +{ + return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z)); +} + +// max +static __inline__ __host__ __device__ float3 fmaxf(float3 a, float3 b) +{ + return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z)); +} + +// addition +inline __host__ __device__ float3 operator+(float3 a, float3 b) +{ + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ float3 operator+(float3 a, float b) +{ + return make_float3(a.x + b, a.y + b, a.z + b); +} +inline __host__ __device__ void operator+=(float3 &a, float3 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; +} + +// subtract +inline __host__ __device__ float3 operator-(float3 a, float3 b) +{ + return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); +} +inline __host__ __device__ float3 operator-(float3 a, float b) +{ + return make_float3(a.x - b, a.y - b, a.z - b); +} +inline __host__ __device__ void operator-=(float3 &a, float3 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; +} + +// multiply +inline __host__ __device__ float3 operator*(float3 a, float3 b) +{ + return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ float3 operator*(float3 a, float s) +{ + return make_float3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ float3 operator*(float s, float3 a) +{ + return make_float3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ void operator*=(float3 &a, float s) +{ + a.x *= s; a.y *= s; a.z *= s; +} +inline __host__ __device__ void operator*=(float3 &a, float3 b) +{ + a.x *= b.x; a.y *= b.y; a.z *= b.z;; +} + +// divide +inline __host__ __device__ float3 operator/(float3 a, float3 b) +{ + return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ float3 operator/(float3 a, float s) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ float3 operator/(float s, float3 a) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ void operator/=(float3 &a, float s) +{ + float inv = 1.0f / s; + a *= inv; +} + +// lerp +inline __device__ __host__ float3 lerp(float3 a, float3 b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float3 clamp(float3 v, float a, float b) +{ + return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} + +inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) +{ + return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} + +// dot product +inline __host__ __device__ float dot(float3 a, float3 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +// cross product +inline __host__ __device__ float3 cross(float3 a, float3 b) +{ + return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x); +} + +// length +inline __host__ __device__ float length(float3 v) +{ + return sqrtf(dot(v, v)); +} + +// normalize +inline __host__ __device__ float3 normalize(float3 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +// floor +inline __host__ __device__ float3 floor(const float3 v) +{ + return make_float3(floor(v.x), floor(v.y), floor(v.z)); +} + +// reflect +inline __host__ __device__ float3 reflect(float3 i, float3 n) +{ + return i - 2.0f * n * dot(n,i); +} + +// absolute value +inline __host__ __device__ float3 fabs(float3 v) +{ + return make_float3(fabs(v.x), fabs(v.y), fabs(v.z)); +} + +// float4 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ float4 make_float4(float s) +{ + return make_float4(s, s, s, s); +} +inline __host__ __device__ float4 make_float4(float3 a) +{ + return make_float4(a.x, a.y, a.z, 0.0f); +} +inline __host__ __device__ float4 make_float4(float3 a, float w) +{ + return make_float4(a.x, a.y, a.z, w); +} +inline __host__ __device__ float4 make_float4(int4 a) +{ + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); +} + +// negate +inline __host__ __device__ float4 operator-(float4 &a) +{ + return make_float4(-a.x, -a.y, -a.z, -a.w); +} + +// min +static __inline__ __host__ __device__ float4 fminf(float4 a, float4 b) +{ + return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w)); +} + +// max +static __inline__ __host__ __device__ float4 fmaxf(float4 a, float4 b) +{ + return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w)); +} + +// addition +inline __host__ __device__ float4 operator+(float4 a, float4 b) +{ + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); +} +inline __host__ __device__ void operator+=(float4 &a, float4 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; +} + +// subtract +inline __host__ __device__ float4 operator-(float4 a, float4 b) +{ + return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); +} +inline __host__ __device__ void operator-=(float4 &a, float4 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; +} + +// multiply +inline __host__ __device__ float4 operator*(float4 a, float s) +{ + return make_float4(a.x * s, a.y * s, a.z * s, a.w * s); +} +inline __host__ __device__ float4 operator*(float s, float4 a) +{ + return make_float4(a.x * s, a.y * s, a.z * s, a.w * s); +} +inline __host__ __device__ void operator*=(float4 &a, float s) +{ + a.x *= s; a.y *= s; a.z *= s; a.w *= s; +} + +// divide +inline __host__ __device__ float4 operator/(float4 a, float4 b) +{ + return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); +} +inline __host__ __device__ float4 operator/(float4 a, float s) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ float4 operator/(float s, float4 a) +{ + float inv = 1.0f / s; + return a * inv; +} +inline __host__ __device__ void operator/=(float4 &a, float s) +{ + float inv = 1.0f / s; + a *= inv; +} + +// lerp +inline __device__ __host__ float4 lerp(float4 a, float4 b, float t) +{ + return a + t*(b-a); +} + +// clamp +inline __device__ __host__ float4 clamp(float4 v, float a, float b) +{ + return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); +} + +inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b) +{ + return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); +} + +// dot product +inline __host__ __device__ float dot(float4 a, float4 b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; +} + +// length +inline __host__ __device__ float length(float4 r) +{ + return sqrtf(dot(r, r)); +} + +// normalize +inline __host__ __device__ float4 normalize(float4 v) +{ + float invLen = rsqrtf(dot(v, v)); + return v * invLen; +} + +// floor +inline __host__ __device__ float4 floor(const float4 v) +{ + return make_float4(floor(v.x), floor(v.y), floor(v.z), floor(v.w)); +} + +// absolute value +inline __host__ __device__ float4 fabs(float4 v) +{ + return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w)); +} + +// int3 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ int3 make_int3(int s) +{ + return make_int3(s, s, s); +} +inline __host__ __device__ int3 make_int3(float3 a) +{ + return make_int3(int(a.x), int(a.y), int(a.z)); +} + +// negate +inline __host__ __device__ int3 operator-(int3 &a) +{ + return make_int3(-a.x, -a.y, -a.z); +} + +// min +inline __host__ __device__ int3 min(int3 a, int3 b) +{ + return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} + +// max +inline __host__ __device__ int3 max(int3 a, int3 b) +{ + return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} + +// addition +inline __host__ __device__ int3 operator+(int3 a, int3 b) +{ + return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(int3 &a, int3 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; +} + +// subtract +inline __host__ __device__ int3 operator-(int3 a, int3 b) +{ + return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); +} + +inline __host__ __device__ void operator-=(int3 &a, int3 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; +} + +// multiply +inline __host__ __device__ int3 operator*(int3 a, int3 b) +{ + return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ int3 operator*(int3 a, int s) +{ + return make_int3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ int3 operator*(int s, int3 a) +{ + return make_int3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ void operator*=(int3 &a, int s) +{ + a.x *= s; a.y *= s; a.z *= s; +} + +// divide +inline __host__ __device__ int3 operator/(int3 a, int3 b) +{ + return make_int3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ int3 operator/(int3 a, int s) +{ + return make_int3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ int3 operator/(int s, int3 a) +{ + return make_int3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ void operator/=(int3 &a, int s) +{ + a.x /= s; a.y /= s; a.z /= s; +} + +// clamp +inline __device__ __host__ int clamp(int f, int a, int b) +{ + return max(a, min(f, b)); +} + +inline __device__ __host__ int3 clamp(int3 v, int a, int b) +{ + return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} + +inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b) +{ + return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} + + +// uint3 functions +//////////////////////////////////////////////////////////////////////////////// + +// additional constructors +inline __host__ __device__ uint3 make_uint3(uint s) +{ + return make_uint3(s, s, s); +} +inline __host__ __device__ uint3 make_uint3(float3 a) +{ + return make_uint3(uint(a.x), uint(a.y), uint(a.z)); +} + +// min +inline __host__ __device__ uint3 min(uint3 a, uint3 b) +{ + return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z)); +} + +// max +inline __host__ __device__ uint3 max(uint3 a, uint3 b) +{ + return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z)); +} + +// addition +inline __host__ __device__ uint3 operator+(uint3 a, uint3 b) +{ + return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); +} +inline __host__ __device__ void operator+=(uint3 &a, uint3 b) +{ + a.x += b.x; a.y += b.y; a.z += b.z; +} + +// subtract +inline __host__ __device__ uint3 operator-(uint3 a, uint3 b) +{ + return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); +} + +inline __host__ __device__ void operator-=(uint3 &a, uint3 b) +{ + a.x -= b.x; a.y -= b.y; a.z -= b.z; +} + +// multiply +inline __host__ __device__ uint3 operator*(uint3 a, uint3 b) +{ + return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); +} +inline __host__ __device__ uint3 operator*(uint3 a, uint s) +{ + return make_uint3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ uint3 operator*(uint s, uint3 a) +{ + return make_uint3(a.x * s, a.y * s, a.z * s); +} +inline __host__ __device__ void operator*=(uint3 &a, uint s) +{ + a.x *= s; a.y *= s; a.z *= s; +} + +// divide +inline __host__ __device__ uint3 operator/(uint3 a, uint3 b) +{ + return make_uint3(a.x / b.x, a.y / b.y, a.z / b.z); +} +inline __host__ __device__ uint3 operator/(uint3 a, uint s) +{ + return make_uint3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ uint3 operator/(uint s, uint3 a) +{ + return make_uint3(a.x / s, a.y / s, a.z / s); +} +inline __host__ __device__ void operator/=(uint3 &a, uint s) +{ + a.x /= s; a.y /= s; a.z /= s; +} + +// clamp +inline __device__ __host__ uint clamp(uint f, uint a, uint b) +{ + return max(a, min(f, b)); +} + +inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b) +{ + return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); +} + +inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b) +{ + return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); +} + + + +#endif \ No newline at end of file diff --git a/fairnr/clib/include/intersect.h b/fairnr/clib/include/intersect.h new file mode 100644 index 0000000..bd23d4e --- /dev/null +++ b/fairnr/clib/include/intersect.h @@ -0,0 +1,19 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +std::tuple ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float radius, const int n_max); +std::tuple aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float voxelsize, const int n_max); +std::tuple svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, at::Tensor children, + const float voxelsize, const int n_max); +std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, + const float cagesize, const float blur, const int n_max); +std::tuple uniform_ray_sampling(at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, + const float step_size, const int max_steps); \ No newline at end of file diff --git a/fairnr/clib/include/octree.h b/fairnr/clib/include/octree.h new file mode 100644 index 0000000..429053e --- /dev/null +++ b/fairnr/clib/include/octree.h @@ -0,0 +1,10 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +std::tuple build_octree(at::Tensor center, at::Tensor points, int depth); \ No newline at end of file diff --git a/fairnr/clib/include/utils.h b/fairnr/clib/include/utils.h new file mode 100644 index 0000000..925f769 --- /dev/null +++ b/fairnr/clib/include/utils.h @@ -0,0 +1,30 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include + +#define CHECK_CUDA(x) \ + do { \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ + } while (0) + +#define CHECK_CONTIGUOUS(x) \ + do { \ + TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ + } while (0) + +#define CHECK_IS_INT(x) \ + do { \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor"); \ + } while (0) + +#define CHECK_IS_FLOAT(x) \ + do { \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor"); \ + } while (0) diff --git a/fairnr/clib/src/binding.cpp b/fairnr/clib/src/binding.cpp new file mode 100644 index 0000000..b906b1f --- /dev/null +++ b/fairnr/clib/src/binding.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "intersect.h" +#include "octree.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ball_intersect", &ball_intersect); + m.def("aabb_intersect", &aabb_intersect); + m.def("svo_intersect", &svo_intersect); + m.def("triangle_intersect", &triangle_intersect); + m.def("uniform_ray_sampling", &uniform_ray_sampling); + + m.def("build_octree", &build_octree); +} \ No newline at end of file diff --git a/fairnr/clib/src/intersect.cpp b/fairnr/clib/src/intersect.cpp new file mode 100644 index 0000000..d10d586 --- /dev/null +++ b/fairnr/clib/src/intersect.cpp @@ -0,0 +1,186 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "intersect.h" +#include "utils.h" +#include + +void ball_intersect_point_kernel_wrapper( + int b, int n, int m, float radius, int n_max, + const float *ray_start, const float *ray_dir, const float *points, + int *idx, float *min_depth, float *max_depth); + +std::tuple< at::Tensor, at::Tensor, at::Tensor > ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float radius, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(points); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(points); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor min_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor max_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + ball_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), + radius, n_max, + ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), + idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); + return std::make_tuple(idx, min_depth, max_depth); +} + + +void aabb_intersect_point_kernel_wrapper( + int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, const float *points, + int *idx, float *min_depth, float *max_depth); + +std::tuple< at::Tensor, at::Tensor, at::Tensor > aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + const float voxelsize, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(points); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(points); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor min_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor max_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + aabb_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), + voxelsize, n_max, + ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), + idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); + return std::make_tuple(idx, min_depth, max_depth); +} + + +void svo_intersect_point_kernel_wrapper( + int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, const float *points, const int *children, + int *idx, float *min_depth, float *max_depth); + + +std::tuple< at::Tensor, at::Tensor, at::Tensor > svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, + at::Tensor children, const float voxelsize, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(children); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(points); + CHECK_CUDA(children); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor min_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor max_depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + svo_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1), + voxelsize, n_max, + ray_start.data_ptr (), ray_dir.data_ptr (), points.data_ptr (), + children.data_ptr (), idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr ()); + return std::make_tuple(idx, min_depth, max_depth); +} + + +void triangle_intersect_point_kernel_wrapper( + int b, int n, int m, float cagesize, float blur, int n_max, + const float *ray_start, const float *ray_dir, const float *face_points, + int *idx, float *depth, float *uv); + +std::tuple< at::Tensor, at::Tensor, at::Tensor > triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points, + const float cagesize, const float blur, const int n_max){ + CHECK_CONTIGUOUS(ray_start); + CHECK_CONTIGUOUS(ray_dir); + CHECK_CONTIGUOUS(face_points); + CHECK_IS_FLOAT(ray_start); + CHECK_IS_FLOAT(ray_dir); + CHECK_IS_FLOAT(face_points); + CHECK_CUDA(ray_start); + CHECK_CUDA(ray_dir); + CHECK_CUDA(face_points); + + at::Tensor idx = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max}, + at::device(ray_start.device()).dtype(at::ScalarType::Int)); + at::Tensor depth = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 3}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + at::Tensor uv = + torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 2}, + at::device(ray_start.device()).dtype(at::ScalarType::Float)); + triangle_intersect_point_kernel_wrapper(face_points.size(0), face_points.size(1), ray_start.size(1), + cagesize, blur, n_max, + ray_start.data_ptr (), ray_dir.data_ptr (), face_points.data_ptr (), + idx.data_ptr (), depth.data_ptr (), uv.data_ptr ()); + return std::make_tuple(idx, depth, uv); +} + + +void uniform_ray_sampling_kernel_wrapper( + int b, int num_rays, int max_hits, int max_steps, float step_size, + const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, + int *sampled_idx, float *sampled_depth, float *sampled_dists); + + +std::tuple< at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling( + at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise, + const float step_size, const int max_steps){ + + CHECK_CONTIGUOUS(pts_idx); + CHECK_CONTIGUOUS(min_depth); + CHECK_CONTIGUOUS(max_depth); + CHECK_CONTIGUOUS(uniform_noise); + CHECK_IS_FLOAT(min_depth); + CHECK_IS_FLOAT(max_depth); + CHECK_IS_FLOAT(uniform_noise); + CHECK_CUDA(pts_idx); + CHECK_CUDA(min_depth); + CHECK_CUDA(max_depth); + CHECK_CUDA(uniform_noise); + + at::Tensor sampled_idx = + -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps}, + at::device(pts_idx.device()).dtype(at::ScalarType::Int)); + at::Tensor sampled_depth = + torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, + at::device(min_depth.device()).dtype(at::ScalarType::Float)); + at::Tensor sampled_dists = + torch::zeros({min_depth.size(0), min_depth.size(1), max_steps}, + at::device(min_depth.device()).dtype(at::ScalarType::Float)); + uniform_ray_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), + step_size, + pts_idx.data_ptr (), min_depth.data_ptr (), max_depth.data_ptr (), + uniform_noise.data_ptr (), sampled_idx.data_ptr (), + sampled_depth.data_ptr (), sampled_dists.data_ptr ()); + return std::make_tuple(sampled_idx, sampled_depth, sampled_dists); +} \ No newline at end of file diff --git a/fairnr/clib/src/intersect_gpu.cu b/fairnr/clib/src/intersect_gpu.cu new file mode 100644 index 0000000..9581ee3 --- /dev/null +++ b/fairnr/clib/src/intersect_gpu.cu @@ -0,0 +1,503 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include "cuda_utils.h" +#include "cutil_math.h" // required for float3 vector math + + +__global__ void ball_intersect_point_kernel( + int b, int n, int m, float radius, + int n_max, + const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ points, + int *__restrict__ idx, + float *__restrict__ min_depth, + float *__restrict__ max_depth) { + + int batch_index = blockIdx.x; + points += batch_index * n * 3; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + min_depth += batch_index * m * n_max; + max_depth += batch_index * m * n_max; + + int index = threadIdx.x; + int stride = blockDim.x; + float radius2 = radius * radius; + + for (int j = index; j < m; j += stride) { + + float x0 = ray_start[j * 3 + 0]; + float y0 = ray_start[j * 3 + 1]; + float z0 = ray_start[j * 3 + 2]; + float xw = ray_dir[j * 3 + 0]; + float yw = ray_dir[j * 3 + 1]; + float zw = ray_dir[j * 3 + 2]; + + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + + for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) { + float x = points[k * 3 + 0] - x0; + float y = points[k * 3 + 1] - y0; + float z = points[k * 3 + 2] - z0; + float d2 = x * x + y * y + z * z; + float d2_proj = pow(x * xw + y * yw + z * zw, 2); + float r2 = d2 - d2_proj; + + if (r2 < radius2) { + idx[j * n_max + cnt] = k; + + float depth = sqrt(d2_proj); + float depth_blur = sqrt(radius2 - r2); + + min_depth[j * n_max + cnt] = depth - depth_blur; + max_depth[j * n_max + cnt] = depth + depth_blur; + ++cnt; + } + } + } +} + + +__device__ float2 RayAABBIntersection( + const float3 &ori, + const float3 &dir, + const float3 ¢er, + float half_voxel) { + + float f_low = 0; + float f_high = 100000.; + float f_dim_low, f_dim_high, temp, inv_ray_dir, start, aabb; + + for (int d = 0; d < 3; ++d) { + switch (d) { + case 0: + inv_ray_dir = __fdividef(1.0f, dir.x); start = ori.x; aabb = center.x; break; + case 1: + inv_ray_dir = __fdividef(1.0f, dir.y); start = ori.y; aabb = center.y; break; + case 2: + inv_ray_dir = __fdividef(1.0f, dir.z); start = ori.z; aabb = center.z; break; + } + + f_dim_low = (aabb - half_voxel - start) * inv_ray_dir; + f_dim_high = (aabb + half_voxel - start) * inv_ray_dir; + + // Make sure low is less than high + if (f_dim_high < f_dim_low) { + temp = f_dim_low; + f_dim_low = f_dim_high; + f_dim_high = temp; + } + + // If this dimension's high is less than the low we got then we definitely missed. + if (f_dim_high < f_low) { + return make_float2(-1.0f, -1.0f); + } + + // Likewise if the low is less than the high. + if (f_dim_low > f_high) { + return make_float2(-1.0f, -1.0f); + } + + // Add the clip from this dimension to the previous results + f_low = (f_dim_low > f_low) ? f_dim_low : f_low; + f_high = (f_dim_high < f_high) ? f_dim_high : f_high; + + if (f_low > f_high) { + return make_float2(-1.0f, -1.0f); + } + } + return make_float2(f_low, f_high); +} + + +__global__ void aabb_intersect_point_kernel( + int b, int n, int m, float voxelsize, + int n_max, + const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ points, + int *__restrict__ idx, + float *__restrict__ min_depth, + float *__restrict__ max_depth) { + + int batch_index = blockIdx.x; + points += batch_index * n * 3; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + min_depth += batch_index * m * n_max; + max_depth += batch_index * m * n_max; + + int index = threadIdx.x; + int stride = blockDim.x; + float half_voxel = voxelsize * 0.5; + + for (int j = index; j < m; j += stride) { + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + + for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k) { + float2 depths = RayAABBIntersection( + make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), + make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), + make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]), + half_voxel); + + if (depths.x > -1.0f){ + idx[j * n_max + cnt] = k; + min_depth[j * n_max + cnt] = depths.x; + max_depth[j * n_max + cnt] = depths.y; + ++cnt; + } + } + } +} + + +__global__ void svo_intersect_point_kernel( + int b, int n, int m, float voxelsize, + int n_max, + const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ points, + const int *__restrict__ children, + int *__restrict__ idx, + float *__restrict__ min_depth, + float *__restrict__ max_depth) { + /* + TODO: this is an inefficient implementation of the + navie Ray -- Sparse Voxel Octree Intersection. + It can be further improved using: + + Revelles, Jorge, Carlos Urena, and Miguel Lastra. + "An efficient parametric algorithm for octree traversal." (2000). + */ + int batch_index = blockIdx.x; + points += batch_index * n * 3; + children += batch_index * n * 9; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + min_depth += batch_index * m * n_max; + max_depth += batch_index * m * n_max; + + int index = threadIdx.x; + int stride = blockDim.x; + float half_voxel = voxelsize * 0.5; + + for (int j = index; j < m; j += stride) { + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + int stack[256] = {-1}; // DFS, initialize the stack + int ptr = 0, cnt = 0, k = -1; + stack[ptr] = n - 1; // ROOT node is always the last + while (ptr > -1 && cnt < n_max) { + assert((ptr < 256)); + + // evaluate the current node + k = stack[ptr]; + float2 depths = RayAABBIntersection( + make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), + make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), + make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]), + half_voxel * float(children[k * 9 + 8])); + stack[ptr] = -1; ptr--; + + if (depths.x > -1.0f) { // ray did not miss the voxel + // TODO: here it should be able to know which children is ok, further optimize the code + if (children[k * 9 + 8] == 1) { // this is a terminal node + idx[j * n_max + cnt] = k; + min_depth[j * n_max + cnt] = depths.x; + max_depth[j * n_max + cnt] = depths.y; + ++cnt; continue; + } + + for (int u = 0; u < 8; u++) { + if (children[k * 9 + u] > -1) { + ptr++; stack[ptr] = children[k * 9 + u]; // push child to the stack + } + } + } + } + } +} + + +__device__ float3 RayTriangleIntersection( + const float3 &ori, + const float3 &dir, + const float3 &v0, + const float3 &v1, + const float3 &v2, + float blur) { + + float3 v0v1 = v1 - v0; + float3 v0v2 = v2 - v0; + float3 v0O = ori - v0; + float3 dir_crs_v0v2 = cross(dir, v0v2); + + float det = dot(v0v1, dir_crs_v0v2); + det = __fdividef(1.0f, det); // CUDA intrinsic function + + float u = dot(v0O, dir_crs_v0v2) * det; + if (u < 0.0f - blur || u > 1.0f + blur) + return make_float3(-1.0f, 0.0f, 0.0f); + + float3 v0O_crs_v0v1 = cross(v0O, v0v1); + float v = dot(dir, v0O_crs_v0v1) * det; + if (v < 0.0f - blur || v > 1.0f + blur) + return make_float3(-1.0f, 0.0f, 0.0f); + + if ((u + v) < 0.0f - blur || (u + v) > 1.0f + blur) + return make_float3(-1.0f, 0.0f, 0.0f); + + float t = dot(v0v2, v0O_crs_v0v1) * det; + return make_float3(t, u, v); +} + + +__global__ void triangle_intersect_point_kernel( + int b, int n, int m, float cagesize, + float blur, int n_max, + const float *__restrict__ ray_start, + const float *__restrict__ ray_dir, + const float *__restrict__ face_points, + int *__restrict__ idx, + float *__restrict__ depth, + float *__restrict__ uv) { + + int batch_index = blockIdx.x; + face_points += batch_index * n * 9; + ray_start += batch_index * m * 3; + ray_dir += batch_index * m * 3; + idx += batch_index * m * n_max; + depth += batch_index * m * n_max * 3; + uv += batch_index * m * n_max * 2; + + int index = threadIdx.x; + int stride = blockDim.x; + for (int j = index; j < m; j += stride) { + // go over rays + for (int l = 0; l < n_max; ++l) { + idx[j * n_max + l] = -1; + } + + int cnt = 0; + for (int k = 0; k < n && cnt < n_max; ++k) { + // go over triangles + float3 tuv = RayTriangleIntersection( + make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]), + make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]), + make_float3(face_points[k * 9 + 0], face_points[k * 9 + 1], face_points[k * 9 + 2]), + make_float3(face_points[k * 9 + 3], face_points[k * 9 + 4], face_points[k * 9 + 5]), + make_float3(face_points[k * 9 + 6], face_points[k * 9 + 7], face_points[k * 9 + 8]), + blur); + + if (tuv.x > 0) { + int ki = k; + float d = tuv.x, u = tuv.y, v = tuv.z; + + // sort + for (int l = 0; l < cnt; l++) { + if (d < depth[j * n_max * 3 + l * 3]) { + swap(ki, idx[j * n_max + l]); + swap(d, depth[j * n_max * 3 + l * 3]); + swap(u, uv[j * n_max * 2 + l * 2]); + swap(v, uv[j * n_max * 2 + l * 2 + 1]); + } + } + idx[j * n_max + cnt] = ki; + depth[j * n_max * 3 + cnt * 3] = d; + uv[j * n_max * 2 + cnt * 2] = u; + uv[j * n_max * 2 + cnt * 2 + 1] = v; + cnt++; + } + } + + for (int l = 0; l < cnt; l++) { + // compute min_depth + if (l == 0) + depth[j * n_max * 3 + l * 3 + 1] = -cagesize; + else + depth[j * n_max * 3 + l * 3 + 1] = -fminf(cagesize, + .5 * (depth[j * n_max * 3 + l * 3] - depth[j * n_max * 3 + l * 3 - 3])); + + // compute max_depth + if (l == cnt - 1) + depth[j * n_max * 3 + l * 3 + 2] = cagesize; + else + depth[j * n_max * 3 + l * 3 + 2] = fminf(cagesize, + .5 * (depth[j * n_max * 3 + l * 3 + 3] - depth[j * n_max * 3 + l * 3])); + } + } +} + +void ball_intersect_point_kernel_wrapper( + int b, int n, int m, float radius, int n_max, + const float *ray_start, const float *ray_dir, const float *points, + int *idx, float *min_depth, float *max_depth) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + ball_intersect_point_kernel<<>>( + b, n, m, radius, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth); + + CUDA_CHECK_ERRORS(); +} + + +void aabb_intersect_point_kernel_wrapper( + int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, const float *points, + int *idx, float *min_depth, float *max_depth) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + aabb_intersect_point_kernel<<>>( + b, n, m, voxelsize, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth); + + CUDA_CHECK_ERRORS(); +} + + +void svo_intersect_point_kernel_wrapper( + int b, int n, int m, float voxelsize, int n_max, + const float *ray_start, const float *ray_dir, const float *points, const int *children, + int *idx, float *min_depth, float *max_depth) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + svo_intersect_point_kernel<<>>( + b, n, m, voxelsize, n_max, ray_start, ray_dir, points, children, idx, min_depth, max_depth); + + CUDA_CHECK_ERRORS(); +} + + +void triangle_intersect_point_kernel_wrapper( + int b, int n, int m, float cagesize, float blur, int n_max, + const float *ray_start, const float *ray_dir, const float *face_points, + int *idx, float *depth, float *uv) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + triangle_intersect_point_kernel<<>>( + b, n, m, cagesize, blur, n_max, ray_start, ray_dir, face_points, idx, depth, uv); + + CUDA_CHECK_ERRORS(); +} + + +__global__ void uniform_ray_sampling_kernel( + int b, int num_rays, + int max_hits, + int max_steps, + float step_size, + const int *__restrict__ pts_idx, + const float *__restrict__ min_depth, + const float *__restrict__ max_depth, + const float *__restrict__ uniform_noise, + int *__restrict__ sampled_idx, + float *__restrict__ sampled_depth, + float *__restrict__ sampled_dists) { + + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + + pts_idx += batch_index * num_rays * max_hits; + min_depth += batch_index * num_rays * max_hits; + max_depth += batch_index * num_rays * max_hits; + + uniform_noise += batch_index * num_rays * max_steps; + sampled_idx += batch_index * num_rays * max_steps; + sampled_depth += batch_index * num_rays * max_steps; + sampled_dists += batch_index * num_rays * max_steps; + + // loop over all rays + for (int j = index; j < num_rays; j += stride) { + int H = j * max_hits, K = j * max_steps; + int s = 0, ucur = 0, umin = 0, umax = 0; + float last_min_depth, last_max_depth, curr_depth; + + // sort all depths + while (true) { + if (pts_idx[H + umax] == -1 || umax == max_hits || ucur == max_steps) { + break; // reach the maximum + } + if (umin < max_hits) { + last_min_depth = min_depth[H + umin]; + } + if (umax < max_hits) { + last_max_depth = max_depth[H + umax]; + } + if (ucur < max_steps) { + curr_depth = min_depth[H] + (float(ucur) + uniform_noise[K + ucur]) * step_size; + } + + if (last_max_depth <= curr_depth && last_max_depth <= last_min_depth) { + sampled_depth[K + s] = last_max_depth; + sampled_idx[K + s] = pts_idx[H + umax]; + umax++; s++; continue; + } + if (curr_depth <= last_min_depth && curr_depth <= last_max_depth) { + sampled_depth[K + s] = curr_depth; + sampled_idx[K + s] = pts_idx[H + umin - 1]; + ucur++; s++; continue; + } + if (last_min_depth <= curr_depth && last_min_depth <= last_max_depth) { + sampled_depth[K + s] = last_min_depth; + sampled_idx[K + s] = pts_idx[H + umin]; + umin++; s++; continue; + } + } + + float l_depth, r_depth; + int step = 0; + for (ucur = 0, umin = 0, umax = 0; ucur < max_steps - 1; ucur++) { + l_depth = sampled_depth[K + ucur]; + r_depth = sampled_depth[K + ucur + 1]; + sampled_depth[K + ucur] = (l_depth + r_depth) * .5; + sampled_dists[K + ucur] = (r_depth - l_depth); + if (sampled_depth[K + ucur] >= min_depth[H + umin] && umin < max_hits) umin++; + if (sampled_depth[K + ucur] >= max_depth[H + umax] && umax < max_hits) umax++; + if (umax == max_hits || pts_idx[H + umax] == -1) break; + if (umin - 1 == umax && sampled_dists[K + ucur] > 0) { + sampled_depth[K + step] = sampled_depth[K + ucur]; + sampled_dists[K + step] = sampled_dists[K + ucur]; + sampled_idx[K + step] = sampled_idx[K + ucur]; + step++; + } + } + for (int s = step; s < max_steps; s++) { + sampled_idx[K + s] = -1; + } + } + + +} + + +void uniform_ray_sampling_kernel_wrapper( + int b, int num_rays, int max_hits, int max_steps, float step_size, + const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise, + int *sampled_idx, float *sampled_depth, float *sampled_dists) { + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + uniform_ray_sampling_kernel<<>>( + b, num_rays, max_hits, max_steps, step_size, pts_idx, + min_depth, max_depth, uniform_noise, sampled_idx, sampled_depth, sampled_dists); + + CUDA_CHECK_ERRORS(); +} + diff --git a/fairnr/clib/src/octree.cpp b/fairnr/clib/src/octree.cpp new file mode 100644 index 0000000..e1c8ab0 --- /dev/null +++ b/fairnr/clib/src/octree.cpp @@ -0,0 +1,136 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "octree.h" +#include "utils.h" +#include +#include +using namespace std::chrono; + + +typedef struct OcTree +{ + int depth; + int index; + at::Tensor center; + struct OcTree *children[8]; + void init(at::Tensor center, int d, int i) { + this->center = center; + this->depth = d; + this->index = i; + for (int i=0; i<8; i++) this->children[i] = nullptr; + } +}OcTree; + +class EasyOctree { + public: + OcTree *root; + int total; + int terminal; + + at::Tensor all_centers; + at::Tensor all_children; + + EasyOctree(at::Tensor center, int depth) { + root = new OcTree; + root->init(center, depth, -1); + total = -1; + terminal = -1; + } + ~EasyOctree() { + OcTree *p = root; + destory(p); + } + void destory(OcTree * &p); + void insert(OcTree * &p, at::Tensor point, int index); + void finalize(); + std::pair count(OcTree * &p); +}; + +void EasyOctree::destory(OcTree * &p){ + if (p != nullptr) { + for (int i=0; i<8; i++) { + if (p->children[i] != nullptr) destory(p->children[i]); + } + delete p; + p = nullptr; + } +} + +void EasyOctree::insert(OcTree * &p, at::Tensor point, int index) { + at::Tensor diff = (point > p->center).to(at::kInt); + int idx = diff[0].item() + 2 * diff[1].item() + 4 * diff[2].item(); + if (p->depth == 0) { + p->children[idx] = new OcTree; + p->children[idx]->init(point, -1, index); + } else { + if (p->children[idx] == nullptr) { + int length = 1 << (p->depth - 1); + at::Tensor new_center = p->center + (2 * diff - 1) * length; + p->children[idx] = new OcTree; + p->children[idx]->init(new_center, p->depth-1, -1); + } + insert(p->children[idx], point, index); + } +} + +std::pair EasyOctree::count(OcTree * &p) { + int total = 0, terminal = 0; + for (int i=0; i<8; i++) { + if (p->children[i] != nullptr) { + std::pair sub = count(p->children[i]); + total += sub.first; + terminal += sub.second; + } + } + total += 1; + if (p->depth == -1) terminal += 1; + return std::make_pair(total, terminal); +} + +void EasyOctree::finalize() { + std::pair outs = count(root); + total = outs.first; terminal = outs.second; + + all_centers = + torch::zeros({outs.first, 3}, at::device(root->center.device()).dtype(at::ScalarType::Int)); + all_children = + -torch::ones({outs.first, 9}, at::device(root->center.device()).dtype(at::ScalarType::Int)); + + int node_idx = outs.first - 1; + root->index = node_idx; + + std::queue all_leaves; all_leaves.push(root); + while (!all_leaves.empty()) { + OcTree* node_ptr = all_leaves.front(); + all_leaves.pop(); + for (int i=0; i<8; i++) { + if (node_ptr->children[i] != nullptr) { + if (node_ptr->children[i]->depth > -1) { + node_idx--; + node_ptr->children[i]->index = node_idx; + } + all_leaves.push(node_ptr->children[i]); + all_children[node_ptr->index][i] = node_ptr->children[i]->index; + } + } + all_children[node_ptr->index][8] = 1 << (node_ptr->depth + 1); + all_centers[node_ptr->index] = node_ptr->center; + } + assert (node_idx == outs.second); +}; + +std::tuple build_octree(at::Tensor center, at::Tensor points, int depth) { + auto start = high_resolution_clock::now(); + EasyOctree tree(center, depth); + for (int k=0; k(stop - start); + printf("Building EasyOctree done. total #nodes = %d, terminal #nodes = %d (time taken %f s)\n", + tree.total, tree.terminal, float(duration.count()) / 1000000.); + return std::make_tuple(tree.all_centers, tree.all_children); +} \ No newline at end of file diff --git a/fairnr/criterions/__init__.py b/fairnr/criterions/__init__.py new file mode 100644 index 0000000..f439fc6 --- /dev/null +++ b/fairnr/criterions/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + criterion_name = file[: file.find(".py")] + importlib.import_module( + "fairnr.criterions." + criterion_name + ) diff --git a/fairnr/criterions/perceptual_loss.py b/fairnr/criterions/perceptual_loss.py new file mode 100644 index 0000000..836ea3a --- /dev/null +++ b/fairnr/criterions/perceptual_loss.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchvision + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, resize=False): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) + self.blocks = torch.nn.ModuleList(blocks) + self.transform = torch.nn.functional.interpolate + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) + self.resize = resize + + # NO GRADIENT! + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input, target, level=2): + # print(input.device, input.dtype, self.mean.device, self.mean.dtype, self.std, self.std.dtype) + if input.shape[1] != 3: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + input = (input-self.mean) / self.std + target = (target-self.mean) / self.std + + if self.resize: + input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) + target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) + + loss = 0.0 + x = input + y = target + for i, block in enumerate(self.blocks): + x = block(x) + y = block(y) + if i < level: + loss += torch.nn.functional.mse_loss(x, y) + else: + break + return loss diff --git a/fairnr/criterions/rendering_loss.py b/fairnr/criterions/rendering_loss.py new file mode 100644 index 0000000..91f4869 --- /dev/null +++ b/fairnr/criterions/rendering_loss.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import math + +import torch.nn.functional as F +import torch +from torch import Tensor + +from fairseq import metrics +from fairseq.utils import item +from fairseq.criterions import FairseqCriterion, register_criterion +import fairnr.criterions.utils as utils + +class RenderingCriterion(FairseqCriterion): + + def __init__(self, args, task): + super().__init__(task) + self.args = args + + @classmethod + def build_criterion(cls, args, task): + """Construct a criterion from command-line args.""" + return cls(args, task) + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + pass + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample) + loss, loss_output = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = 1 + + logging_output = { + 'loss': loss.data.item() if reduce else loss.data, + 'nsentences': sample['alpha'].size(0), + 'ntokens': sample['alpha'].size(1), + 'npixels': sample['alpha'].size(2), + 'sample_size': sample_size, + } + for w in loss_output: + logging_output[w] = loss_output[w] + + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + raise NotImplementedError + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + summed_logging_outputs = { + w: sum(log.get(w, 0) for log in logging_outputs) + for w in logging_outputs[0] + } + sample_size = summed_logging_outputs['sample_size'] + + for w in summed_logging_outputs: + if '_loss' in w: + metrics.log_scalar(w[:5].split('_')[0], summed_logging_outputs[w] / sample_size, sample_size, round=3) + elif '_weight' in w: + metrics.log_scalar('w_' + w[:3], summed_logging_outputs[w] / sample_size, sample_size, round=3) + elif '_acc' in w: + metrics.log_scalar('a_' + w[:3], summed_logging_outputs[w] / sample_size, sample_size, round=3) + elif w == 'loss': + metrics.log_scalar('loss', summed_logging_outputs['loss'] / sample_size, sample_size, priority=0, round=3) + elif '_log' in w: + metrics.log_scalar(w[:3], summed_logging_outputs[w] / sample_size, sample_size, priority=1, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True + + +@register_criterion('srn_loss') +class SRNLossCriterion(RenderingCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + # HACK: to avoid warnings in c10d + self.dummy_loss = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32), requires_grad=True) + if args.vgg_weight > 0: + from fairnr.criterions.perceptual_loss import VGGPerceptualLoss + self.vgg = VGGPerceptualLoss(resize=False) + + if args.eval_lpips: + from lpips_pytorch import LPIPS + self.lpips = LPIPS(net_type='alex', version='0.1') + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + parser.add_argument('--L1', action='store_true', + help='if enabled, use L1 instead of L2 for RGB loss') + parser.add_argument('--color-weight', type=float, default=256.0) + parser.add_argument('--depth-weight', type=float, default=0.0) + parser.add_argument('--depth-weight-decay', type=str, default=None, + help="""if set, use tuple to set (final_ratio, steps). + For instance, (0, 30000) + """) + parser.add_argument('--alpha-weight', type=float, default=0.0) + parser.add_argument('--vgg-weight', type=float, default=0.0) + parser.add_argument('--vgg-level', type=int, choices=[1,2,3,4], default=2) + parser.add_argument('--eval-lpips', action='store_true', + help="evaluate LPIPS scores in validation") + parser.add_argument('--no-background-loss', action='store_true') + + def compute_loss(self, model, net_output, sample, reduce=True): + losses, other_logs = {}, {} + + # prepare data before computing loss + sampled_uv = net_output['sampled_uv'] # S, V, 2, N, P, P (patch-size) + S, V, _, N, P1, P2 = sampled_uv.size() + H, W = int(sample['size'][0, 0, 0]), int(sample['size'][0, 0, 1]) + L = N * P1 * P2 + flatten_uv = sampled_uv.view(S, V, 2, L) + flatten_index = (flatten_uv[:,:,0] + flatten_uv[:,:,1] * W).long() + + assert 'colors' in sample and sample['colors'] is not None, "ground-truth colors not provided" + target_colors = sample['colors'] + masks = (sample['alpha'] > 0) if self.args.no_background_loss else None + if L < target_colors.size(2): + target_colors = target_colors.gather(2, flatten_index.unsqueeze(-1).repeat(1,1,1,3)) + masks = masks.gather(2, flatten_uv) if masks is not None else None + + if 'other_logs' in net_output: + other_logs.update(net_output['other_logs']) + + # computing loss + if self.args.color_weight > 0: + color_loss = utils.rgb_loss( + net_output['colors'], target_colors, + masks, self.args.L1) + losses['color_loss'] = (color_loss, self.args.color_weight) + + if self.args.alpha_weight > 0: + _alpha = net_output['missed'].reshape(-1) + alpha_loss = torch.log1p( + 1. / 0.11 * _alpha.float() * (1 - _alpha.float()) + ).mean().type_as(_alpha) + losses['alpha_loss'] = (alpha_loss, self.args.alpha_weight) + + if self.args.depth_weight > 0: + if sample['depths'] is not None: + target_depths = target_depths.gather(2, flatten_index) + depth_mask = masks & (target_depths > 0) + depth_loss = utils.depth_loss(net_output['depths'], target_depths, depth_mask) + + else: + # no depth map is provided, depth loss only applied on background based on masks + max_depth_target = self.args.max_depth * torch.ones_like(net_output['depths']) + if sample['mask'] is not None: + depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, (1 - sample['mask']).bool()) + else: + depth_loss = utils.depth_loss(net_output['depths'], max_depth_target, ~masks) + + depth_weight = self.args.depth_weight + if self.args.depth_weight_decay is not None: + final_factor, final_steps = eval(self.args.depth_weight_decay) + depth_weight *= max(0, 1 - (1 - final_factor) * self.task._num_updates / final_steps) + other_logs['depth_weight'] = depth_weight + + losses['depth_loss'] = (depth_loss, depth_weight) + + + if self.args.vgg_weight > 0: + assert P1 * P2 > 1, "we have to use a patch-based sampling for VGG loss" + target_colors = target_colors.reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5 + output_colors = net_output['colors'].reshape(-1, P1, P2, 3).permute(0, 3, 1, 2) * .5 + .5 + vgg_loss = self.vgg(output_colors, target_colors) + losses['vgg_loss'] = (vgg_loss, self.args.vgg_weight) + + loss = sum(losses[key][0] * losses[key][1] for key in losses) + + # add a dummy loss + loss = loss + model.dummy_loss + self.dummy_loss * 0. + logging_outputs = {key: item(losses[key][0]) for key in losses} + logging_outputs.update(other_logs) + return loss, logging_outputs \ No newline at end of file diff --git a/fairnr/criterions/utils.py b/fairnr/criterions/utils.py new file mode 100644 index 0000000..df266b4 --- /dev/null +++ b/fairnr/criterions/utils.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + +TINY = 1e-7 + + +def rgb_loss(predicts, rgbs, masks=None, L1=False, sum=False): + if masks is not None: + if masks.sum() == 0: + return predicts.new_zeros(1).mean() + predicts = predicts[masks] + rgbs = rgbs[masks] + + if L1: + loss = torch.abs(predicts - rgbs).sum(-1) + else: + loss = ((predicts - rgbs) ** 2).sum(-1) + + return loss.mean() if not sum else loss.sum() + + +def depth_loss(depths, depth_gt, masks=None, sum=False): + if masks is not None: + if masks.sum() == 0: + return depths.new_zeros(1).mean() + depth_gt = depth_gt[masks] + depths = depths[masks] + + loss = (depths[masks] - depth_gt[masks]) ** 2 + return loss.mean() if not sum else loss.sum() \ No newline at end of file diff --git a/fairnr/data/__init__.py b/fairnr/data/__init__.py new file mode 100644 index 0000000..cd9c291 --- /dev/null +++ b/fairnr/data/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .shape_dataset import ( + ShapeDataset, ShapeViewDataset, ShapeViewStreamDataset, + SampledPixelDataset, WorldCoordDataset, + InfiniteDataset +) + +__all__ = [ + 'ShapeDataset', + 'ShapeViewDataset', + 'ShapeViewStreamDataset', + 'SampledPixelDataset', + 'WorldCoordDataset', +] diff --git a/fairnr/data/data_utils.py b/fairnr/data/data_utils.py new file mode 100644 index 0000000..1f1e558 --- /dev/null +++ b/fairnr/data/data_utils.py @@ -0,0 +1,345 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import functools +import cv2 +import math +import numpy as np +import imageio +from glob import glob +import os +import copy +import shutil +import skimage.metrics +import pandas as pd +import pylab as plt +import fairseq.distributed_utils as du +from fairseq.meters import StopwatchMeter + +def get_rank(): + try: + return du.get_rank() + except AssertionError: + return 0 + + +def get_world_size(): + try: + return du.get_world_size() + except AssertionError: + return 1 + + +def parse_views(view_args): + output = [] + try: + xx = view_args.split(':') + ids = xx[0].split(',') + for id in ids: + if '..' in id: + a, b = id.split('..') + output += list(range(int(a), int(b))) + else: + output += [int(id)] + if len(xx) > 1: + output = output[::int(xx[-1])] + except Exception as e: + raise Exception("parse view args error: {}".format(e)) + + return output + + +def get_uv(H, W, h, w): + """ + H, W: real image (intrinsics) + h, w: resized image + """ + uv = np.flip(np.mgrid[0: h, 0: w], axis=0).astype(np.float32) + uv[0] = uv[0] * float(W / w) + uv[1] = uv[1] * float(H / h) + return uv, [float(H / h), float(W / w)] + + +def load_rgb( + path, + resolution=None, + with_alpha=True, + bg_color=[1.0, 1.0, 1.0], + min_rgb=-1, + interpolation='AREA'): + if with_alpha: + img = imageio.imread(path) # RGB-ALPHA + else: + img = imageio.imread(path)[:, :, :3] + + img = skimage.img_as_float32(img).astype('float32') + H, W, D = img.shape + h, w = resolution + + if D == 3: + img = np.concatenate([img, np.ones((img.shape[0], img.shape[1], 1))], -1).astype('float32') + + uv, ratio = get_uv(H, W, h, w) + if (h < H) or (w < W): + # img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') + img = cv2.resize(img, (w, h), interpolation=cv2.INTER_AREA).astype('float32') + + if min_rgb == -1: # 0, 1 --> -1, 1 + img[:, :, :3] -= 0.5 + img[:, :, :3] *= 2. + + img[:, :, :3] = img[:, :, :3] * img[:, :, 3:] + np.asarray(bg_color)[None, None, :] * (1 - img[:, :, 3:]) + img[:, :, 3] = img[:, :, 3] * (img[:, :, :3] != np.asarray(bg_color)[None, None, :]).any(-1) + img = img.transpose(2, 0, 1) + + return img, uv, ratio + + +def load_depth(path, resolution=None, depth_plane=5): + if path is None: + return None + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) + # ret, img = cv2.threshold(img, depth_plane, depth_plane, cv2.THRESH_TRUNC) + + H, W = img.shape[:2] + h, w = resolution + if (h < H) or (w < W): + img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') + #img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) + + if len(img.shape) ==3: + img = img[:,:,:1] + img = img.transpose(2,0,1) + else: + img = img[None,:,:] + return img + + +def load_mask(path, resolution=None): + if path is None: + return None + + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) + h, w = resolution + H, W = img.shape[:2] + if (h < H) or (w < W): + img = cv2.resize(img, (w, h), interpolation=cv2.INTER_NEAREST).astype('float32') + img = img / (img.max() + 1e-7) + return img + + +def load_matrix(path): + return np.array([[float(w) for w in line.strip().split()] for line in open(path)]).astype(np.float32) + + +def load_intrinsics(filepath, resized_width=None, invert_y=False): + try: + intrinsics = load_matrix(filepath) + if intrinsics.shape[0] == 3 and intrinsics.shape[1] == 3: + _intrinsics = np.zeros((4, 4), np.float32) + _intrinsics[:3, :3] = intrinsics + _intrinsics[3, 3] = 1 + intrinsics = _intrinsics + return intrinsics + except ValueError: + pass + + # Get camera intrinsics + with open(filepath, 'r') as file: + f, cx, cy, _ = map(float, file.readline().split()) + fx = f + if invert_y: + fy = -f + else: + fy = f + + # Build the intrinsic matrices + full_intrinsic = np.array([[fx, 0., cx, 0.], + [0., fy, cy, 0], + [0., 0, 1, 0], + [0, 0, 0, 1]]) + return full_intrinsic + + +def unflatten_img(img, width=512): + sizes = img.size() + height = sizes[-1] // width + return img.reshape(*sizes[:-1], height, width) + + +def square_crop_img(img): + if img.shape[0] == img.shape[1]: + return img # already square + + min_dim = np.amin(img.shape[:2]) + center_coord = np.array(img.shape[:2]) // 2 + img = img[center_coord[0] - min_dim // 2:center_coord[0] + min_dim // 2, + center_coord[1] - min_dim // 2:center_coord[1] + min_dim // 2] + return img + + +def sample_pixel_from_image( + num_pixel, num_sample, + mask=None, ratio=1.0, + use_bbox=False, + center_ratio=1.0, + width=512, + patch_size=1): + + if patch_size > 1: + assert (num_pixel % (patch_size * patch_size) == 0) \ + and (num_sample % (patch_size * patch_size) == 0), "size must match" + _num_pixel = num_pixel // (patch_size * patch_size) + _num_sample = num_sample // (patch_size * patch_size) + height = num_pixel // width + + _mask = None if mask is None else \ + mask.reshape(height, width).reshape( + height//patch_size, patch_size, width//patch_size, patch_size + ).any(1).any(-1).reshape(-1) + _width = width // patch_size + _out = sample_pixel_from_image(_num_pixel, _num_sample, _mask, ratio, use_bbox, _width) + _x, _y = _out % _width, _out // _width + x, y = _x * patch_size, _y * patch_size + x = x[:, None, None] + np.arange(patch_size)[None, :, None] + y = y[:, None, None] + np.arange(patch_size)[None, None, :] + out = x + y * width + return out.reshape(-1) + + if center_ratio < 1.0: + r = (1 - center_ratio) / 2.0 + H, W = num_pixel // width, width + mask0 = np.zeros((H, W)) + mask0[int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1 + mask0 = mask0.reshape(-1) + + if mask is None: + mask = mask0 + else: + mask = mask * mask0 + + if mask is not None: + mask = (mask > 0.0).astype('float32') + + if (mask is None) or \ + (ratio <= 0.0) or \ + (mask.sum() == 0) or \ + ((1 - mask).sum() == 0): + return np.random.choice(num_pixel, num_sample) + + if use_bbox: + mask = mask.reshape(-1, width) + x, y = np.where(mask == 1) + mask = np.zeros_like(mask) + mask[x.min(): x.max()+1, y.min(): y.max()+1] = 1.0 + mask = mask.reshape(-1) + + try: + probs = mask * ratio / (mask.sum()) + (1 - mask) / (num_pixel - mask.sum()) * (1 - ratio) + # x = np.random.choice(num_pixel, num_sample, True, p=probs) + return np.random.choice(num_pixel, num_sample, True, p=probs) + + except Exception: + return np.random.choice(num_pixel, num_sample) + + +def colormap(dz): + # return plt.cm.jet(dz) + # return plt.cm.viridis(dz) + return plt.cm.gray(dz) + + +def recover_image(img, min_val=-1, max_val=1, width=512, bg=None, weight=None): + sizes = img.size() + height = sizes[0] // width + img = img.float().to('cpu') + + if len(sizes) == 1 and (bg is not None): + bg_mask = img.eq(bg)[:, None].type_as(img) + + img = ((img - min_val) / (max_val - min_val)).clamp(min=0, max=1) + if len(sizes) == 1: + img = torch.from_numpy(colormap(img.numpy())[:, :3]) + if weight is not None: + weight = weight.float().to('cpu') + img = img * weight[:, None] + + if bg is not None: + img = img * (1 - bg_mask) + bg_mask + img = img.reshape(height, width, -1) + return img + + +def write_images(writer, images, updates): + for tag in images: + img = images[tag] + tag, dataform = tag.split(':') + writer.add_image(tag, img, updates, dataformats=dataform) + + +def compute_psnr(p, t): + """Compute PSNR of model image predictions. + :param prediction: Return value of forward pass. + :param ground_truth: Ground truth. + :return: (psnr, ssim): tuple of floats + """ + ssim = skimage.metrics.structural_similarity(p, t, multichannel=True, data_range=1) + psnr = skimage.metrics.peak_signal_noise_ratio(p, t, data_range=1) + return ssim, psnr + + +class InfIndex(object): + + def __init__(self, index_list, shuffle=False): + self.index_list = index_list + self.size = len(index_list) + self.shuffle = shuffle + self.reset_permutation() + + def reset_permutation(self): + if self.shuffle: + self._perm = np.random.permutation(self.index_list).tolist() + else: + self._perm = copy.deepcopy(self.index_list) + + def __iter__(self): + return self + + def __next__(self): + if len(self._perm) == 0: + self.reset_permutation() + return self._perm.pop() + + def __len__(self): + return self.size + + +class Timer(StopwatchMeter): + def __enter__(self): + """Start a new timer as a context manager""" + self.start() + return self + + def __exit__(self, *exc_info): + """Stop the context manager timer""" + self.stop() + + +class GPUTimer(object): + def __enter__(self): + """Start a new timer as a context manager""" + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self.start.record() + self.sum = 0 + return self + + def __exit__(self, *exc_info): + """Stop the context manager timer""" + self.end.record() + torch.cuda.synchronize() + self.sum = self.start.elapsed_time(self.end) / 1000. diff --git a/fairnr/data/geometry.py b/fairnr/data/geometry.py new file mode 100644 index 0000000..c6b1e2d --- /dev/null +++ b/fairnr/data/geometry.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn.functional as F + +from fairnr.data import data_utils as D + +INF = 1000.0 + + +def ones_like(x): + T = torch if isinstance(x, torch.Tensor) else np + return T.ones_like(x) + + +def stack(x): + T = torch if isinstance(x[0], torch.Tensor) else np + return T.stack(x) + + +def matmul(x, y): + T = torch if isinstance(x, torch.Tensor) else np + return T.matmul(x, y) + + +def cross(x, y, axis=0): + T = torch if isinstance(x, torch.Tensor) else np + return T.cross(x, y, axis) + + +def cat(x, axis=1): + if isinstance(x[0], torch.Tensor): + return torch.cat(x, dim=axis) + return np.concatenate(x, axis=axis) + + +def normalize(x, axis=-1, order=2): + if isinstance(x, torch.Tensor): + l2 = x.norm(p=order, dim=axis, keepdim=True) + return x / (l2 + 1e-8), l2 + + else: + l2 = np.linalg.norm(x, order, axis) + l2 = np.expand_dims(l2, axis) + l2[l2==0] = 1 + return x / l2, l2 + + +def parse_extrinsics(extrinsics, world2camera=True): + """ this function is only for numpy for now""" + if extrinsics.shape[0] == 3 and extrinsics.shape[1] == 4: + extrinsics = np.vstack([extrinsics, np.array([[0, 0, 0, 1.0]])]) + if extrinsics.shape[0] == 1 and extrinsics.shape[1] == 16: + extrinsics = extrinsics.reshape(4, 4) + if world2camera: + extrinsics = np.linalg.inv(extrinsics).astype(np.float32) + return extrinsics + + +def parse_intrinsics(intrinsics): + fx = intrinsics[0, 0] + fy = intrinsics[1, 1] + cx = intrinsics[0, 2] + cy = intrinsics[1, 2] + return fx, fy, cx, cy + + +def uv2cam(uv, z, intrinsics, homogeneous=False): + fx, fy, cx, cy = parse_intrinsics(intrinsics) + x_lift = (uv[0] - cx) / fx * z + y_lift = (uv[1] - cy) / fy * z + z_lift = ones_like(x_lift) * z + + if homogeneous: + return stack([x_lift, y_lift, z_lift, ones_like(z_lift)]) + else: + return stack([x_lift, y_lift, z_lift]) + + +def cam2world(xyz_cam, inv_RT): + return matmul(inv_RT, xyz_cam)[:3] + + +def r6d2mat(d6: torch.Tensor) -> torch.Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalisation per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def get_ray_direction(ray_start, uv, intrinsics, inv_RT, depths=None): + if depths is None: + depths = 1 + rt_cam = uv2cam(uv, depths, intrinsics, True) + rt = cam2world(rt_cam, inv_RT) + ray_dir, _ = normalize(rt - ray_start[:, None], axis=0) + return ray_dir + + +def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): + """ + This function takes a vector 'camera_position' which specifies the location + of the camera in world coordinates and two vectors `at` and `up` which + indicate the position of the object and the up directions of the world + coordinate system respectively. The object is assumed to be centered at + the origin. + + The output is a rotation matrix representing the transformation + from world coordinates -> view coordinates. + + Input: + camera_position: 3 + at: 1 x 3 or N x 3 (0, 0, 0) in default + up: 1 x 3 or N x 3 (0, 1, 0) in default + """ + + if at is None: + at = torch.zeros_like(camera_position) + else: + at = torch.tensor(at).type_as(camera_position) + if up is None: + up = torch.zeros_like(camera_position) + up[2] = -1 + else: + up = torch.tensor(up).type_as(camera_position) + + z_axis = normalize(at - camera_position)[0] + x_axis = normalize(cross(up, z_axis))[0] + y_axis = normalize(cross(z_axis, x_axis))[0] + + R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) + return R + + +def ray(ray_start, ray_dir, depths): + return ray_start + ray_dir * depths + + +def compute_normal_map(ray_start, ray_dir, depths, RT, width=512, proj=False): + # TODO: + # this function is pytorch-only (for not) + wld_coords = ray(ray_start, ray_dir, depths.unsqueeze(-1)).transpose(0, 1) + cam_coords = matmul(RT[:3, :3], wld_coords) + RT[:3, 3].unsqueeze(-1) + cam_coords = D.unflatten_img(cam_coords, width) + + # estimate local normal + shift_l = cam_coords[:, 2:, :] + shift_r = cam_coords[:, :-2, :] + shift_u = cam_coords[:, :, 2: ] + shift_d = cam_coords[:, :, :-2] + diff_hor = normalize(shift_r - shift_l, axis=0)[0][:, :, 1:-1] + diff_ver = normalize(shift_u - shift_d, axis=0)[0][:, 1:-1, :] + normal = cross(diff_hor, diff_ver) + _normal = normal.new_zeros(*cam_coords.size()) + _normal[:, 1:-1, 1:-1] = normal + _normal = _normal.reshape(3, -1).transpose(0, 1) + + # compute the projected color + if proj: + _normal = normalize(_normal, axis=1)[0] + wld_coords0 = ray(ray_start, ray_dir, 0).transpose(0, 1) + cam_coords0 = matmul(RT[:3, :3], wld_coords0) + RT[:3, 3].unsqueeze(-1) + cam_coords0 = D.unflatten_img(cam_coords0, width) + cam_raydir = normalize(cam_coords - cam_coords0, 0)[0].reshape(3, -1).transpose(0, 1) + proj_factor = (_normal * cam_raydir).sum(-1).abs() * 0.8 + 0.2 + return proj_factor + return _normal + + +def trilinear_interp(p, q, point_feats): + weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True) + if point_feats.dim() == 2: + point_feats = point_feats.view(point_feats.size(0), 8, -1) + point_feats = (weights * point_feats).sum(1) + return point_feats + + +# helper functions for encoder + +def padding_points(xs, pad): + if len(xs) == 1: + return xs[0].unsqueeze(0) + + maxlen = max([x.size(0) for x in xs]) + xt = xs[0].new_ones(len(xs), maxlen, xs[0].size(1)).fill_(pad) + for i in range(len(xs)): + xt[i, :xs[i].size(0)] = xs[i] + return xt + + +def pruning_points(feats, points, scores, depth=0, th=0.5): + if depth > 0: + g = int(8 ** depth) + scores = scores.reshape(scores.size(0), -1, g).sum(-1, keepdim=True) + scores = scores.expand(*scores.size()[:2], g).reshape(scores.size(0), -1) + alpha = (1 - torch.exp(-scores)) > th + feats = [feats[i][alpha[i]] for i in range(alpha.size(0))] + points = [points[i][alpha[i]] for i in range(alpha.size(0))] + points = padding_points(points, INF) + feats = padding_points(feats, 0) + return feats, points + + +def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2): + c = torch.arange(1, 2 * bits, 2, device=point_xyz.device) + ox, oy, oz = torch.meshgrid([c, c, c]) + offset = (torch.cat([ + ox.reshape(-1, 1), + oy.reshape(-1, 1), + oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1) + if not offset_only: + return point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel + return offset.type_as(point_xyz) * quarter_voxel + + +def splitting_points(point_xyz, point_feats, values, half_voxel): + # generate new centers + quarter_voxel = half_voxel * .5 + new_points = offset_points(point_xyz, quarter_voxel).reshape(-1, 3) + old_coords = (point_xyz / quarter_voxel).floor_().long() + new_coords = offset_points(old_coords).reshape(-1, 3) + new_keys0 = offset_points(new_coords).reshape(-1, 3) + + # get unique keys and inverse indices (for original key0, where it maps to in keys) + new_keys, new_feats = torch.unique(new_keys0, dim=0, sorted=True, return_inverse=True) + new_keys_idx = new_feats.new_zeros(new_keys.size(0)).scatter_( + 0, new_feats, torch.arange(new_keys0.size(0), device=new_feats.device) // 64) + + # recompute key vectors using trilinear interpolation + new_feats = new_feats.reshape(-1, 8) + + if values is not None: + p = (new_keys - old_coords[new_keys_idx]).type_as(point_xyz).unsqueeze(1) * .25 + 0.5 # (1/4 voxel size) + q = offset_points(p, .5, offset_only=True).unsqueeze(0) + 0.5 # BUG? + point_feats = point_feats[new_keys_idx] + point_feats = F.embedding(point_feats, values).view(point_feats.size(0), -1) + new_values = trilinear_interp(p, q, point_feats) + else: + new_values = None + return new_points, new_feats, new_values, new_keys + + +def expand_points(voxel_points, voxel_size): + _voxel_size = min([ + torch.sqrt(((voxel_points[j:j+1] - voxel_points[j+1:]) ** 2).sum(-1).min()) + for j in range(100)]) + depth = int(np.round(torch.log2(_voxel_size / voxel_size))) + if depth > 0: + half_voxel = _voxel_size / 2.0 + for _ in range(depth): + voxel_points = offset_points(voxel_points, half_voxel / 2.0).reshape(-1, 3) + half_voxel = half_voxel / 2.0 + + return voxel_points, depth + + +def get_edge(depth_pts, voxel_pts, voxel_size, th=0.05): + voxel_pts = offset_points(voxel_pts, voxel_size / 2.0) + diff_pts = (voxel_pts - depth_pts[:, None, :]).norm(dim=2) + ab = diff_pts.sort(dim=1)[0][:, :2] + a, b = ab[:, 0], ab[:, 1] + c = voxel_size + p = (ab.sum(-1) + c) / 2.0 + h = (p * (p - a) * (p - b) * (p - c)) ** 0.5 / c + return h < (th * voxel_size) + + +# fill-in image +def fill_in(shape, hits, input, initial=1.0): + if isinstance(initial, torch.Tensor): + output = initial.expand(*shape) + else: + output = input.new_ones(*shape) * initial + if input is not None: + if len(shape) == 1: + return output.masked_scatter(hits, input) + return output.masked_scatter(hits.unsqueeze(-1).expand(*shape), input) + return output + + +def build_easy_octree(points, half_voxel): + from fairnr.clib._ext import build_octree + coords = (points / half_voxel).floor_().long() # works easier in int space + residual = (points - coords.float() * half_voxel).mean(0, keepdim=True) + ranges = coords.max(0)[0] - coords.min(0)[0] + depths = torch.log2(ranges.max().float()).ceil_().long() - 1 + center = (coords.max(0)[0] + coords.min(0)[0]) / 2 + centers, children = build_octree(center, coords, int(depths)) + centers = centers.float() * half_voxel + residual # transform back to float + return centers, children \ No newline at end of file diff --git a/fairnr/data/shape_dataset.py b/fairnr/data/shape_dataset.py new file mode 100644 index 0000000..ace4135 --- /dev/null +++ b/fairnr/data/shape_dataset.py @@ -0,0 +1,525 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os, glob +import copy +import numpy as np +import torch +import logging + +from collections import defaultdict +from fairseq.data import FairseqDataset, BaseWrapperDataset +from . import data_utils, geometry, trajectory + + +logger = logging.getLogger(__name__) + + +class ShapeDataset(FairseqDataset): + """ + A dataset that only returns data per shape + """ + def __init__(self, + paths, + preload=True, + repeat=1, + subsample_valid=-1, + ids=None): + + if os.path.isdir(paths): + self.paths = [paths] + else: + self.paths = [line.strip() for line in open(paths)] + + self.subsample_valid = subsample_valid + self.total_num_shape = len(self.paths) + self.cache = None + self.repeat = repeat + + # -- load per-shape data + _data_per_shape = {} + _data_per_shape['shape'] = list(range(len(self.paths))) + _ixts = self.find_intrinsics() + if len(_ixts) > 0: + _data_per_shape['ixt'] = _ixts + + if self.subsample_valid > -1: + for key in _data_per_shape: + _data_per_shape[key] = _data_per_shape[key][::self.subsample_valid] + self.paths = self.paths[::self.subsample_valid] + self.total_num_shape = len(self.paths) + + # group the data.. + data_list = [] + for r in range(repeat): + # HACK: making several copies to enable multi-GPU usage. + if r == 0 and preload: + self.cache = [] + logger.info('pre-load the dataset into memory.') + + for id in range(self.total_num_shape): + element = {} + for key in _data_per_shape: + element[key] = _data_per_shape[key][id] + data_list.append(element) + + if r == 0 and preload: + self.cache += [self._load_batch(data_list, id)] + + # group the data together + self.data = data_list + + def find_intrinsics(self): + ixt_list = [] + for path in self.paths: + if os.path.exists(path + '/intrinsic.txt'): + ixt_list.append(path + '/intrinsic.txt') + elif os.path.exists(path + '/intrinsics.txt'): + ixt_list.append(path + '/intrinsics.txt') + return ixt_list + + def _load_shape(self, packed_data): + intrinsics = data_utils.load_intrinsics(packed_data['ixt']).astype('float32') \ + if packed_data.get('ixt', None) is not None else None + shape_id = packed_data['shape'] + return {'intrinsics': intrinsics, 'id': shape_id} + + def _load_batch(self, data, index): + return index, self._load_shape(data[index]) + + def __getitem__(self, index): + if self.cache is not None: + return self.cache[index % self.total_num_shape][0], \ + self.cache[index % self.total_num_shape][1] + return self._load_batch(self.data, index) + + def __len__(self): + return len(self.data) + + def num_tokens(self, index): + return 1 + + def _collater(self, samples): + results = {} + + results['shape'] = torch.from_numpy(np.array([s[0] for s in samples])) + for key in samples[0][1]: + if samples[0][1][key] is not None: + results[key] = torch.from_numpy( + np.array([s[1][key] for s in samples])) + else: + results[key] = None + return results + + def collater(self, samples): + try: + results = self._collater(samples) + except IndexError: + results = None + return results + + +class ShapeViewDataset(ShapeDataset): + """ + A dataset contains a series of images renderred offline for an object. + """ + + def __init__(self, + paths, + views, + num_view, + subsample_valid=-1, + resolution=None, + load_depth=False, + load_mask=False, + train=True, + preload=True, + repeat=1, + binarize=True, + bg_color="1,1,1", + min_color=-1, + ids=None): + + super().__init__(paths, False, repeat, subsample_valid, ids) + + self.train = train + self.load_depth = load_depth + self.load_mask = load_mask + self.views = views + self.num_view = num_view + + if isinstance(resolution, str): + self.resolution = [int(r) for r in resolution.split('x')] + else: + self.resolution = [resolution, resolution] + self.world2camera = True + self.cache_view = None + + bg_color = [float(b) for b in bg_color.split(',')] \ + if isinstance(bg_color, str) else [bg_color] + if min_color == -1: + bg_color = [b * 2 - 1 for b in bg_color] + if len(bg_color) == 1: + bg_color = bg_color + bg_color + bg_color + self.bg_color = bg_color + self.min_color = min_color + self.apply_mask_color = (self.bg_color[0] >= -1) & (self.bg_color[0] <= 1) # if need to apply + + # -- load per-view data + _data_per_view = {} + _data_per_view['rgb'] = self.find_rgb() + _data_per_view['ext'] = self.find_extrinsics() + if self.find_intrinsics_per_view() is not None: + _data_per_view['ixt_v'] = self.find_intrinsics_per_view() + if self.load_depth: + _data_per_view['dep'] = self.find_depth() + if self.load_mask: + _data_per_view['mask'] = self.find_mask() + _data_per_view['view'] = self.summary_view_data(_data_per_view) + + # group the data. + _index = 0 + for r in range(repeat): + # HACK: making several copies to enable multi-GPU usage. + if r == 0 and preload: + self.cache = [] + logger.info('pre-load the dataset into memory.') + + for id in range(self.total_num_shape): + element = {} + total_num_view = len(_data_per_view['rgb'][id]) + perm_ids = np.random.permutation(total_num_view) if train else np.arange(total_num_view) + for key in _data_per_view: + element[key] = [_data_per_view[key][id][i] for i in perm_ids] + self.data[_index].update(element) + + if r == 0 and preload: + phase_name = f"{'train' if self.train else 'valid'}" + \ + f".{self.resolution[0]}x{self.resolution[1]}" + \ + f"{'.d' if load_depth else ''}" + \ + f"{'.m' if load_mask else ''}" + \ + f"{'b' if not self.apply_mask_color else ''}" + \ + "_full" + logger.info("preload {}-{}".format(id, phase_name)) + if binarize: + cache = self._load_binary(id, np.arange(total_num_view), phase_name) + else: + cache = self._load_batch(self.data, id, np.arange(total_num_view)) + self.cache += [cache] + _index += 1 + + # group the data together + self.data_index = [] + for i, d in enumerate(self.data): + if self.train: + index_list = list(range(len(d['rgb']))) + self.data_index.append( + data_utils.InfIndex(index_list, shuffle=True) + ) + else: + copy_id = i // self.total_num_shape + index_list = [] + for j in range(copy_id * num_view, copy_id * num_view + num_view): + index_list.append(j % len(d['rgb'])) + self.data_index.append( + data_utils.InfIndex(index_list, shuffle=False) + ) + + def _load_binary(self, id, views, phase='train'): + root = os.path.dirname(self.data[id]['shape']) + npzfile = os.path.join(root, '{}.npz'.format(phase)) + try: + with np.load(npzfile, allow_pickle=True) as f: + return f['cache'] + except Exception: + cache = self._load_batch(self.data, id, views) + if data_utils.get_rank() == 0: + np.savez(npzfile, cache=cache) + return cache + + def select(self, file_list): + if len(file_list[0]) == 0: + raise FileNotFoundError + return [[files[i] for i in self.views] for files in file_list] + + def find_rgb(self): + try: + return self.select([sorted(glob.glob(path + '/rgb/*.*g')) for path in self.paths]) + except FileNotFoundError: + try: + return self.select([sorted(glob.glob(path + '/color/*.*g')) for path in self.paths]) + except FileNotFoundError: + raise FileNotFoundError("CANNOT find rendered images.") + + def find_depth(self): + try: + return self.select([sorted(glob.glob(path + '/depth/*.exr')) for path in self.paths]) + except FileNotFoundError: + raise FileNotFoundError("CANNOT find estimated depths images") + + def find_mask(self): + try: + return self.select([sorted(glob.glob(path + '/mask/*')) for path in self.paths]) + except FileNotFoundError: + raise FileNotFoundError("CANNOT find precomputed mask images") + + def find_extrinsics(self): + try: + return self.select([sorted(glob.glob(path + '/extrinsic/*.txt')) for path in self.paths]) + except FileNotFoundError: + try: + self.world2camera = False + return self.select([sorted(glob.glob(path + '/pose/*.txt')) for path in self.paths]) + except FileNotFoundError: + raise FileNotFoundError('world2camera or camera2world matrices not found.') + + def find_intrinsics_per_view(self): + try: + return self.select([sorted(glob.glob(path + '/intrinsic/*.txt')) for path in self.paths]) + except FileNotFoundError: + return None + + def summary_view_data(self, _data_per_view): + keys = [k for k in _data_per_view if _data_per_view[k] is not None] + num_of_objects = len(_data_per_view[keys[0]]) + for k in range(num_of_objects): + assert len(set([len(_data_per_view[key][k]) for key in keys])) == 1, "numer of views must be consistent." + return [list(range(len(_data_per_view[keys[0]][k]))) for k in range(num_of_objects)] + + def num_tokens(self, index): + return self.num_view + + def _load_view(self, packed_data, view_idx): + image, uv, ratio = data_utils.load_rgb( + packed_data['rgb'][view_idx], + resolution=self.resolution, + bg_color=self.bg_color, + min_rgb=self.min_color) + rgb, alpha = image[:3], image[3] # C x H x W for RGB + extrinsics = data_utils.load_matrix(packed_data['ext'][view_idx]) + extrinsics = geometry.parse_extrinsics(extrinsics, self.world2camera).astype('float32') # this is C2W + intrinsics = data_utils.load_intrinsics(packed_data['ixt_v'][view_idx]).astype('float32') \ + if packed_data.get('ixt_v', None) is not None else None + + z, mask = None, None + if packed_data.get('dep', None) is not None: + z = data_utils.load_depth(packed_data['dep'][view_idx], resolution=self.resolution) + if packed_data.get('mask', None) is not None: + mask = data_utils.load_mask(packed_data['mask'][view_idx], resolution=self.resolution) + if self.apply_mask_color: # we can also not apply mask + rgb = rgb * mask[None, :, :] + (1 - mask[None, :, :]) * np.asarray(self.bg_color)[:, None, None] + + return { + 'path': packed_data['rgb'][view_idx], + 'view': view_idx, + 'uv': uv.reshape(2, -1), + 'colors': rgb.reshape(3, -1), + 'alpha': alpha.reshape(-1), + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + 'depths': z.reshape(-1) if z is not None else None, + 'mask': mask.reshape(-1) if mask is not None else None, + 'size': np.array([rgb.shape[1], rgb.shape[2]] + ratio, dtype=np.float32) + } + + def _load_batch(self, data, index, view_ids=None): + if view_ids is None: + view_ids = [next(self.data_index[index]) for _ in range(self.num_view)] + return index, self._load_shape(data[index]), [self._load_view(data[index], view_id) for view_id in view_ids] + + def __getitem__(self, index): + if self.cache is not None: + view_ids = [next(self.data_index[index]) for _ in range(self.num_view)] + return copy.deepcopy(self.cache[index % self.total_num_shape][0]), \ + copy.deepcopy(self.cache[index % self.total_num_shape][1]), \ + [copy.deepcopy(self.cache[index % self.total_num_shape][2][i]) for i in view_ids] + return self._load_batch(self.data, index) + + def collater(self, samples): + results = super().collater(samples) + if results is None: + return results + + for key in samples[0][2][0]: + if key == 'path': + results[key] = [[d[key] for d in s[2]] for s in samples] + + elif samples[0][2][0][key] is not None: + results[key] = torch.from_numpy( + np.array([[d[key] for d in s[2]] for s in samples]) + ) + + results['colors'] = results['colors'].transpose(2, 3) + if results.get('full_rgb', None) is not None: + results['full_rgb'] = results['full_rgb'].transpose(2, 3) + return results + + +class ShapeViewStreamDataset(ShapeViewDataset): + """ + Different from ShapeViewDataset. + We merge all the views together into one dataset regardless of the shapes. + + ** HACK **: an alternative of the ShapeViewDataset + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.repeat == 1, "Comboned dataset does not support repeating" + assert self.num_view == 1, "StreamDataset only supports one view per shape at a time." + + # reset the data_index + self.data_index = [] + for i, d in enumerate(self.data): + for j, _ in enumerate(d['rgb']): + self.data_index.append((i, j)) # shape i, view j + + def __len__(self): + return len(self.data_index) + + def _load_batch(self, data, shape_id, view_ids): + return shape_id, self._load_shape(data[shape_id]), [self._load_view(data[shape_id], view_id) for view_id in view_ids] + + def __getitem__(self, index): + shape_id, view_id = self.data_index[index] + if self.cache is not None: + return copy.deepcopy(self.cache[shape_id % self.total_num_shape][0]), \ + copy.deepcopy(self.cache[shape_id % self.total_num_shape][1]), \ + [copy.deepcopy(self.cache[shape_id % self.total_num_shape][2][view_id])] + return self._load_batch(self.data, shape_id, [view_id]) + + def _load_binary(self, id, views, phase='train'): + root = os.path.dirname(self.data[id]['ixt']) + npzfile = os.path.join(root, '{}.npz'.format(phase)) + try: + with np.load(npzfile, allow_pickle=True) as f: + return f['cache'] + except Exception: + caches = [self._load_batch(self.data, id, view_id) for view_id in views] + cache = [caches[0][0], caches[0][1], [caches[i][2][0] for i in range(len(views))]] + + if data_utils.get_rank() == 0: + np.savez(npzfile, cache=cache) + return cache + + +class SampledPixelDataset(BaseWrapperDataset): + """ + A wrapper dataset, which split rendered images into pixels + """ + + def __init__(self, + dataset, + num_sample=None, + sampling_on_mask=1.0, + sampling_on_bbox=False, + sampling_at_center=1.0, + resolution=512, + patch_size=1): + + super().__init__(dataset) + self.num_sample = num_sample + self.sampling_on_mask = sampling_on_mask + self.sampling_on_bbox = sampling_on_bbox + self.sampling_at_center = sampling_at_center + self.patch_size = patch_size + self.res = resolution + + def __getitem__(self, index): + index, data_per_shape, data_per_view = self.dataset[index] + + # sample pixels from the original images + sample_index = [ + data_utils.sample_pixel_from_image( + data['alpha'].shape[-1], + self.num_sample, + data.get('mask', None) + if data.get('mask', None) is not None + else data.get('alpha', None), + self.sampling_on_mask, + self.sampling_on_bbox, + self.sampling_at_center, + width=int(data['size'][1]), + patch_size=self.patch_size) + for data in data_per_view + ] + + for i, data in enumerate(data_per_view): + data_per_view[i]['full_rgb'] = copy.deepcopy(data['colors']) + for key in data: + if data[key] is not None \ + and (key != 'extrinsics' and key != 'view' and key != 'full_rgb') \ + and data[key].shape[-1] > self.num_sample: + + if len(data[key].shape) == 2: + data_per_view[i][key] = data[key][:, sample_index[i]] + else: + data_per_view[i][key] = data[key][sample_index[i]] + data_per_view[i]['index'] = sample_index[i] + return index, data_per_shape, data_per_view + + def num_tokens(self, index): + return self.dataset.num_view * self.num_sample + + +class WorldCoordDataset(BaseWrapperDataset): + """ + A wrapper dataset. transform UV space into World space + """ + def __getitem__(self, index): + index, data_per_shape, data_per_view = self.dataset[index] + + def camera2world(data): + inv_RT = data['extrinsics'] + intrinsics = data_per_shape['intrinsics'] + + # get camera center (XYZ) + ray_start = inv_RT[:3, 3] + + # get points at a random depth (=1) + ray_dir = geometry.get_ray_direction( + ray_start, data['uv'], intrinsics, inv_RT, 1 + ) + + # here we still keep the original data for tracking purpose + data.update({ + 'ray_start': ray_start, + 'ray_dir': ray_dir, + }) + return data + + return index, data_per_shape, [camera2world(data) for data in data_per_view] + + def collater(self, samples): + results = self.dataset.collater(samples) + if results is None: + return results + + results['ray_start'] = results['ray_start'].unsqueeze(-2) + results['ray_dir'] = results['ray_dir'].transpose(2, 3) + results['colors'] = results['colors'].transpose(2, 3) + if results.get('full_rgb', None) is not None: + results['full_rgb'] = results['full_rgb'].transpose(2, 3) + return results + + +class InfiniteDataset(BaseWrapperDataset): + """ + A wrapper dataset which supports infnite sampling from a dataset. + No epochs in this case. + """ + def __init__(self, dataset, max_len=1000000): + super().__init__(dataset) + self.MAXLEN = max_len + + def __len__(self): + return self.MAXLEN + + def ordered_indices(self): + return np.arange(self.MAXLEN) + + def __getitem__(self, index): + actual_length = len(self.dataset) + return self.dataset[index % actual_length] \ No newline at end of file diff --git a/fairnr/data/trajectory.py b/fairnr/data/trajectory.py new file mode 100644 index 0000000..f37d833 --- /dev/null +++ b/fairnr/data/trajectory.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np + +TRAJECTORY_REGISTRY = {} + + +def register_traj(name): + def register_traj_fn(fn): + if name in TRAJECTORY_REGISTRY: + raise ValueError('Cannot register duplicate trajectory ({})'.format(name)) + TRAJECTORY_REGISTRY[name] = fn + return fn + return register_traj_fn + + +def get_trajectory(name): + return TRAJECTORY_REGISTRY.get(name, None) + + +@register_traj('circle') +def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): + if axis == 'z': + return lambda t: [radius * np.cos(r * t+t0), radius * np.sin(r * t+t0), h] + elif axis == 'y': + return lambda t: [radius * np.cos(r * t+t0), h, radius * np.sin(r * t+t0)] + else: + return lambda t: [h, radius * np.cos(r * t+t0), radius * np.sin(r * t+t0)] + + +@register_traj('zoomin_circle') +def zoomin_circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): + ra = lambda t: 0.1 + abs(4.0 - t * 2 / np.pi) + + if axis == 'z': + return lambda t: [radius * ra(t) * np.cos(r * t+t0), radius * ra(t) * np.sin(r * t+t0), h] + elif axis == 'y': + return lambda t: [radius * ra(t) * np.cos(r * t+t0), h, radius * ra(t) * np.sin(r * t+t0)] + else: + return lambda t: [h, radius * (4.2 - t * 2 / np.pi) * np.cos(r * t+t0), radius * (4.2 - t * 2 / np.pi) * np.sin(r * t+t0)] + + +@register_traj('zoomin_line') +def zoomin_line(radius=3.5, h=0.0, axis='z', t0=0, r=1, min_r=0.0001, max_r=10, step_r=10): + ra = lambda t: min_r + (max_r - min_r) * t * 180 / np.pi / step_r + + if axis == 'z': + return lambda t: [radius * ra(t) * np.cos(t0), radius * ra(t) * np.sin(t0), h * ra(t)] + elif axis == 'y': + return lambda t: [radius * ra(t) * np.cos(t0), h, radius * ra(t) * np.sin(t0)] + else: + return lambda t: [h, radius * (4.2 - t * 2 / np.pi) * np.cos(r * t+t0), radius * (4.2 - t * 2 / np.pi) * np.sin(r * t+t0)] diff --git a/fairnr/models/__init__.py b/fairnr/models/__init__.py new file mode 100644 index 0000000..e021f44 --- /dev/null +++ b/fairnr/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): + model_name = file[:file.find('.py')] if file.endswith('.py') else file + module = importlib.import_module('fairnr.models.' + model_name) diff --git a/fairnr/models/fairnr_model.py b/fairnr/models/fairnr_model.py new file mode 100644 index 0000000..758663a --- /dev/null +++ b/fairnr/models/fairnr_model.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base classes for various models. + +The basic principle of differentiable rendering is two components: + -- an field or so-called geometric field (GE) + -- an raymarcher or so-called differentiable ray-marcher (RM) +So it can be composed as a GERM model +""" + +import logging +import torch +import torch.nn as nn +import skimage.metrics +import imageio, os +import numpy as np + +from fairseq.models import BaseFairseqModel +from fairnr.modules.encoder import get_encoder +from fairnr.modules.field import get_field +from fairnr.modules.renderer import get_renderer +from fairnr.modules.reader import get_reader +from fairnr.data.geometry import ray, compute_normal_map, compute_normal_map +from fairnr.data.data_utils import recover_image + +logger = logging.getLogger(__name__) + + +class BaseModel(BaseFairseqModel): + """Base class""" + + ENCODER = 'abstract_encoder' + FIELD = 'abstract_field' + RAYMARCHER = 'abstract_renderer' + READER = 'abstract_reader' + + def __init__(self, args, reader, encoder, field, raymarcher): + super().__init__() + self.args = args + self.reader = reader + self.encoder = encoder + self.field = field + self.raymarcher = raymarcher + self.cache = None + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + reader = get_reader(cls.READER)(args) + encoder = get_encoder(cls.ENCODER)(args) + field = get_field(cls.FIELD)(args) + raymarcher = get_renderer(cls.RAYMARCHER)(args) + return cls(args, reader, encoder, field, raymarcher) + + @classmethod + def add_args(cls, parser): + get_reader(cls.READER).add_args(parser) + get_renderer(cls.RAYMARCHER).add_args(parser) + get_encoder(cls.ENCODER).add_args(parser) + get_field(cls.FIELD).add_args(parser) + + @property + def dummy_loss(self): + return sum([p.sum() for p in self.parameters()]) * 0.0 + + # def forward(self, ray_start, ray_dir, ray_split=1, **kwargs): + def forward(self, ray_split=1, **kwargs): + ray_start, ray_dir, uv = self.reader(**kwargs) + + if ray_split == 1: + results = self._forward(ray_start, ray_dir, **kwargs) + else: + total_rays = ray_dir.shape[2] + chunk_size = total_rays // ray_split + results = [ + self._forward( + ray_start, ray_dir[:, :, i: i+chunk_size], **kwargs) + for i in range(0, total_rays, chunk_size) + ] + results = self.merge_outputs(results) + + if results.get('sampled_uv', None) is None: + results['sampled_uv'] = uv + results['ray_start'] = ray_start + results['ray_dir'] = ray_dir + + # caching the prediction + self.cache = { + w: results[w].detach() + if isinstance(w, torch.Tensor) + else results[w] + for w in results + } + return results + + def _forward(self, ray_start, ray_dir, **kwargs): + raise NotImplementedError + + def merge_outputs(self, outputs): + new_output = {} + for key in outputs[0]: + if isinstance(outputs[0][key], torch.Tensor) and outputs[0][key].dim() > 2: + new_output[key] = torch.cat([o[key] for o in outputs], 2) + else: + new_output[key] = outputs[0][key] + return new_output + + @torch.no_grad() + def visualize(self, sample, output=None, shape=0, view=0, **kwargs): + width = int(sample['size'][shape, view][1].item()) + img_id = '{}_{}'.format(sample['shape'][shape], sample['view'][shape, view]) + + if output is None: + assert self.cache is not None, "need to run forward-pass" + output = self.cache # make sure to run forward-pass. + + images = {} + images = self._visualize(images, sample, output, [img_id, shape, view, width, 'render']) + images = self._visualize(images, sample, sample, [img_id, shape, view, width, 'target']) + images = { + tag: recover_image(width=width, **images[tag]) + for tag in images if images[tag] is not None + } + return images + + def _visualize(self, images, sample, output, state, **kwargs): + img_id, shape, view, width, name = state + if 'colors' in output and output['colors'] is not None: + images['{}_color/{}:HWC'.format(name, img_id)] ={ + 'img': output['colors'][shape, view]} + + if 'depths' in output and output['depths'] is not None: + min_depth, max_depth = output['depths'].min(), output['depths'].max() + images['{}_depth/{}:HWC'.format(name, img_id)] = { + 'img': output['depths'][shape, view], + 'min_val': min_depth, + 'max_val': max_depth} + normals = compute_normal_map( + output['ray_start'][shape, view].float(), + output['ray_dir'][shape, view].float(), + output['depths'][shape, view].float(), + sample['extrinsics'][shape, view].float().inverse(), width) + images['{}_normal/{}:HWC'.format(name, img_id)] = { + 'img': normals, 'min_val': -1, 'max_val': 1} + + return images + + def add_eval_scores(self, logging_output, sample, output, criterion, scores=['ssim', 'psnr', 'lpips'], outdir=None): + predicts, targets = output['colors'], sample['colors'] + ssims, psnrs, lpips, rmses = [], [], [], [] + + for s in range(predicts.size(0)): + for v in range(predicts.size(1)): + width = int(sample['size'][s, v][1]) + p = recover_image(predicts[s, v], width=width) + t = recover_image(targets[s, v], width=width) + pn, tn = p.numpy(), t.numpy() + p, t = p.to(predicts.device), t.to(targets.device) + + if 'ssim' in scores: + ssims += [skimage.metrics.structural_similarity(pn, tn, multichannel=True, data_range=1)] + if 'psnr' in scores: + psnrs += [skimage.metrics.peak_signal_noise_ratio(pn, tn, data_range=1)] + if 'lpips' in scores and hasattr(criterion, 'lpips'): + with torch.no_grad(): + lpips += [criterion.lpips( + 2 * p.unsqueeze(-1).permute(3,2,0,1) - 1, + 2 * t.unsqueeze(-1).permute(3,2,0,1) - 1).item()] + if 'depths' in sample: + td = sample['depths'][sample['depths'] > 0] + pd = output['depths'][sample['depths'] > 0] + rmses += [torch.sqrt(((td - pd) ** 2).mean()).item()] + + if outdir is not None: + def imsave(filename, image): + imageio.imsave(os.path.join(outdir, filename), (image * 255).astype('uint8')) + + figname = '-{:03d}_{:03d}.png'.format(sample['id'][s], sample['view'][s, v]) + imsave('output' + figname, pn) + imsave('target' + figname, tn) + imsave('normal' + figname, recover_image(compute_normal_map( + output['ray_start'][s, v].float(), output['ray_dir'][s, v].float(), + output['depths'][s, v].float(), sample['extrinsics'][s, v].float().inverse(), width=width), + min_val=-1, max_val=1, width=width).numpy()) + + if len(ssims) > 0: + logging_output['ssim_loss'] = np.mean(ssims) + if len(psnrs) > 0: + logging_output['psnr_loss'] = np.mean(psnrs) + if len(lpips) > 0: + logging_output['lpips_loss'] = np.mean(lpips) + if len(rmses) > 0: + logging_output['rmses_loss'] = np.mean(rmses) + + def adjust(self, **kwargs): + raise NotImplementedError + + @property + def text(self): + return "fairnr BaseModel" + + diff --git a/fairnr/models/multi_nsvf.py b/fairnr/models/multi_nsvf.py new file mode 100644 index 0000000..ad167fa --- /dev/null +++ b/fairnr/models/multi_nsvf.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +logger = logging.getLogger(__name__) + +import torch + +from fairseq.models import ( + register_model, + register_model_architecture +) +from fairnr.models.nsvf import NSVFModel, base_architecture + + +@register_model('multi_nsvf') +class MultiNSVFModel(NSVFModel): + + ENCODER = 'multi_sparsevoxel_encoder' + + @torch.no_grad() + def split_voxels(self): + logger.info("half the global voxel size {:.4f} -> {:.4f}".format( + self.encoder.all_voxels[0].voxel_size.item(), + self.encoder.all_voxels[0].voxel_size.item() * .5)) + self.encoder.splitting() + for id in range(len(self.encoder.all_voxels)): + self.encoder.all_voxels[id].voxel_size *= .5 + self.encoder.all_voxels[id].max_hits *= 1.5 + + @torch.no_grad() + def reduce_stepsize(self): + logger.info("reduce the raymarching step size {:.4f} -> {:.4f}".format( + self.encoder.all_voxels[0].step_size.item(), + self.encoder.all_voxels[0].step_size.item() * .5)) + for id in range(len(self.encoder.all_voxels)): + self.encoder.all_voxels[id].step_size *= .5 + + +@register_model("shared_nsvf") +class SharedNSVFModel(MultiNSVFModel): + + ENCODER = 'shared_sparsevoxel_encoder' + + +@register_model_architecture('multi_nsvf', "multi_nsvf_base") +def multi_base_architecture(args): + base_architecture(args) + + +@register_model_architecture('shared_nsvf', 'shared_nsvf') +def shared_base_architecture(args): + # encoder + args.context_embed_dim = getattr(args, "context_embed_dim", 96) + + # field + args.inputs_to_density = getattr(args, "inputs_to_density", "emb:6:32, context:0:96") + args.hypernetwork = getattr(args, "hypernetwork", False) + base_architecture(args) \ No newline at end of file diff --git a/fairnr/models/nmf.py b/fairnr/models/nmf.py new file mode 100644 index 0000000..a8c1d5a --- /dev/null +++ b/fairnr/models/nmf.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +logger = logging.getLogger(__name__) + +import torch +from fairseq.models import ( + register_model, + register_model_architecture +) +from fairnr.models.nsvf import NSVFModel, MAX_DEPTH + + +@register_model('nmf') +class NMFModel(NSVFModel): + """ + Experimental code: Neural Mesh Field + """ + ENCODER = 'triangle_mesh_encoder' + + @torch.no_grad() + def prune_voxels(self, *args, **kwargs): + pass + + @torch.no_grad() + def split_voxels(self): + logger.info("half the global cage size {:.4f} -> {:.4f}".format( + self.encoder.cage_size.item(), self.encoder.cage_size.item() * .5)) + self.encoder.cage_size *= .5 + + +@register_model_architecture("nmf", "nmf_base") +def base_architecture(args): + # parameter needs to be changed + args.max_hits = getattr(args, "max_hits", 60) + args.raymarching_stepsize = getattr(args, "raymarching_stepsize", 0.01) + + # encoder default parameter + args.voxel_embed_dim = getattr(args, "voxel_embed_dim", 0) + args.voxel_path = getattr(args, "voxel_path", None) + + # field + args.inputs_to_density = getattr(args, "inputs_to_density", "pos:10") + args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, pos:10, ray:4") + args.feature_embed_dim = getattr(args, "feature_embed_dim", 256) + args.density_embed_dim = getattr(args, "density_embed_dim", 128) + args.texture_embed_dim = getattr(args, "texture_embed_dim", 256) + + args.feature_layers = getattr(args, "feature_layers", 1) + args.texture_layers = getattr(args, "texture_layers", 3) + + args.background_stop_gradient = getattr(args, "background_stop_gradient", False) + args.background_depth = getattr(args, "background_depth", 5.0) + + args.saperate_specular = getattr(args, "saperate_specular", False) + args.specular_dropout = getattr(args, "specular_dropout", 0.0) + + # raymarcher + args.discrete_regularization = getattr(args, "discrete_regularization", False) + args.deterministic_step = getattr(args, "deterministic_step", False) + args.raymarching_tolerance = getattr(args, "raymarching_tolerance", 0) + + # reader + args.pixel_per_view = getattr(args, "pixel_per_view", 2048) + args.sampling_on_mask = getattr(args, "sampling_on_mask", 0.0) + args.sampling_at_center = getattr(args, "sampling_at_center", 1.0) + args.sampling_on_bbox = getattr(args, "sampling_on_bbox", False) + args.sampling_patch_size = getattr(args, "sampling_patch_size", 1) + args.sampling_skipping_size = getattr(args, "sampling_skipping_size", 1) + + # others + args.chunk_size = getattr(args, "chunk_size", 64) \ No newline at end of file diff --git a/fairnr/models/nsvf.py b/fairnr/models/nsvf.py new file mode 100644 index 0000000..ba7d6c5 --- /dev/null +++ b/fairnr/models/nsvf.py @@ -0,0 +1,248 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +logger = logging.getLogger(__name__) + +import cv2, math, time +import numpy as np +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.models import ( + register_model, + register_model_architecture +) + +from fairnr.data.data_utils import Timer, GPUTimer +from fairnr.data.geometry import compute_normal_map, fill_in +from fairnr.models.fairnr_model import BaseModel + +MAX_DEPTH = 10000.0 + + +@register_model('nsvf') +class NSVFModel(BaseModel): + + READER = 'image_reader' + ENCODER = 'sparsevoxel_encoder' + FIELD = 'radiance_field' + RAYMARCHER = 'volume_rendering' + + def _forward(self, ray_start, ray_dir, **kwargs): + S, V, P, _ = ray_dir.size() + assert S == 1, "naive NeRF only supports single object." + + # voxel encoder (precompute for each voxel if needed) + encoder_states = self.encoder.precompute(**kwargs) + + # ray-voxel intersection + with GPUTimer() as timer0: + ray_start, ray_dir, intersection_outputs, hits = \ + self.encoder.ray_intersect(ray_start, ray_dir, encoder_states) + + if self.reader.no_sampling and self.training: # sample points after ray-voxel intersection + uv, size = kwargs['uv'], kwargs['size'] + mask = hits.reshape(*uv.size()[:2], uv.size(-1)) + + # sample rays based on voxel intersections + sampled_uv, sampled_masks = self.reader.sample_pixels( + uv, size, mask=mask, return_mask=True) + sampled_masks = sampled_masks.reshape(uv.size(0), -1).bool() + hits, sampled_masks = hits[sampled_masks].reshape(S, -1), sampled_masks.unsqueeze(-1) + intersection_outputs = {name: outs[sampled_masks.expand_as(outs)].reshape(S, -1, outs.size(-1)) + for name, outs in intersection_outputs.items()} + ray_start = ray_start[sampled_masks.expand_as(ray_start)].reshape(S, -1, 3) + ray_dir = ray_dir[sampled_masks.expand_as(ray_dir)].reshape(S, -1, 3) + P = hits.size(-1) // V # the number of pixels per image + else: + sampled_uv = None + + # neural ray-marching + fullsize = S * V * P + + BG_DEPTH = self.field.bg_color.depth + bg_color = self.field.bg_color(ray_dir) + + all_results = defaultdict(lambda: None) + if hits.sum() > 0: # check if ray missed everything + intersection_outputs = {name: outs[hits] for name, outs in intersection_outputs.items()} + ray_start, ray_dir = ray_start[hits], ray_dir[hits] + + # sample evalution points along the ray + samples = self.encoder.ray_sample(intersection_outputs) + encoder_states = {name: s.reshape(-1, s.size(-1)) if s is not None else None + for name, s in encoder_states.items()} + + # rendering + all_results = self.raymarcher( + self.encoder, self.field, ray_start, ray_dir, samples, encoder_states) + all_results['depths'] = all_results['depths'] + BG_DEPTH * all_results['missed'] + all_results['voxel_edges'] = self.encoder.get_edge(ray_start, ray_dir, samples, encoder_states) + all_results['voxel_depth'] = samples['sampled_point_depth'][:, 0] + + # fill out the full size + hits = hits.reshape(fullsize) + all_results['missed'] = fill_in((fullsize, ), hits, all_results['missed'], 1.0).view(S, V, P) + all_results['depths'] = fill_in((fullsize, ), hits, all_results['depths'], BG_DEPTH).view(S, V, P) + all_results['voxel_depth'] = fill_in((fullsize, ), hits, all_results['voxel_depth'], BG_DEPTH).view(S, V, P) + all_results['voxel_edges'] = fill_in((fullsize, 3), hits, all_results['voxel_edges'], 1.0).view(S, V, P, 3) + all_results['colors'] = fill_in((fullsize, 3), hits, all_results['colors'], 0.0).view(S, V, P, 3) + all_results['bg_color'] = bg_color.reshape(fullsize, 3).view(S, V, P, 3) + all_results['colors'] += all_results['missed'].unsqueeze(-1) * all_results['bg_color'] + if 'normal' in all_results: + all_results['normal'] = fill_in((fullsize, 3), hits, all_results['normal'], 0.0).view(S, V, P, 3) + + # other logs + all_results['other_logs'] = { + 'voxs_log': self.encoder.voxel_size.item(), + 'stps_log': self.encoder.step_size.item(), + 'tvox_log': timer0.sum, + 'asf_log': (all_results['ae'].float() / fullsize).item(), + 'ash_log': (all_results['ae'].float() / hits.sum()).item(), + 'nvox_log': self.encoder.num_voxels, + } + all_results['sampled_uv'] = sampled_uv + return all_results + + def _visualize(self, images, sample, output, state, **kwargs): + img_id, shape, view, width, name = state + images = super()._visualize(images, sample, output, state, **kwargs) + if 'voxel_edges' in output and output['voxel_edges'] is not None: + # voxel hitting visualization + images['{}_voxel/{}:HWC'.format(name, img_id)] = { + 'img': output['voxel_edges'][shape, view].float(), + 'min_val': 0, + 'max_val': 1, + 'weight': + compute_normal_map( + output['ray_start'][shape, view].float(), + output['ray_dir'][shape, view].float(), + output['voxel_depth'][shape, view].float(), + sample['extrinsics'][shape, view].float().inverse(), + width, proj=True) + } + if 'normal' in output and output['normal'] is not None: + images['{}_predn/{}:HWC'.format(name, img_id)] = { + 'img': output['normal'][shape, view], 'min_val': -1, 'max_val': 1} + return images + + @torch.no_grad() + def prune_voxels(self, th=0.5): + self.encoder.pruning(self.field, th) + + @torch.no_grad() + def split_voxels(self): + logger.info("half the global voxel size {:.4f} -> {:.4f}".format( + self.encoder.voxel_size.item(), self.encoder.voxel_size.item() * .5)) + self.encoder.splitting() + self.encoder.voxel_size *= .5 + self.encoder.max_hits *= 1.5 + + @torch.no_grad() + def reduce_stepsize(self): + logger.info("reduce the raymarching step size {:.4f} -> {:.4f}".format( + self.encoder.step_size.item(), self.encoder.step_size.item() * .5)) + self.encoder.step_size *= .5 + + +@register_model_architecture("nsvf", "nsvf_base") +def base_architecture(args): + # parameter needs to be changed + args.voxel_size = getattr(args, "voxel_size", 0.25) + args.max_hits = getattr(args, "max_hits", 60) + args.raymarching_stepsize = getattr(args, "raymarching_stepsize", 0.01) + args.raymarching_stepsize_ratio = getattr(args, "raymarching_stepsize_ratio", 0.0) + + # encoder default parameter + args.voxel_embed_dim = getattr(args, "voxel_embed_dim", 32) + args.voxel_path = getattr(args, "voxel_path", None) + args.initial_boundingbox = getattr(args, "initial_boundingbox", None) + + # field + args.inputs_to_density = getattr(args, "inputs_to_density", "emb:6:32") + args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, ray:4") + args.feature_embed_dim = getattr(args, "feature_embed_dim", 256) + args.density_embed_dim = getattr(args, "density_embed_dim", 128) + args.texture_embed_dim = getattr(args, "texture_embed_dim", 256) + + args.feature_layers = getattr(args, "feature_layers", 1) + args.texture_layers = getattr(args, "texture_layers", 3) + + args.background_stop_gradient = getattr(args, "background_stop_gradient", False) + args.background_depth = getattr(args, "background_depth", 5.0) + + # raymarcher + args.discrete_regularization = getattr(args, "discrete_regularization", False) + args.deterministic_step = getattr(args, "deterministic_step", False) + args.raymarching_tolerance = getattr(args, "raymarching_tolerance", 0) + args.use_octree = getattr(args, "use_octree", False) + + # reader + args.pixel_per_view = getattr(args, "pixel_per_view", 2048) + args.sampling_on_mask = getattr(args, "sampling_on_mask", 0.0) + args.sampling_at_center = getattr(args, "sampling_at_center", 1.0) + args.sampling_on_bbox = getattr(args, "sampling_on_bbox", False) + args.sampling_patch_size = getattr(args, "sampling_patch_size", 1) + args.sampling_skipping_size = getattr(args, "sampling_skipping_size", 1) + + # others + args.chunk_size = getattr(args, "chunk_size", 64) + args.valid_chunk_size = getattr(args, "valid_chunk_size", 64) + + +@register_model_architecture("nsvf", "nsvf_xyz") +def nerf2_architecture(args): + args.voxel_embed_dim = getattr(args, "voxel_embed_dim", 0) + args.inputs_to_density = getattr(args, "inputs_to_density", "pos:10") + args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, pos:10, ray:4") + base_architecture(args) + + +@register_model_architecture("nsvf", "nsvf_xyzn") +def nerf3_architecture(args): + args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, pos:10, normal:4, ray:4") + nerf2_architecture(args) + + +@register_model_architecture("nsvf", "nsvf_embn") +def nerf4_architecture(args): + args.inputs_to_density = getattr(args, "inputs_to_density", "emb:6:32") + args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, normal:4, ray:4") + base_architecture(args) + + +@register_model_architecture("nsvf", "nsvf_emb0") +def nerf5_architecture(args): + args.voxel_embed_dim = getattr(args, "voxel_embed_dim", 384) + args.inputs_to_density = getattr(args, "inputs_to_density", "emb:0:384") + args.inputs_to_texture = getattr(args, "inputs_to_texture", "feat:0:256, ray:4") + base_architecture(args) + + +@register_model('rnsvf') +class ResampledNSVFModel(NSVFModel): + + RAYMARCHER = "resampled_volume_rendering" + + +@register_model_architecture("rnsvf", "rnsvf_base") +def rnsvf_architecture(args): + base_architecture(args) + + +@register_model('disco_nsvf') +class ResampledNSVFModel(NSVFModel): + + FIELD = "disentangled_radiance_field" + + +@register_model_architecture("disco_nsvf", "disco_nsvf") +def disco_nsvf_architecture(args): + args.compressed_light_dim = getattr(args, "compressed_light_dim", 64) + nerf3_architecture(args) diff --git a/fairnr/modules/__init__.py b/fairnr/modules/__init__.py new file mode 100644 index 0000000..3d449d4 --- /dev/null +++ b/fairnr/modules/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): + model_name = file[:file.find('.py')] if file.endswith('.py') else file + module = importlib.import_module('fairnr.modules.' + model_name) \ No newline at end of file diff --git a/fairnr/modules/encoder.py b/fairnr/modules/encoder.py new file mode 100644 index 0000000..31d3964 --- /dev/null +++ b/fairnr/modules/encoder.py @@ -0,0 +1,680 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +import sys +import os +import math +import logging +logger = logging.getLogger(__name__) + +from pathlib import Path +from fairnr.data.data_utils import load_matrix +from fairnr.data.geometry import ( + trilinear_interp, splitting_points, offset_points, + get_edge, build_easy_octree +) +from fairnr.clib import ( + aabb_ray_intersect, triangle_ray_intersect, + uniform_ray_sampling, svo_ray_intersect +) +from fairnr.modules.linear import FCBlock, Linear, Embedding + +MAX_DEPTH = 10000.0 +ENCODER_REGISTRY = {} + +def register_encoder(name): + def register_encoder_cls(cls): + if name in ENCODER_REGISTRY: + raise ValueError('Cannot register duplicate module ({})'.format(name)) + ENCODER_REGISTRY[name] = cls + return cls + return register_encoder_cls + + +def get_encoder(name): + if name not in ENCODER_REGISTRY: + raise ValueError('Cannot find module {}'.format(name)) + return ENCODER_REGISTRY[name] + + +@register_encoder('abstract_encoder') +class Encoder(nn.Module): + """ + backbone network + """ + def __init__(self, args): + super().__init__() + self.args = args + + def forward(self, **kwargs): + raise NotImplementedError + + @staticmethod + def add_args(parser): + pass + + +@register_encoder('sparsevoxel_encoder') +class SparseVoxelEncoder(Encoder): + + def __init__(self, args, voxel_path=None, bbox_path=None, shared_values=None): + super().__init__(args) + self.voxel_path = voxel_path if voxel_path is not None else args.voxel_path + self.bbox_path = bbox_path if bbox_path is not None else getattr(args, "initial_boundingbox", None) + assert (self.bbox_path is not None) or (self.voxel_path is not None), \ + "at least initial bounding box or pretrained voxel files are required." + + self.voxel_index = None + if self.voxel_path is not None: + assert os.path.exists(self.voxel_path), "voxel file must exist" + assert getattr(args, "voxel_size", None) is not None, "final voxel size is essential." + + voxel_size = args.voxel_size + + if Path(self.voxel_path).suffix == '.ply': + from plyfile import PlyData, PlyElement + plydata = PlyData.read(self.voxel_path)['vertex'] + fine_points = torch.from_numpy( + np.stack([plydata['x'], plydata['y'], plydata['z']]).astype('float32').T) + try: + self.voxel_index = torch.from_numpy(plydata['quality']).long() + except ValueError: + pass + else: + # supporting the old version voxel points + fine_points = torch.from_numpy(np.loadtxt(self.voxel_path)[:, 3:].astype('float32')) + else: + bbox = np.loadtxt(self.bbox_path) + voxel_size = bbox[-1] + fine_points = torch.from_numpy(bbox2voxels(bbox[:6], voxel_size)) + + half_voxel = voxel_size * .5 + fine_length = fine_points.size(0) + + # transform from voxel centers to voxel corners (key/values) + fine_coords = (fine_points / half_voxel).floor_().long() + fine_res = (fine_points - (fine_points / half_voxel).floor_() * half_voxel).mean(0, keepdim=True) + fine_keys0 = offset_points(fine_coords, 1.0).reshape(-1, 3) + fine_keys, fine_feats = torch.unique(fine_keys0, dim=0, sorted=True, return_inverse=True) + fine_feats = fine_feats.reshape(-1, 8) + num_keys = torch.scalar_tensor(fine_keys.size(0)).long() + + self.use_octree = getattr(args, "use_octree", False) + self.flatten_centers, self.flatten_children = None, None + + # assign values + points = fine_points + feats = fine_feats.long() + keep = fine_feats.new_ones(fine_feats.size(0)).long() + keys = fine_keys.long() + + # ray-marching step size + if getattr(args, "raymarching_stepsize_ratio", 0) > 0: + step_size = args.raymarching_stepsize_ratio * voxel_size + else: + step_size = args.raymarching_stepsize + + # register parameters + self.register_buffer("points", points) # voxel centers + self.register_buffer("feats", feats) # for each voxel, 8 vertexs + self.register_buffer("keys", keys) + self.register_buffer("keep", keep) + self.register_buffer("num_keys", num_keys) + + self.register_buffer("voxel_size", torch.scalar_tensor(voxel_size)) + self.register_buffer("step_size", torch.scalar_tensor(step_size)) + self.register_buffer("max_hits", torch.scalar_tensor(args.max_hits)) + + # set-up other hyperparameters + self.embed_dim = getattr(args, "voxel_embed_dim", None) + self.deterministic_step = getattr(args, "deterministic_step", False) + + if shared_values is None and self.embed_dim > 0: + self.values = Embedding(num_keys, self.embed_dim, None) + else: + self.values = shared_values + + def upgrade_state_dict_named(self, state_dict, name): + # update the voxel embedding shapes + if self.values is not None: + loaded_values = state_dict[name + '.values.weight'] + self.values.weight = nn.Parameter(self.values.weight.new_zeros(*loaded_values.size())) + self.values.num_embeddings = self.values.weight.size(0) + self.total_size = self.values.weight.size(0) + self.num_keys = self.num_keys * 0 + self.total_size + + if self.voxel_index is not None: + state_dict[name + '.points'] = state_dict[name + '.points'][self.voxel_index] + state_dict[name + '.feats'] = state_dict[name + '.feats'][self.voxel_index] + state_dict[name + '.keep'] = state_dict[name + '.keep'][self.voxel_index] + + # update the buffers shapes + self.points = self.points.new_zeros(*state_dict[name + '.points'].size()) + self.feats = self.feats.new_zeros(*state_dict[name + '.feats'].size()) + self.keys = self.keys.new_zeros(*state_dict[name + '.keys'].size()) + self.keep = self.keep.new_zeros(*state_dict[name + '.keep'].size()) + + @staticmethod + def add_args(parser): + parser.add_argument('--initial-boundingbox', type=str, help='the initial bounding box to initialize the model') + parser.add_argument('--voxel-size', type=float, metavar='D', help='voxel size of the input points (initial') + parser.add_argument('--voxel-path', type=str, help='path for pretrained voxel file. if provided no update') + parser.add_argument('--voxel-embed-dim', type=int, metavar='N', help="embedding size") + parser.add_argument('--deterministic-step', action='store_true', + help='if set, the model runs fixed stepsize, instead of sampling one') + parser.add_argument('--max-hits', type=int, metavar='N', help='due to restrictions we set a maximum number of hits') + parser.add_argument('--raymarching-stepsize', type=float, metavar='D', + help='ray marching step size for sparse voxels') + parser.add_argument('--raymarching-stepsize-ratio', type=float, metavar='D', + help='if the concrete step size is not given (=0), we use the ratio to the voxel size as step size.') + parser.add_argument('--use-octree', action='store_true', help='if set, instead of looping over the voxels, we build an octree.') + + def precompute(self, id=None, *args, **kwargs): + feats = self.feats[self.keep.bool()] + points = self.points[self.keep.bool()] + values = self.values.weight[: self.num_keys] if self.values is not None else None + if (self.flatten_centers is None or self.flatten_children is None) and self.use_octree: + # octree is not built. rebuild + centers, children = build_easy_octree(points, self.voxel_size / 2.0) + self.flatten_centers, self.flatten_children = centers, children + + if id is not None: + # extend size to support multi-objects + feats = feats.unsqueeze(0).expand(id.size(0), *feats.size()).contiguous() + points = points.unsqueeze(0).expand(id.size(0), *points.size()).contiguous() + values = values.unsqueeze(0).expand(id.size(0), *values.size()).contiguous() if values is not None else None + + # moving to multiple objects + if id.size(0) > 1: + feats = feats + self.num_keys * torch.arange(id.size(0), + device=feats.device, dtype=feats.dtype)[:, None, None] + encoder_states = { + 'voxel_vertex_idx': feats, + 'voxel_center_xyz': points, + 'voxel_vertex_emb': values + } + + if self.use_octree: + flatten_centers, flatten_children = self.flatten_centers.clone(), self.flatten_children.clone() + if id is not None: + flatten_centers = flatten_centers.unsqueeze(0).expand(id.size(0), *flatten_centers.size()).contiguous() + flatten_children = flatten_children.unsqueeze(0).expand(id.size(0), *flatten_children.size()).contiguous() + encoder_states['voxel_octree_center_xyz'] = flatten_centers + encoder_states['voxel_octree_children_idx'] = flatten_children + return encoder_states + + def extract_voxels(self): + voxel_index = torch.arange(self.keep.size(0), device=self.keep.device) + voxel_index = voxel_index[self.keep.bool()] + voxel_point = self.points[self.keep.bool()] + return voxel_index, voxel_point + + def get_edge(self, ray_start, ray_dir, samples, encoder_states): + outs = get_edge( + ray_start + ray_dir * samples['sampled_point_depth'][:, :1], + encoder_states['voxel_center_xyz'].reshape(-1, 3)[samples['sampled_point_voxel_idx'][:, 0].long()], + self.voxel_size).type_as(ray_dir) # get voxel edges/depth (for visualization) + outs = (1 - outs[:, None].expand(outs.size(0), 3)) * 0.7 + return outs + + def ray_intersect(self, ray_start, ray_dir, encoder_states): + point_feats = encoder_states['voxel_vertex_idx'] + point_xyz = encoder_states['voxel_center_xyz'] + S, V, P, _ = ray_dir.size() + _, H, D = point_feats.size() + + # ray-voxel intersection + ray_start = ray_start.expand_as(ray_dir).contiguous().view(S, V * P, 3).contiguous() + ray_dir = ray_dir.reshape(S, V * P, 3).contiguous() + + if self.use_octree: # ray-voxel intersection with SVO + flatten_centers = encoder_states['voxel_octree_center_xyz'] + flatten_children = encoder_states['voxel_octree_children_idx'] + pts_idx, min_depth, max_depth = svo_ray_intersect( + self.voxel_size, self.max_hits, flatten_centers, flatten_children, + ray_start, ray_dir) + else: # ray-voxel intersection with all voxels + pts_idx, min_depth, max_depth = aabb_ray_intersect( + self.voxel_size, self.max_hits, point_xyz, ray_start, ray_dir) + + # sort the depths + min_depth.masked_fill_(pts_idx.eq(-1), MAX_DEPTH) + max_depth.masked_fill_(pts_idx.eq(-1), MAX_DEPTH) + min_depth, sorted_idx = min_depth.sort(dim=-1) + max_depth = max_depth.gather(-1, sorted_idx) + pts_idx = pts_idx.gather(-1, sorted_idx) + hits = pts_idx.ne(-1).any(-1) # remove all points that completely miss the object + + if S > 1: # extend the point-index to multiple shapes (just in case) + pts_idx = (pts_idx + H * torch.arange(S, + device=pts_idx.device, dtype=pts_idx.dtype)[:, None, None] + ).masked_fill_(pts_idx.eq(-1), -1) + + intersection_outputs = { + "min_depth": min_depth, + "max_depth": max_depth, + "intersected_voxel_idx": pts_idx + } + return ray_start, ray_dir, intersection_outputs, hits + + def ray_sample(self, intersection_outputs): + min_depth = intersection_outputs['min_depth'] + max_depth = intersection_outputs['max_depth'] + pts_idx = intersection_outputs['intersected_voxel_idx'] + + max_ray_length = (max_depth.masked_fill(max_depth.eq(MAX_DEPTH), 0).max(-1)[0] - min_depth.min(-1)[0]).max() + sampled_idx, sampled_depth, sampled_dists = uniform_ray_sampling( + pts_idx, min_depth, max_depth, self.step_size, max_ray_length, + self.deterministic_step or (not self.training)) + sampled_dists = sampled_dists.clamp(min=0.0) + sampled_depth.masked_fill_(sampled_idx.eq(-1), MAX_DEPTH) + sampled_dists.masked_fill_(sampled_idx.eq(-1), 0.0) + + samples = { + 'sampled_point_depth': sampled_depth, + 'sampled_point_distance': sampled_dists, + 'sampled_point_voxel_idx': sampled_idx, + } + return samples + + @torch.enable_grad() + def forward(self, samples, encoder_states): + # encoder states + point_feats = encoder_states['voxel_vertex_idx'] + point_xyz = encoder_states['voxel_center_xyz'] + values = encoder_states['voxel_vertex_emb'] + + # ray point samples + sampled_idx = samples['sampled_point_voxel_idx'].long() + sampled_xyz = samples['sampled_point_xyz'].requires_grad_(True) + sampled_dir = samples['sampled_point_ray_direction'] + + # prepare inputs for implicit field + inputs = {'pos': sampled_xyz, 'ray': sampled_dir} + if values is not None: + # resample point features + point_xyz = F.embedding(sampled_idx, point_xyz) + point_feats = F.embedding(F.embedding(sampled_idx, point_feats), values).view(point_xyz.size(0), -1) + + # tri-linear interpolation + p = ((sampled_xyz - point_xyz) / self.voxel_size + .5).unsqueeze(1) + q = offset_points(p, .5, offset_only=True).unsqueeze(0) + .5 # BUG (FIX) + inputs.update({'emb': trilinear_interp(p, q, point_feats)}) + + return inputs + + @torch.no_grad() + def pruning(self, field_fn, th=0.5, encoder_states=None): + logger.info("pruning...") + if encoder_states is None: + encoder_states = self.precompute(id=None) + + feats = encoder_states['voxel_vertex_idx'] + points = encoder_states['voxel_center_xyz'] + values = encoder_states['voxel_vertex_emb'] + chunk_size, bits = 64, 16 + + if self.use_octree: # clean the octree, need to be rebuilt + self.flatten_centers, self.flatten_children = None, None + + def prune_once(feats, points, values): + # sample points inside voxels + sampled_xyz = offset_points(points, self.voxel_size / 2.0, bits=bits) + sampled_idx = torch.arange(points.size(0), device=points.device)[:, None].expand(*sampled_xyz.size()[:2]) + sampled_xyz, sampled_idx = sampled_xyz.reshape(-1, 3), sampled_idx.reshape(-1) + + field_inputs = self.forward( + {'sampled_point_xyz': sampled_xyz, + 'sampled_point_voxel_idx': sampled_idx, + 'sampled_point_ray_direction': None}, + {'voxel_vertex_idx': feats, + 'voxel_center_xyz': points, + 'voxel_vertex_emb': values}) # get field inputs + + # evaluation with density + field_outputs = field_fn(field_inputs, outputs=['sigma']) + free_energy = -torch.relu(field_outputs['sigma']).reshape(-1, bits ** 3).max(-1)[0] + + # prune voxels if needed + return (1 - torch.exp(free_energy)) > th + + keep = torch.cat([prune_once(feats[i: i + chunk_size], points[i: i + chunk_size], values) + for i in range(0, points.size(0), chunk_size)], 0) + self.keep.masked_scatter_(self.keep.bool(), keep.long()) + logger.info("pruning done. # of voxels before: {}, after: {} voxels".format(points.size(0), keep.sum())) + + @torch.no_grad() + def splitting(self): + logger.info("splitting...") + encoder_states = self.precompute(id=None) + feats, points, values = encoder_states['voxel_vertex_idx'], encoder_states['voxel_center_xyz'], encoder_states['voxel_vertex_emb'] + new_points, new_feats, new_values, new_keys = splitting_points(points, feats, values, self.voxel_size / 2.0) + new_num_keys = new_keys.size(0) + new_point_length = new_points.size(0) + + # set new voxel embeddings + if new_values is not None: + self.values.weight = nn.Parameter(new_values) + self.values.num_embeddings = self.values.weight.size(0) + + if self.use_octree: # clean the octree, need to be rebuilt + self.flatten_centers, self.flatten_children = None, None + + self.total_size = new_num_keys + self.num_keys = self.num_keys * 0 + self.total_size + + self.points = new_points + self.feats = new_feats + self.keep = self.keep.new_ones(new_point_length) + logger.info("splitting done. # of voxels before: {}, after: {} voxels".format(points.size(0), self.keep.sum())) + + @property + def feature_dim(self): + return self.embed_dim + + @property + def dummy_loss(self): + if self.values is not None: + return self.values.weight[0,0] * 0.0 + return 0.0 + + @property + def num_voxels(self): + return self.keep.long().sum() + + +@register_encoder('multi_sparsevoxel_encoder') +class MultiSparseVoxelEncoder(Encoder): + def __init__(self, args): + super().__init__(args) + + self.voxel_lists = open(args.voxel_path).readlines() + self.all_voxels = nn.ModuleList( + [SparseVoxelEncoder(args, vox.strip()) for vox in self.voxel_lists]) + self.cid = None + + @staticmethod + def add_args(parser): + SparseVoxelEncoder.add_args(parser) + + def precompute(self, id, *args, **kwargs): + # TODO: this is a HACK for simplicity + assert id.size(0) == 1, "for now, only works for one object" + self.cid = id[0] + return self.all_voxels[id[0]].precompute(id, *args, **kwargs) + + def ray_intersect(self, *args, **kwargs): + return self.all_voxels[self.cid].ray_intersect(*args, **kwargs) + + def ray_sample(self, *args, **kwargs): + return self.all_voxels[self.cid].ray_sample(*args, **kwargs) + + def forward(self, samples, encoder_states): + return self.all_voxels[self.cid].forward(samples, encoder_states) + + @torch.no_grad() + def pruning(self, field_fn, th=0.5): + for id in range(len(self.all_voxels)): + self.all_voxels[id].pruning(field_fn, th) + + @torch.no_grad() + def splitting(self): + for id in range(len(self.all_voxels)): + self.all_voxels[id].splitting() + + @property + def feature_dim(self): + return self.all_voxels[0].embed_dim + + @property + def dummy_loss(self): + return sum([d.dummy_loss for d in self.all_voxels]) + + @property + def voxel_size(self): + return self.all_voxels[0].voxel_size + + @property + def step_size(self): + return self.all_voxels[0].step_size + + @property + def num_voxels(self): + return self.all_voxels[self.cid].num_voxels + + +@register_encoder('shared_sparsevoxel_encoder') +class SharedSparseVoxelEncoder(Encoder): + """ + Different from MultiSparseVoxelEncoder, we assume a shared list + of voxels across all models. Usually useful to learn a video sequence. + """ + def __init__(self, args): + super().__init__(args) + + # using a shared voxel + self.voxel_path = args.voxel_path + self.num_frames = args.num_frames + self.all_voxels = [SparseVoxelEncoder(args, self.voxel_path)] + self.all_voxels = nn.ModuleList(self.all_voxels + [ + SparseVoxelEncoder(args, self.voxel_path, shared_values=self.all_voxels[0].values) + for i in range(self.num_frames - 1)]) + self.context_embed_dim = args.context_embed_dim + self.contexts = nn.Embedding(self.num_frames, self.context_embed_dim, None) + self.cid = None + + def precompute(self, id, *args, **kwargs): + # TODO: this is a HACK for simplicity + assert id.size(0) == 1, "for now, only works for one object" + self.cid = id[0] + return self.all_voxels[id[0]].precompute(id, *args, **kwargs) + + def ray_intersect(self, *args, **kwargs): + return self.all_voxels[self.cid].ray_intersect(*args, **kwargs) + + def ray_sample(self, *args, **kwargs): + return self.all_voxels[self.cid].ray_sample(*args, **kwargs) + + def forward(self, samples, encoder_states): + inputs = self.all_voxels[self.cid].forward(samples, encoder_states) + inputs.update({'context': self.contexts(self.cid).unsqueeze(0)}) + return inputs + + @torch.no_grad() + def pruning(self, field_fn, th=0.5): + for cid in range(len(self.all_voxels)): + id = torch.tensor([cid], device=self.contexts.weight.device) + self.all_voxels[cid].pruning(field_fn, th, + encoder_states={name: v[0] for name, v in self.precompute(id).items()}) + + @torch.no_grad() + def splitting(self): + logger.info("splitting...") + all_feats, all_points = [], [] + for id in range(len(self.all_voxels)): + feats, points, values = self.all_voxels[id].precompute(id=None) + all_feats.append(feats) + all_points.append(points) + feats, points = torch.cat(all_feats, 0), torch.cat(all_points, 0) + unique_feats, unique_idx = torch.unique(feats, dim=0, return_inverse=True) + unique_points = points[ + unique_feats.new_zeros(unique_feats.size(0)).scatter_( + 0, unique_idx, torch.arange(unique_idx.size(0), device=unique_feats.device) + )] + new_points, new_feats, new_values = splitting_points(unique_points, unique_feats, values, self.voxel_size / 2.0) + new_num_keys = new_values.size(0) + new_point_length = new_points.size(0) + + # set new voxel embeddings (shared voxels) + self.all_voxels[0].values.weight = nn.Parameter(new_values) + self.all_voxels[0].values.num_embeddings = new_num_keys + + for id in range(len(self.all_voxels)): + self.all_voxels[id].total_size = new_num_keys + self.all_voxels[id].num_keys = self.all_voxels[id].num_keys * 0 + self.all_voxels[id].total_size + + self.all_voxels[id].points = new_points + self.all_voxels[id].feats = new_feats + self.all_voxels[id].keep = self.all_voxels[id].keep.new_ones(new_point_length) + + logger.info("splitting done. # of voxels before: {}, after: {} voxels".format( + unique_points.size(0), new_point_length)) + + @property + def feature_dim(self): + return self.all_voxels[0].embed_dim + self.context_embed_dim + + @property + def dummy_loss(self): + return sum([d.dummy_loss for d in self.all_voxels]) + + @property + def voxel_size(self): + return self.all_voxels[0].voxel_size + + @property + def step_size(self): + return self.all_voxels[0].step_size + + @property + def num_voxels(self): + return self.all_voxels[self.cid].num_voxels + + @staticmethod + def add_args(parser): + SparseVoxelEncoder.add_args(parser) + parser.add_argument('--num-frames', type=int, help='the total number of frames') + parser.add_argument('--context-embed-dim', type=int, help='context embedding for each view') + + +@register_encoder('triangle_mesh_encoder') +class TriangleMeshEncoder(SparseVoxelEncoder): + """ + Training on fixed mesh model. Cannot pruning.. + """ + def __init__(self, args, mesh_path=None, shared_values=None): + super(SparseVoxelEncoder, self).__init__(args) + self.mesh_path = mesh_path if mesh_path is not None else args.mesh_path + assert (self.mesh_path is not None) and os.path.exists(self.mesh_path) + + import open3d as o3d + mesh = o3d.io.read_triangle_mesh(self.mesh_path) + vertices = torch.from_numpy(np.asarray(mesh.vertices, dtype=np.float32)) + faces = torch.from_numpy(np.asarray(mesh.triangles, dtype=np.long)) + + step_size = args.raymarching_stepsize + cage_size = step_size * 10 # truncated space around the triangle surfaces + self.register_buffer("cage_size", torch.scalar_tensor(cage_size)) + self.register_buffer("step_size", torch.scalar_tensor(step_size)) + self.register_buffer("max_hits", torch.scalar_tensor(args.max_hits)) + + self.vertices = nn.Parameter(vertices, requires_grad=getattr(args, "trainable_vertices", False)) + self.faces = nn.Parameter(faces, requires_grad=False) + + # set-up other hyperparameters + self.embed_dim = getattr(args, "voxel_embed_dim", None) + self.deterministic_step = getattr(args, "deterministic_step", False) + self.values = None + self.blur_ratio = getattr(args, "blur_ratio", 0.0) + + def upgrade_state_dict_named(self, state_dict, name): + pass + + @staticmethod + def add_args(parser): + parser.add_argument('--mesh-path', type=str, help='path for initial mesh file') + parser.add_argument('--voxel-embed-dim', type=int, metavar='N', help="embedding size") + parser.add_argument('--deterministic-step', action='store_true', + help='if set, the model runs fixed stepsize, instead of sampling one') + parser.add_argument('--max-hits', type=int, metavar='N', help='due to restrictions we set a maximum number of hits') + parser.add_argument('--raymarching-stepsize', type=float, metavar='D', + help='ray marching step size for sparse voxels') + parser.add_argument('--blur-ratio', type=float, default=0, + help="it is possible to shoot outside the triangle. default=0") + parser.add_argument('--trainable-vertices', action='store_true', + help='if set, making the triangle trainable. experimental code. not ideal.') + + def precompute(self, id=None, *args, **kwargs): + feats, points, values = self.faces, self.vertices, self.values + if id is not None: + # extend size to support multi-objects + feats = feats.unsqueeze(0).expand(id.size(0), *feats.size()).contiguous() + points = points.unsqueeze(0).expand(id.size(0), *points.size()).contiguous() + values = values.unsqueeze(0).expand(id.size(0), *values.size()).contiguous() if values is not None else None + + # moving to multiple objects + if id.size(0) > 1: + feats = feats + points.size(1) * torch.arange(id.size(0), + device=feats.device, dtype=feats.dtype)[:, None, None] + + encoder_states = { + 'mesh_face_vertex_idx': feats, + 'mesh_vertex_xyz': points, + } + return encoder_states + + def get_edge(self, ray_start, ray_dir, *args, **kwargs): + return torch.ones_like(ray_dir) * 0.7 + + @property + def voxel_size(self): + return self.cage_size + + def ray_intersect(self, ray_start, ray_dir, encoder_states): + point_xyz = encoder_states['mesh_vertex_xyz'] + point_feats =encoder_states['mesh_face_vertex_idx'] + + S, V, P, _ = ray_dir.size() + F, G = point_feats.size(1), point_xyz.size(1) + + # ray-voxel intersection + ray_start = ray_start.expand_as(ray_dir).contiguous().view(S, V * P, 3).contiguous() + ray_dir = ray_dir.reshape(S, V * P, 3).contiguous() + pts_idx, depth, uv = triangle_ray_intersect( + self.cage_size, self.blur_ratio, self.max_hits, point_xyz, point_feats, ray_start, ray_dir) + min_depth = (depth[:,:,:,0] + depth[:,:,:,1]).masked_fill_(pts_idx.eq(-1), MAX_DEPTH) + max_depth = (depth[:,:,:,0] + depth[:,:,:,2]).masked_fill_(pts_idx.eq(-1), MAX_DEPTH) + hits = pts_idx.ne(-1).any(-1) # remove all points that completely miss the object + + if S > 1: # extend the point-index to multiple shapes (just in case) + pts_idx = (pts_idx + G * torch.arange(S, + device=pts_idx.device, dtype=pts_idx.dtype)[:, None, None] + ).masked_fill_(pts_idx.eq(-1), -1) + + intersection_outputs = { + "min_depth": min_depth, + "max_depth": max_depth, + "intersected_voxel_idx": pts_idx + } + return ray_start, ray_dir, intersection_outputs, hits + + @torch.enable_grad() + def forward(self, samples, encoder_states): + # TODO: enable mesh embedding learning + + sampled_xyz = samples['sampled_point_xyz'].requires_grad_(True) + sampled_dir = samples['sampled_point_ray_direction'] + + # prepare inputs for implicit field + inputs = {'pos': sampled_xyz, 'ray': sampled_dir} + return inputs + + @property + def num_voxels(self): + return self.vertices.size(0) + +def bbox2voxels(bbox, voxel_size): + vox_min, vox_max = bbox[:3], bbox[3:] + steps = ((vox_max - vox_min) / voxel_size).round().astype('int64') + 1 + x, y, z = [c.reshape(-1).astype('float32') for c in np.meshgrid(np.arange(steps[0]), np.arange(steps[1]), np.arange(steps[2]))] + x, y, z = x * voxel_size + vox_min[0], y * voxel_size + vox_min[1], z * voxel_size + vox_min[2] + return np.stack([x, y, z]).T.astype('float32') + + diff --git a/fairnr/modules/field.py b/fairnr/modules/field.py new file mode 100644 index 0000000..2d65c90 --- /dev/null +++ b/fairnr/modules/field.py @@ -0,0 +1,282 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.autograd import grad +from collections import OrderedDict +from fairnr.modules.implicit import ( + ImplicitField, SignedDistanceField, + TextureField, HyperImplicitField, BackgroundField +) +from fairnr.modules.linear import NeRFPosEmbLinear + +FIELD_REGISTRY = {} + +def register_field(name): + def register_field_cls(cls): + if name in FIELD_REGISTRY: + raise ValueError('Cannot register duplicate module ({})'.format(name)) + FIELD_REGISTRY[name] = cls + return cls + return register_field_cls + + +def get_field(name): + if name not in FIELD_REGISTRY: + raise ValueError('Cannot find module {}'.format(name)) + return FIELD_REGISTRY[name] + + +@register_field('abstract_field') +class Field(nn.Module): + """ + Abstract class for implicit functions + """ + def __init__(self, args): + super().__init__() + self.args = args + + def forward(self, **kwargs): + raise NotImplementedError + + @staticmethod + def add_args(parser): + pass + + +@register_field('radiance_field') +class RaidanceField(Field): + + def __init__(self, args): + super().__init__(args) + + # additional arguments + self.chunk_size = getattr(args, "chunk_size", 256) * 256 + self.deterministic_step = getattr(args, "deterministic_step", False) + + # background field + self.min_color = getattr(args, "min_color", -1) + self.trans_bg = getattr(args, "transparent_background", "1.0,1.0,1.0") + self.sgbg = getattr(args, "background_stop_gradient", False) + self.bg_color = BackgroundField(bg_color=self.trans_bg, min_color=self.min_color, stop_grad=self.sgbg) + self.den_filters, self.den_ori_dims, self.den_input_dims = self.parse_inputs(args.inputs_to_density) + self.tex_filters, self.tex_ori_dims, self.tex_input_dims = self.parse_inputs(args.inputs_to_texture) + self.den_filters, self.tex_filters = nn.ModuleDict(self.den_filters), nn.ModuleDict(self.tex_filters) + den_input_dim, tex_input_dim = sum(self.den_input_dims), sum(self.tex_input_dims) + den_feat_dim = self.tex_input_dims[0] + + # build networks + if not getattr(args, "hypernetwork", False): + self.feature_field = ImplicitField(den_input_dim, den_feat_dim, + args.feature_embed_dim, args.feature_layers) + else: + den_contxt_dim = self.den_input_dims[-1] + self.feature_field = HyperImplicitField(den_contxt_dim, den_input_dim - den_contxt_dim, + den_feat_dim, args.feature_embed_dim, args.feature_layers) + self.predictor = SignedDistanceField(den_feat_dim, args.density_embed_dim, recurrent=False) + self.renderer = TextureField(tex_input_dim, args.texture_embed_dim, args.texture_layers) + + def parse_inputs(self, arguments): + def fillup(p): + assert len(p) > 0 + default = 'b' if (p[0] != 'ray') and (p[0] != 'normal') else 'a' + + if len(p) == 1: + return [p[0], 0, 3, default] + elif len(p) == 2: + return [p[0], int(p[1]), 3, default] + elif len(p) == 3: + return [p[0], int(p[1]), int(p[2]), default] + return [p[0], int(p[1]), int(p[2]), p[4]] + + filters, input_dims, output_dims = OrderedDict(), [], [] + for p in arguments.split(','): + name, pos_dim, base_dim, pos_type = fillup([a.strip() for a in p.strip().split(':')]) + + if pos_dim > 0: # use positional embedding + func = NeRFPosEmbLinear( + base_dim, base_dim * pos_dim * 2, + angular=(pos_type == 'a'), + no_linear=True, + cat_input=(pos_type != 'a')) + odim = func.out_dim + func.in_dim if func.cat_input else func.out_dim + + else: + func = nn.Identity() + odim = base_dim + + input_dims += [base_dim] + output_dims += [odim] + filters[name] = func + return filters, input_dims, output_dims + + @staticmethod + def add_args(parser): + parser.add_argument('--inputs-to-density', type=str, + help=""" + Types of inputs to predict the density. + Choices of types are emb or pos. + use first . to assign sinsudoal frequency. + use second : to assign the input dimension (in default 3). + use third : to set the type -> basic, angular or gaussian + Size must match + e.g. --inputs-to-density emb:6:32,pos:4 + """) + parser.add_argument('--inputs-to-texture', type=str, + help=""" + Types of inputs to predict the texture. + Choices of types are feat, emb, ray, pos or normal. + """) + + parser.add_argument('--feature-embed-dim', type=int, metavar='N', + help='field hidden dimension for FFN') + parser.add_argument('--density-embed-dim', type=int, metavar='N', + help='hidden dimension of density prediction'), + parser.add_argument('--texture-embed-dim', type=int, metavar='N', + help='hidden dimension of texture prediction') + + parser.add_argument('--input-embed-dim', type=int, metavar='N', + help='number of features for query (in default 3, xyz)') + parser.add_argument('--output-embed-dim', type=int, metavar='N', + help='number of features the field returns') + parser.add_argument('--raydir-embed-dim', type=int, metavar='N', + help='the number of dimension to encode the ray directions') + parser.add_argument('--disable-raydir', action='store_true', + help='if set, not use view direction as additional inputs') + parser.add_argument('--add-pos-embed', type=int, metavar='N', + help='using periodic activation augmentation') + parser.add_argument('--feature-layers', type=int, metavar='N', + help='number of FC layers used to encode') + parser.add_argument('--texture-layers', type=int, metavar='N', + help='number of FC layers used to predict colors') + + # specific parameters (hypernetwork does not work right now) + parser.add_argument('--hypernetwork', action='store_true', + help='use hypernetwork to model feature') + parser.add_argument('--hyper-feature-embed-dim', type=int, metavar='N', + help='feature dimension used to predict the hypernetwork. consistent with context embedding') + + # backgound parameters + parser.add_argument('--background-depth', type=float, + help='the depth of background. used for depth visualization') + parser.add_argument('--background-stop-gradient', action='store_true', + help='do not optimize the background color') + + @torch.enable_grad() # tracking the gradient in case we need to have normal at testing time. + def forward(self, inputs, outputs=['sigma', 'texture']): + filtered_inputs, context = [], None + if 'feat' not in inputs: + for i, name in enumerate(self.den_filters): + d_in, func = self.den_ori_dims[i], self.den_filters[name] + assert (name in inputs), "the encoder must contain target inputs" + assert inputs[name].size(-1) == d_in, "{} dimension must match {} v.s. {}".format( + name, inputs[name].size(-1), d_in) + if name == 'context': + assert (i == (len(self.den_filters) - 1)), "we force context as the last input" + assert inputs[name].size(0) == 1, "context is object level" + context = func(inputs[name]) + else: + filtered_inputs += [func(inputs[name])] + + filtered_inputs = torch.cat(filtered_inputs, -1) + if context is not None: + if getattr(self.args, "hypernetwork", False): + filtered_inputs = (filtered_inputs, context) + else: + filtered_inputs = (torch.cat([filtered_inputs, context.repeat(filtered_inputs.size(0), 1)], -1),) + else: + filtered_inputs = (filtered_inputs, ) + inputs['feat'] = self.feature_field(*filtered_inputs) + + if 'sigma' in outputs: + assert 'feat' in inputs, "feature must be pre-computed" + inputs['sigma'] = self.predictor(inputs['feat'])[0] + + if (('texture' in outputs) and ("normal" in self.tex_filters)) or ("normal" in outputs): + assert 'sigma' in inputs, "sigma must be pre-computed" + assert 'pos' in inputs, "position is used to compute sigma" + grad_pos, = grad( + outputs=inputs['sigma'], inputs=inputs['pos'], + grad_outputs=torch.ones_like(inputs['sigma']), + retain_graph=True) + inputs['normal'] = F.normalize(-grad_pos, p=2, dim=1) # BUG: gradient direction reversed. + + if 'texture' in outputs: + filtered_inputs = [] + for i, name in enumerate(self.tex_filters): + d_in, func = self.tex_ori_dims[i], self.tex_filters[name] + assert (name in inputs), "the encoder must contain target inputs" + assert inputs[name].size(-1) == d_in, "dimension must match" + + filtered_inputs += [func(inputs[name])] + + filtered_inputs = torch.cat(filtered_inputs, -1) + inputs['texture'] = self.renderer(filtered_inputs) + + return inputs + + + +@register_field('disentangled_radiance_field') +class DisentangledRaidanceField(RaidanceField): + + def __init__(self, args): + super().__init__(args) + + # for now we fix the input types + assert [name for name in self.tex_filters] == ['feat', 'pos', 'normal', 'ray'] + + # rebuild the renderer + self.D = getattr(args, "compressed_light_dim", 64) # D + self.renderer = nn.ModuleDict( + { + "light-transport": nn.Sequential( + ImplicitField( + in_dim=sum([self.tex_input_dims[t] for t in [2, 3]]), + out_dim=self.D * 3, + hidden_dim=args.texture_embed_dim, + num_layers=args.texture_layers, + outmost_linear=True + ), nn.Sigmoid()), # f(v, n, w) + "visibility": nn.Sequential( + ImplicitField( + in_dim=sum([self.tex_input_dims[t] for t in [0, 1]]), + out_dim=self.D, + hidden_dim=args.texture_embed_dim, + num_layers=args.texture_layers, + outmost_linear=True + ), nn.Sigmoid()), # v(x, w) + "lighting": nn.Sequential( + BackgroundField( + out_dim=self.D * 3, min_color=0 + ), nn.ReLU()) # L(w) + } + ) + + @staticmethod + def add_args(parser): + RaidanceField.add_args(parser) + parser.add_argument('---compressed-light-dim', type=int, + help='instead of sampling light directions physically, we compressed the light directions') + + @torch.enable_grad() # tracking the gradient in case we need to have normal at testing time. + def forward(self, inputs, outputs=['sigma', 'texture']): + inputs = super().forward(inputs, outputs=['sigma', 'normal']) + if 'texture' in outputs: + lt = self.renderer['light-transport']( + torch.cat([self.tex_filters['normal'](inputs['normal']), + self.tex_filters['ray'](inputs['ray'])], -1)).reshape(-1, self.D, 3) + vs = self.renderer['visibility']( + torch.cat([self.tex_filters['feat'](inputs['feat']), + self.tex_filters['pos'](inputs['pos'])], -1)).reshape(-1, self.D, 1) + light = self.renderer['lighting'](inputs['ray']).reshape(-1, self.D, 3) + texture = (lt * vs * light).mean(1) + if self.min_color == -1: + texture = 2 * texture - 1 + inputs['texture'] = texture + return inputs \ No newline at end of file diff --git a/fairnr/modules/hyper.py b/fairnr/modules/hyper.py new file mode 100644 index 0000000..0d76c33 --- /dev/null +++ b/fairnr/modules/hyper.py @@ -0,0 +1,244 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +''' +Pytorch implementations of hyper-network modules. +This code is largely adapted from +https://github.com/vsitzmann/scene-representation-networks +''' + +import torch +import torch.nn as nn +import functools + +from fairnr.modules.linear import FCBlock + + +def partialclass(cls, *args, **kwds): + + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwds) + + return NewCls + + +class LookupLayer(nn.Module): + def __init__(self, in_ch, out_ch, num_objects): + super().__init__() + + self.out_ch = out_ch + self.lookup_lin = LookupLinear(in_ch, + out_ch, + num_objects=num_objects) + self.norm_nl = nn.Sequential( + nn.LayerNorm([self.out_ch], elementwise_affine=False), + nn.ReLU(inplace=True) + ) + + def forward(self, obj_idx): + net = nn.Sequential( + self.lookup_lin(obj_idx), + self.norm_nl + ) + return net + + +class LookupFC(nn.Module): + def __init__(self, + hidden_ch, + num_hidden_layers, + num_objects, + in_ch, + out_ch, + outermost_linear=False): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append(LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects)) + + for i in range(num_hidden_layers): + self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects)) + + if outermost_linear: + self.layers.append(LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) + else: + self.layers.append(LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)) + + def forward(self, obj_idx): + net = [] + for i in range(len(self.layers)): + net.append(self.layers[i](obj_idx)) + + return nn.Sequential(*net) + + +class LookupLinear(nn.Module): + def __init__(self, + in_ch, + out_ch, + num_objects): + super().__init__() + self.in_ch = in_ch + self.out_ch = out_ch + + self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch) + + for i in range(num_objects): + nn.init.kaiming_normal_(self.hypo_params.weight.data[i, :self.in_ch * self.out_ch].view(self.out_ch, self.in_ch), + a=0.0, + nonlinearity='relu', + mode='fan_in') + self.hypo_params.weight.data[i, self.in_ch * self.out_ch:].fill_(0.) + + def forward(self, obj_idx): + hypo_params = self.hypo_params(obj_idx) + + # Indices explicit to catch erros in shape of output layer + weights = hypo_params[..., :self.in_ch * self.out_ch] + biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] + + biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) + weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) + + return BatchLinear(weights=weights, biases=biases) + + +class HyperLayer(nn.Module): + '''A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU.''' + def __init__(self, + in_ch, + out_ch, + hyper_in_ch, + hyper_num_hidden_layers, + hyper_hidden_ch): + super().__init__() + + self.hyper_linear = HyperLinear(in_ch=in_ch, + out_ch=out_ch, + hyper_in_ch=hyper_in_ch, + hyper_num_hidden_layers=hyper_num_hidden_layers, + hyper_hidden_ch=hyper_hidden_ch) + self.norm_nl = nn.Sequential( + nn.LayerNorm([out_ch], elementwise_affine=False), + nn.ReLU(inplace=True) + ) + + def forward(self, hyper_input): + ''' + :param hyper_input: input to hypernetwork. + :return: nn.Module; predicted fully connected network. + ''' + return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl) + + +class HyperFC(nn.Module): + '''Builds a hypernetwork that predicts a fully connected neural network. + ''' + def __init__(self, + hyper_in_ch, + hyper_num_hidden_layers, + hyper_hidden_ch, + hidden_ch, + num_hidden_layers, + in_ch, + out_ch, + outermost_linear=False): + super().__init__() + + PreconfHyperLinear = partialclass(HyperLinear, + hyper_in_ch=hyper_in_ch, + hyper_num_hidden_layers=hyper_num_hidden_layers, + hyper_hidden_ch=hyper_hidden_ch) + PreconfHyperLayer = partialclass(HyperLayer, + hyper_in_ch=hyper_in_ch, + hyper_num_hidden_layers=hyper_num_hidden_layers, + hyper_hidden_ch=hyper_hidden_ch) + + self.layers = nn.ModuleList() + self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch)) + + for i in range(num_hidden_layers): + self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch)) + + if outermost_linear: + self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch)) + else: + self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch)) + + + def forward(self, hyper_input): + ''' + :param hyper_input: Input to hypernetwork. + :return: nn.Module; Predicted fully connected neural network. + ''' + net = [] + for i in range(len(self.layers)): + net.append(self.layers[i](hyper_input)) + + return nn.Sequential(*net) + + +class BatchLinear(nn.Module): + def __init__(self, + weights, + biases): + '''Implements a batch linear layer. + + :param weights: Shape: (batch, out_ch, in_ch) + :param biases: Shape: (batch, 1, out_ch) + ''' + super().__init__() + + self.weights = weights + self.biases = biases + + def __repr__(self): + return "BatchLinear(batch=%d, in_ch=%d, out_ch=%d)"%( + self.weights.shape[0], self.weights.shape[-1], self.weights.shape[-2]) + + def forward(self, input): + output = input.matmul(self.weights.permute(*[i for i in range(len(self.weights.shape)-2)], -1, -2)) + output += self.biases + return output + + +def last_hyper_layer_init(m): + if type(m) == nn.Linear: + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + m.weight.data *= 1e-1 + + +class HyperLinear(nn.Module): + '''A hypernetwork that predicts a single linear layer (weights & biases).''' + def __init__(self, + in_ch, + out_ch, + hyper_in_ch, + hyper_num_hidden_layers, + hyper_hidden_ch): + + super().__init__() + self.in_ch = in_ch + self.out_ch = out_ch + + self.hypo_params = FCBlock( + in_features=hyper_in_ch, + hidden_ch=hyper_hidden_ch, + num_hidden_layers=hyper_num_hidden_layers, + out_features=(in_ch * out_ch) + out_ch, + outermost_linear=True) + self.hypo_params[-1].apply(last_hyper_layer_init) + + def forward(self, hyper_input): + hypo_params = self.hypo_params(hyper_input.cuda()) + + # Indices explicit to catch erros in shape of output layer + weights = hypo_params[..., :self.in_ch * self.out_ch] + biases = hypo_params[..., self.in_ch * self.out_ch:(self.in_ch * self.out_ch)+self.out_ch] + + biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) + weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) + + return BatchLinear(weights=weights, biases=biases) diff --git a/fairnr/modules/implicit.py b/fairnr/modules/implicit.py new file mode 100644 index 0000000..1f25f1c --- /dev/null +++ b/fairnr/modules/implicit.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.utils import get_activation_fn +from fairnr.modules.hyper import HyperFC +from fairnr.modules.linear import ( + NeRFPosEmbLinear, FCLayer, ResFCLayer +) + + +class BackgroundField(nn.Module): + """ + Background (we assume a uniform color) + """ + def __init__(self, out_dim=3, bg_color="1.0,1.0,1.0", min_color=-1, stop_grad=False, background_depth=5.0): + super().__init__() + + if out_dim == 3: # directly model RGB + bg_color = [float(b) for b in bg_color.split(',')] if isinstance(bg_color, str) else [bg_color] + if min_color == -1: + bg_color = [b * 2 - 1 for b in bg_color] + if len(bg_color) == 1: + bg_color = bg_color + bg_color + bg_color + bg_color = torch.tensor(bg_color) + else: + bg_color = torch.ones(out_dim).uniform_() + if min_color == -1: + bg_color = bg_color * 2 - 1 + self.out_dim = out_dim + self.bg_color = nn.Parameter(bg_color, requires_grad=not stop_grad) + self.depth = background_depth + + def forward(self, x, **kwargs): + return self.bg_color.unsqueeze(0).expand( + *x.size()[:-1], self.out_dim) + + +class ImplicitField(nn.Module): + + """ + An implicit field is a neural network that outputs a vector given any query point. + """ + def __init__(self, in_dim, out_dim, hidden_dim, num_layers, outmost_linear=False, pos_proj=0): + super().__init__() + if pos_proj > 0: + new_in_dim = in_dim * 2 * pos_proj + self.nerfpos = NeRFPosEmbLinear(in_dim, new_in_dim, no_linear=True) + in_dim = new_in_dim + in_dim + else: + self.nerfpos = None + + self.net = [] + self.net.append(FCLayer(in_dim, hidden_dim)) + for _ in range(num_layers): + self.net.append(FCLayer(hidden_dim, hidden_dim)) + + if not outmost_linear: + self.net.append(FCLayer(hidden_dim, out_dim)) + else: + self.net.append(nn.Linear(hidden_dim, out_dim)) + + self.net = nn.Sequential(*self.net) + self.net.apply(self.init_weights) + + def init_weights(self, m): + if type(m) == nn.Linear: + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + + def forward(self, x): + if self.nerfpos is not None: + x = torch.cat([x, self.nerfpos(x)], -1) + return self.net(x) + + +class HyperImplicitField(nn.Module): + + def __init__(self, hyper_in_dim, in_dim, out_dim, hidden_dim, num_layers, outmost_linear=False, pos_proj=0): + super().__init__() + + self.hyper_in_dim = hyper_in_dim + self.in_dim = in_dim + + if pos_proj > 0: + new_in_dim = in_dim * 2 * pos_proj + self.nerfpos = NeRFPosEmbLinear(in_dim, new_in_dim, no_linear=True) + in_dim = new_in_dim + in_dim + else: + self.nerfpos = None + + self.net = HyperFC( + hyper_in_dim, + 1, 256, + hidden_dim, + num_layers, + in_dim, + out_dim, + outermost_linear=outmost_linear + ) + + def forward(self, x, c): + assert (x.size(-1) == self.in_dim) and (c.size(-1) == self.hyper_in_dim) + if self.nerfpos is not None: + x = torch.cat([x, self.nerfpos(x)], -1) + return self.net(c)(x.unsqueeze(0)).squeeze(0) + + +class SignedDistanceField(nn.Module): + + def __init__(self, in_dim, hidden_dim, recurrent=False): + super().__init__() + self.recurrent = recurrent + + if recurrent: + self.hidden_layer = nn.LSTMCell(input_size=in_dim, hidden_size=hidden_dim) + self.hidden_layer.apply(init_recurrent_weights) + lstm_forget_gate_init(self.hidden_layer) + else: + self.hidden_layer = FCLayer(in_dim, hidden_dim) + + self.output_layer = nn.Linear(hidden_dim, 1) + + def forward(self, x, state=None): + if self.recurrent: + shape = x.size() + state = self.hidden_layer(x.view(-1, shape[-1]), state) + if state[0].requires_grad: + state[0].register_hook(lambda x: x.clamp(min=-5, max=5)) + + return self.output_layer(state[0].view(*shape[:-1], -1)).squeeze(-1), state + + else: + + return self.output_layer(self.hidden_layer(x)).squeeze(-1), None + + +class TextureField(ImplicitField): + """ + Pixel generator based on 1x1 conv networks + """ + def __init__(self, in_dim, hidden_dim, num_layers, with_alpha=False): + out_dim = 3 if not with_alpha else 4 + super().__init__(in_dim, out_dim, hidden_dim, num_layers, outmost_linear=True) + + +# bash scripts/generate/generate_lego.sh $MODEL bulldozer6 2 & +class OccupancyField(ImplicitField): + """ + Occupancy Network which predicts 0~1 at every space + """ + def __init__(self, in_dim, hidden_dim, num_layers): + super().__init__(in_dim, 1, hidden_dim, num_layers, outmost_linear=True) + + def forward(self, x): + return torch.sigmoid(super().forward(x)).squeeze(-1) + + +# ------------------ # +# helper functions # +# ------------------ # +def init_recurrent_weights(self): + for m in self.modules(): + if type(m) in [nn.GRU, nn.LSTM, nn.RNN]: + for name, param in m.named_parameters(): + if 'weight_ih' in name: + nn.init.kaiming_normal_(param.data) + elif 'weight_hh' in name: + nn.init.orthogonal_(param.data) + elif 'bias' in name: + param.data.fill_(0) + + +def lstm_forget_gate_init(lstm_layer): + for name, parameter in lstm_layer.named_parameters(): + if not "bias" in name: continue + n = parameter.size(0) + start, end = n // 4, n // 2 + parameter.data[start:end].fill_(1.) + + +def clip_grad_norm_hook(x, max_norm=10): + total_norm = x.norm() + total_norm = total_norm ** (1 / 2.) + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + return x * clip_coef \ No newline at end of file diff --git a/fairnr/modules/linear.py b/fairnr/modules/linear.py new file mode 100644 index 0000000..2eef443 --- /dev/null +++ b/fairnr/modules/linear.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.modules import LayerNorm +from fairseq.utils import get_activation_fn + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + return m + + +class PosEmbLinear(nn.Module): + + def __init__(self, in_dim, out_dim, no_linear=False, scale=1024): + super().__init__() + assert out_dim % (2 * in_dim) == 0, "dimension must be dividable" + half_dim = out_dim // 2 // in_dim + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + + self.emb = nn.Parameter(emb, requires_grad=False) + self.linear = Linear(out_dim, out_dim) if not no_linear else None + self.scale = scale + self.in_dim = in_dim + self.out_dim = out_dim + + def forward(self, x): + assert x.size(-1) == self.in_dim, "size must match" + sizes = x.size() + x = self.scale * x.unsqueeze(-1) @ self.emb.unsqueeze(0) + x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) + x = x.view(*sizes[:-1], self.out_dim) + if self.linear is not None: + return self.linear(x) + return x + + +class NeRFPosEmbLinear(nn.Module): + + def __init__(self, in_dim, out_dim, angular=False, no_linear=False, cat_input=False): + super().__init__() + assert out_dim % (2 * in_dim) == 0, "dimension must be dividable" + L = out_dim // 2 // in_dim + emb = torch.exp(torch.arange(L, dtype=torch.float) * math.log(2.)) + if not angular: + emb = emb * math.pi + + self.emb = nn.Parameter(emb, requires_grad=False) + self.angular = angular + self.linear = Linear(out_dim, out_dim) if not no_linear else None + self.in_dim = in_dim + self.out_dim = out_dim + self.cat_input = cat_input + + def forward(self, x): + assert x.size(-1) == self.in_dim, "size must match" + sizes = x.size() + inputs = x.clone() + + if self.angular: + x = torch.acos(x.clamp(-1 + 1e-6, 1 - 1e-6)) + x = x.unsqueeze(-1) @ self.emb.unsqueeze(0) + x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) + x = x.view(*sizes[:-1], self.out_dim) + if self.linear is not None: + x = self.linear(x) + if self.cat_input: + x = torch.cat([x, inputs], -1) + return x + + def extra_repr(self) -> str: + outstr = 'Sinusoidal (in={}, out={}, angular={})'.format( + self.in_dim, self.out_dim, self.angular) + if self.cat_input: + outstr = 'Cat({}, {})'.format(self.in_dim, outstr) + return outstr + + +class FCLayer(nn.Module): + """ + Reference: + https://github.com/vsitzmann/pytorch_prototyping/blob/10f49b1e7df38a58fd78451eac91d7ac1a21df64/pytorch_prototyping.py + """ + def __init__(self, in_dim, out_dim): + super().__init__() + + self.net = nn.Sequential( + nn.Linear(in_dim, out_dim), + nn.LayerNorm([out_dim]), + nn.ReLU(inplace=True)) + + def forward(self, x): + return self.net(x) + + +class FCBlock(nn.Module): + def __init__(self, + hidden_ch, + num_hidden_layers, + in_features, + out_features, + outermost_linear=False): + super().__init__() + + self.net = [] + self.net.append(FCLayer(in_features, hidden_ch)) + + for i in range(num_hidden_layers): + self.net.append(FCLayer(hidden_ch, hidden_ch)) + + if outermost_linear: + self.net.append(Linear(hidden_ch, out_features)) + else: + self.net.append(FCLayer(hidden_ch, out_features)) + + self.net = nn.Sequential(*self.net) + self.net.apply(self.init_weights) + + def __getitem__(self,item): + return self.net[item] + + def init_weights(self, m): + if type(m) == nn.Linear: + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + + def forward(self, input): + return self.net(input) + + +class ResFCLayer(nn.Module): + """ + Reference: + https://github.com/autonomousvision/occupancy_networks/blob/master/im2mesh/layers.py + """ + def __init__(self, in_dim, out_dim, hidden_dim, act='relu', dropout=0.0): + super().__init__() + + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + # self.layernorm = LayerNorm(out_dim) + self.nonlinear = get_activation_fn(activation=act) + self.dropout = dropout + + # Initialization (?) + nn.init.zeros_(self.fc2.weight) + + def forward(self, x): + residual = x + x = self.fc1(self.nonlinear(x)) + x = self.fc2(self.nonlinear(x)) + if self.dropout > 0: + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + # return self.layernorm(x) + return x \ No newline at end of file diff --git a/fairnr/modules/reader.py b/fairnr/modules/reader.py new file mode 100644 index 0000000..8c636d0 --- /dev/null +++ b/fairnr/modules/reader.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import random, os, glob + +from fairnr.data.geometry import get_ray_direction, r6d2mat + +torch.autograd.set_detect_anomaly(True) +TINY = 1e-9 +READER_REGISTRY = {} + +def register_reader(name): + def register_reader_cls(cls): + if name in READER_REGISTRY: + raise ValueError('Cannot register duplicate module ({})'.format(name)) + READER_REGISTRY[name] = cls + return cls + return register_reader_cls + + +def get_reader(name): + if name not in READER_REGISTRY: + raise ValueError('Cannot find module {}'.format(name)) + return READER_REGISTRY[name] + + +@register_reader('abstract_reader') +class Reader(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + + def forward(self, **kwargs): + raise NotImplementedError + + @staticmethod + def add_args(parser): + pass + + +@register_reader('image_reader') +class ImageReader(Reader): + """ + basic image reader + """ + def __init__(self, args): + super().__init__(args) + self.num_pixels = args.pixel_per_view + self.no_sampling = getattr(args, "no_sampling_at_reader", False) + + self.deltas = None + if getattr(args, "trainable_extrinsics", False): + self.all_data = self.find_data() + self.all_data_idx = {data_img: (s, v) + for s, data in enumerate(self.all_data) + for v, data_img in enumerate(data)} + self.deltas = nn.ParameterList([ + nn.Parameter(torch.tensor( + [[1., 0., 0., 0., 1., 0., 0., 0., 0.]]).repeat(len(data), 1)) + for data in self.all_data]) + + def find_data(self): + paths = self.args.data + if os.path.isdir(paths): + self.paths = [paths] + else: + self.paths = [line.strip() for line in open(paths)] + return [sorted(glob.glob("{}/rgb/*".format(p))) for p in self.paths] + + @staticmethod + def add_args(parser): + parser.add_argument('--pixel-per-view', type=float, metavar='N', + help='number of pixels sampled for each view') + parser.add_argument("--sampling-on-mask", nargs='?', const=0.9, type=float, + help="this value determined the probability of sampling rays on masks") + parser.add_argument("--sampling-at-center", type=float, + help="only useful for training where we restrict sampling at center of the image") + parser.add_argument("--sampling-on-bbox", action='store_true', + help="sampling points to close to the mask") + parser.add_argument("--sampling-patch-size", type=int, + help="sample pixels based on patches instead of independent pixels") + parser.add_argument("--sampling-skipping-size", type=int, + help="sample pixels if we have skipped pixels") + parser.add_argument("--no-sampling-at-reader", action='store_true', + help="do not perform sampling.") + parser.add_argument("--trainable-extrinsics", action='store_true', + help="if set, we assume extrinsics are trainable. We use 6D representations for rotation") + + def forward(self, uv, intrinsics, extrinsics, size, path=None, **kwargs): + S, V = uv.size()[:2] + if (not self.training) or self.no_sampling: + uv = uv.reshape(S, V, 2, -1, 1, 1) + flatten_uv = uv.reshape(S, V, 2, -1) + else: + uv, _ = self.sample_pixels(uv, size, **kwargs) + flatten_uv = uv.reshape(S, V, 2, -1) + + # go over all shapes + ray_start, ray_dir = [[] for _ in range(S)], [[] for _ in range(S)] + for s in range(S): + for v in range(V): + ixt = intrinsics[s] if intrinsics.dim() == 3 else intrinsics[s, v] + ext = extrinsics[s, v] + translation, rotation = ext[:3, 3], ext[:3, :3] + if (self.deltas is not None) and (path is not None): + shape_id, view_id = self.all_data_idx[path[s][v]] + delta = self.deltas[shape_id][view_id] + d_t, d_r = delta[6:], r6d2mat(delta[None, :6]).squeeze(0) + rotation = rotation @ d_r + translation = translation + d_t + ext = torch.cat([torch.cat([rotation, translation[:, None]], 1), ext[3:]], 0) + + ray_start[s] += [translation] + ray_dir[s] += [get_ray_direction(translation, flatten_uv[s, v], ixt, ext, 1)] + ray_start = torch.stack([torch.stack(r) for r in ray_start]) + ray_dir = torch.stack([torch.stack(r) for r in ray_dir]) + return ray_start.unsqueeze(-2), ray_dir.transpose(2, 3), uv + + @torch.no_grad() + def sample_pixels(self, uv, size, alpha=None, mask=None, **kwargs): + H, W = int(size[0,0,0]), int(size[0,0,1]) + S, V = uv.size()[:2] + + if mask is None: + if alpha is not None: + mask = (alpha > 0) + else: + mask = uv.new_ones(S, V, uv.size(-1)).bool() + mask = mask.float().reshape(S, V, H, W) + + if self.args.sampling_at_center < 1.0: + r = (1 - self.args.sampling_at_center) / 2.0 + mask0 = mask.new_zeros(S, V, H, W) + mask0[:, :, int(H * r): H - int(H * r), int(W * r): W - int(W * r)] = 1 + mask = mask * mask0 + + if self.args.sampling_on_bbox: + x_has_points = mask.sum(2, keepdim=True) > 0 + y_has_points = mask.sum(3, keepdim=True) > 0 + mask = (x_has_points & y_has_points).float() + + probs = mask / (mask.sum() + 1e-8) + if self.args.sampling_on_mask > 0.0: + probs = self.args.sampling_on_mask * probs + (1 - self.args.sampling_on_mask) * 1.0 / (H * W) + + num_pixels = int(self.args.pixel_per_view) + patch_size, skip_size = self.args.sampling_patch_size, self.args.sampling_skipping_size + C = patch_size * skip_size + + if C > 1: + probs = probs.reshape(S, V, H // C, C, W // C, C).sum(3).sum(-1) + num_pixels = num_pixels // patch_size // patch_size + + flatten_probs = probs.reshape(S, V, -1) + sampled_index = sampling_without_replacement(torch.log(flatten_probs+ TINY), num_pixels) + sampled_masks = torch.zeros_like(flatten_probs).scatter_(-1, sampled_index, 1).reshape(S, V, H // C, W // C) + + if C > 1: + sampled_masks = sampled_masks[:, :, :, None, :, None].repeat( + 1, 1, 1, patch_size, 1, patch_size).reshape(S, V, H // skip_size, W // skip_size) + if skip_size > 1: + full_datamask = sampled_masks.new_zeros(S, V, skip_size * skip_size, H // skip_size, W // skip_size) + full_index = torch.randint(skip_size*skip_size, (S, V)) + for i in range(S): + for j in range(V): + full_datamask[i, j, full_index[i, j]] = sampled_masks[i, j] + sampled_masks = full_datamask.reshape( + S, V, skip_size, skip_size, H // skip_size, W // skip_size).permute(0, 1, 4, 2, 5, 3).reshape(S, V, H, W) + + X, Y = uv[:,:,0].reshape(S, V, H, W), uv[:,:,1].reshape(S, V, H, W) + X = X[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size) + Y = Y[sampled_masks>0].reshape(S, V, 1, -1, patch_size, patch_size) + return torch.cat([X, Y], 2), sampled_masks + + +def sampling_without_replacement(logp, k): + def gumbel_like(u): + return -torch.log(-torch.log(torch.rand_like(u) + TINY) + TINY) + scores = logp + gumbel_like(logp) + return scores.topk(k, dim=-1)[1] \ No newline at end of file diff --git a/fairnr/modules/renderer.py b/fairnr/modules/renderer.py new file mode 100644 index 0000000..066c0c6 --- /dev/null +++ b/fairnr/modules/renderer.py @@ -0,0 +1,262 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairnr.modules.linear import FCLayer +from fairnr.data.geometry import ray +from torchsearchsorted import searchsorted + +MAX_DEPTH = 10000.0 +RENDERER_REGISTRY = {} + +def register_renderer(name): + def register_renderer_cls(cls): + if name in RENDERER_REGISTRY: + raise ValueError('Cannot register duplicate module ({})'.format(name)) + RENDERER_REGISTRY[name] = cls + return cls + return register_renderer_cls + + +def get_renderer(name): + if name not in RENDERER_REGISTRY: + raise ValueError('Cannot find module {}'.format(name)) + return RENDERER_REGISTRY[name] + + +@register_renderer('abstract_renderer') +class Renderer(nn.Module): + """ + Abstract class for ray marching + """ + def __init__(self, args): + super().__init__() + self.args = args + + def forward(self, **kwargs): + raise NotImplementedError + + @staticmethod + def add_args(parser): + pass + + +@register_renderer('volume_rendering') +class VolumeRenderer(Renderer): + + def __init__(self, args): + super().__init__(args) + self.chunk_size = 1024 * getattr(args, "chunk_size", 64) + self.valid_chunk_size = 1024 * getattr(args, "valid_chunk_size", self.chunk_size // 1024) + self.discrete_reg = getattr(args, "discrete_regularization", False) + self.raymarching_tolerance = getattr(args, "raymarching_tolerance", 0.0) + + @staticmethod + def add_args(parser): + # ray-marching parameters + parser.add_argument('--discrete-regularization', action='store_true', + help='if set, a zero mean unit variance gaussian will be added to encougrage discreteness') + + # additional arguments + parser.add_argument('--chunk-size', type=int, metavar='D', + help='set chunks to go through the network (~K forward passes). trade time for memory. ') + parser.add_argument('--valid-chunk-size', type=int, metavar='D', + help='chunk size used when no training. In default the same as chunk-size.') + parser.add_argument('--raymarching-tolerance', type=float, default=0) + + def forward_once( + self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, + early_stop=None, output_types=['sigma', 'texture'] + ): + """ + chunks: set > 1 if out-of-memory. it can save some memory by time. + """ + sampled_depth = samples['sampled_point_depth'] + sampled_dists = samples['sampled_point_distance'] + sampled_idx = samples['sampled_point_voxel_idx'].long() + + # only compute when the ray hits + sample_mask = sampled_idx.ne(-1) + if early_stop is not None: + sample_mask = sample_mask & (~early_stop.unsqueeze(-1)) + if sample_mask.sum() == 0: # miss everything skip + return None, 0 + + sampled_xyz = ray(ray_start.unsqueeze(1), ray_dir.unsqueeze(1), sampled_depth.unsqueeze(2)) + sampled_dir = ray_dir.unsqueeze(1).expand(*sampled_depth.size(), ray_dir.size()[-1]) + samples['sampled_point_xyz'] = sampled_xyz + samples['sampled_point_ray_direction'] = sampled_dir + + # apply mask + samples = {name: s[sample_mask] for name, s in samples.items()} + + # get encoder features as inputs + field_inputs = input_fn(samples, encoder_states) + + # forward implicit fields + field_outputs = field_fn(field_inputs, outputs=output_types) + outputs = {'sample_mask': sample_mask} + + def masked_scatter(mask, x): + B, K = mask.size() + if x.dim() == 1: + return x.new_zeros(B, K).masked_scatter(mask, x) + return x.new_zeros(B, K, x.size(-1)).masked_scatter( + mask.unsqueeze(-1).expand(B, K, x.size(-1)), x) + + # post processing + if 'sigma' in field_outputs: + sigma, sampled_dists= field_outputs['sigma'], samples['sampled_point_distance'] + noise = 0 if not self.discrete_reg and (not self.training) else torch.zeros_like(sigma).normal_() + free_energy = torch.relu(noise + sigma) * sampled_dists + # (optional) free_energy = (F.elu(sigma - 3, alpha=1) + 1) * dists + outputs['free_energy'] = masked_scatter(sample_mask, free_energy) + if 'texture' in field_outputs: + outputs['texture'] = masked_scatter(sample_mask, field_outputs['texture']) + if 'normal' in field_outputs: + outputs['normal'] = masked_scatter(sample_mask, field_outputs['normal']) + return outputs, sample_mask.sum() + + def forward_chunk( + self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, + gt_depths=None, output_types=['sigma', 'texture'], global_weights=None, + ): + sampled_depth = samples['sampled_point_depth'] + sampled_idx = samples['sampled_point_voxel_idx'].long() + + tolerance = self.raymarching_tolerance + chunk_size = self.chunk_size if self.training else self.valid_chunk_size + early_stop = None + if tolerance > 0: + tolerance = -math.log(tolerance) + + hits = sampled_idx.ne(-1).long() + outputs = defaultdict(lambda: []) + size_so_far, start_step = 0, 0 + accumulated_free_energy = 0 + accumulated_evaluations = 0 + for i in range(hits.size(1) + 1): + if ((i == hits.size(1)) or (size_so_far + hits[:, i].sum() > chunk_size)) and (i > start_step): + _outputs, _evals = self.forward_once( + input_fn, field_fn, + ray_start, ray_dir, + {name: s[:, start_step: i] + for name, s in samples.items()}, + encoder_states, + early_stop=early_stop, + output_types=output_types) + if _outputs is not None: + accumulated_evaluations += _evals + + if 'free_energy' in _outputs: + accumulated_free_energy += _outputs['free_energy'].sum(1) + if tolerance > 0: + early_stop = accumulated_free_energy > tolerance + hits[early_stop] *= 0 + + for key in _outputs: + outputs[key] += [_outputs[key]] + else: + for key in outputs: + outputs[key] += [outputs[key][-1].new_zeros( + outputs[key][-1].size(0), + sampled_depth[:, start_step: i].size(1), + *outputs[key][-1].size()[2:] + )] + start_step, size_so_far = i, 0 + + if (i < hits.size(1)): + size_so_far += hits[:, i].sum() + + outputs = {key: torch.cat(outputs[key], 1) for key in outputs} + results = {} + + if 'free_energy' in outputs: + free_energy = outputs['free_energy'] + shifted_free_energy = torch.cat([free_energy.new_zeros(sampled_depth.size(0), 1), free_energy[:, :-1]], dim=-1) # shift one step + a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here + b = torch.exp(-torch.cumsum(shifted_free_energy.float(), dim=-1)) # probability of everything is empty up to now + probs = (a * b).type_as(free_energy) # probability of the ray hits something here + else: + probs = outputs['sample_mask'].type_as(sampled_depth) / sampled_depth.size(-1) # assuming a uniform distribution + + if global_weights is not None: + probs = probs * global_weights + + depth = (sampled_depth * probs).sum(-1) + missed = 1 - probs.sum(-1) + results.update({'probs': probs, 'depths': depth, 'missed': missed, 'ae': accumulated_evaluations}) + + if 'texture' in outputs: + results['colors'] = (outputs['texture'] * probs.unsqueeze(-1)).sum(-2) + if 'normal' in outputs: + results['normal'] = (outputs['normal'] * probs.unsqueeze(-1)).sum(-2) + return results + + def forward(self, input_fn, field_fn, ray_start, ray_dir, samples, *args, **kwargs): + chunk_size = self.chunk_size if self.training else self.valid_chunk_size + if ray_start.size(0) <= chunk_size: + return self.forward_chunk(input_fn, field_fn, ray_start, ray_dir, samples, *args, **kwargs) + + # the number of rays is larger than maximum forward passes. pre-chuncking.. + results = [ + self.forward_chunk(input_fn, field_fn, + ray_start[i: i+chunk_size], ray_dir[i: i+chunk_size], + {name: s[i: i+chunk_size] for name, s in samples.items()}, *args, **kwargs) + for i in range(0, ray_start.size(0), chunk_size) + ] + return {name: torch.cat([r[name] for r in results], 0) + if results[0][name].dim() > 0 else sum([r[name] for r in results]) + for name in results[0]} + + +@register_renderer("resampled_volume_rendering") +class ResampledVolumeRenderer(VolumeRenderer): + + def forward_chunk(self, input_fn, field_fn, ray_start, ray_dir, samples, encoder_states, gt_depths=None): + results0 = super().forward_chunk( + input_fn, field_fn, ray_start, ray_dir, samples, + encoder_states, output_types=['sigma']) # infer probability + # resample based on piecewise distribution with inverse CDF (only sample non-missing points) + new_samples = resample_pdf(results0['probs'], samples, n_samples=16, deterministic=True) + return super().forward_chunk(input_fn, field_fn, ray_start, ray_dir, new_samples, + encoder_states, output_types=['texture'], + global_weights=results0['probs'].sum(-1, keepdims=True)) # get texture + + +def resample_pdf(probs, samples, n_samples=32, deterministic=False): + sampled_depth, sampled_idx, sampled_dists = samples + + # compute CDF + pdf = probs / (probs.sum(-1, keepdims=True) + 1e-7) + cdf = torch.cat([torch.zeros_like(pdf[...,:1]), torch.cumsum(pdf, -1)], -1) + + # generate random samples + z = torch.arange(n_samples, device=cdf.device, dtype=cdf.dtype).expand( + cdf.size(0), n_samples).contiguous() + if deterministic: + z = z + 0.5 + else: + z = z + z.clone().uniform_() + z = z / float(n_samples) + + # inverse transform sampling + inds = searchsorted(cdf, z) - 1 + inds_miss = inds.eq(sampled_idx.size(1)) + inds_safe = inds.clamp(max=sampled_idx.size(1)-1) + resampled_below, resampled_above = cdf.gather(1, inds_safe), cdf.gather(1, inds_safe + 1) + resampled_idx = sampled_idx.gather(1, inds_safe).masked_fill(inds_miss, -1) + resampled_depth = sampled_depth.gather(1, inds_safe).masked_fill(inds_miss, MAX_DEPTH) + resampled_dists = sampled_dists.gather(1, inds_safe).masked_fill(inds_miss, 0.0) + + # reparameterization + resampled_depth = ((z - resampled_below) / (resampled_above - resampled_below + 1e-7) - 0.5) * resampled_dists + resampled_depth + return resampled_depth, resampled_idx, resampled_depth diff --git a/fairnr/options.py b/fairnr/options.py new file mode 100644 index 0000000..1a77444 --- /dev/null +++ b/fairnr/options.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +import torch + + +from fairseq import options + + +def parse_args_and_arch(*args, **kwargs): + return options.parse_args_and_arch(*args, **kwargs) + + +def get_rendering_parser(default_task="single_object_rendering"): + parser = options.get_parser("Rendering", default_task) + options.add_dataset_args(parser, gen=True) + add_rendering_args(parser) + return parser + + +def add_rendering_args(parser): + group = parser.add_argument_group("Rendering") + options.add_common_eval_args(group) + group.add_argument("--render-beam", default=5, type=int, metavar="N", + help="beam size for parallel rendering") + group.add_argument("--render-resolution", default="512x512", type=str, metavar="N", help='if provide two numbers, means H x W') + group.add_argument("--render-angular-speed", default=1, type=float, metavar="D", + help="angular speed when rendering around the object") + group.add_argument("--render-num-frames", default=500, type=int, metavar="N") + group.add_argument("--render-path-style", default="circle", choices=["circle", "zoomin_circle", "zoomin_line"], type=str) + group.add_argument("--render-path-args", default="{'radius': 2.5, 'h': 0.0}", + help="specialized arguments for rendering paths") + group.add_argument("--render-output", default=None, type=str) + group.add_argument("--render-at-vector", default="(0,0,0)", type=str) + group.add_argument("--render-up-vector", default="(0,0,-1)", type=str) + group.add_argument("--render-output-types", nargs="+", type=str, default=["rgb"], + choices=["target", "color", "depth", "normal", "voxel", "predn"]) + group.add_argument("--render-raymarching-steps", default=None, type=int) + group.add_argument("--render-save-fps", default=24, type=int) + group.add_argument("--render-combine-output", action='store_true', + help="if set, concat the images into one file.") + group.add_argument("--render-camera-poses", default=None, type=str, + help="text file saved for the testing trajectories") + group.add_argument("--render-camera-intrinsics", default=None, type=str) + group.add_argument("--render-views", type=str, default=None, + help="views sampled for rendering, you can set specific view id, or a range") \ No newline at end of file diff --git a/fairnr/renderer.py b/fairnr/renderer.py new file mode 100644 index 0000000..f321052 --- /dev/null +++ b/fairnr/renderer.py @@ -0,0 +1,235 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file is to simulate "generator" in fairseq +""" + +import os, tempfile, shutil, glob +import time +import torch +import numpy as np +import logging +import imageio + +from torchvision.utils import save_image +from fairnr.data import trajectory, geometry, data_utils +from fairseq.meters import StopwatchMeter +from fairnr.data.data_utils import recover_image, get_uv, parse_views +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class NeuralRenderer(object): + + def __init__(self, + resolution="512x512", + frames=501, + speed=5, + raymarching_steps=None, + path_gen=None, + beam=10, + at=(0,0,0), + up=(0,1,0), + output_dir=None, + output_type=None, + fps=24, + test_camera_poses=None, + test_camera_intrinsics=None, + test_camera_views=None): + + self.frames = frames + self.speed = speed + self.raymarching_steps = raymarching_steps + self.path_gen = path_gen + + if isinstance(resolution, str): + self.resolution = [int(r) for r in resolution.split('x')] + else: + self.resolution = [resolution, resolution] + + self.beam = beam + self.output_dir = output_dir + self.output_type = output_type + self.at = at + self.up = up + self.fps = fps + + if self.path_gen is None: + self.path_gen = trajectory.circle() + if self.output_type is None: + self.output_type = ["rgb"] + + if test_camera_intrinsics is not None: + self.test_int = data_utils.load_intrinsics(test_camera_intrinsics) + else: + self.test_int = None + + self.test_frameids = None + if test_camera_poses is not None: + if os.path.isdir(test_camera_poses): + self.test_poses = [ + np.loadtxt(f)[None, :, :] for f in sorted(glob.glob(test_camera_poses + "/*.txt"))] + self.test_poses = np.concatenate(self.test_poses, 0) + else: + self.test_poses = data_utils.load_matrix(test_camera_poses) + if self.test_poses.shape[1] == 17: + self.test_frameids = self.test_poses[:, -1].astype(np.int32) + self.test_poses = self.test_poses[:, :-1] + self.test_poses = self.test_poses.reshape(-1, 4, 4) + + if test_camera_views is not None: + render_views = parse_views(test_camera_views) + self.test_poses = np.stack([self.test_poses[r] for r in render_views]) + + else: + self.test_poses = None + + def generate_rays(self, t, intrinsics, img_size, inv_RT=None, action='none'): + if inv_RT is None: + cam_pos = torch.tensor(self.path_gen(t * self.speed / 180 * np.pi), + device=intrinsics.device, dtype=intrinsics.dtype) + cam_rot = geometry.look_at_rotation(cam_pos, at=self.at, up=self.up, inverse=True, cv=True) + + inv_RT = cam_pos.new_zeros(4, 4) + inv_RT[:3, :3] = cam_rot + inv_RT[:3, 3] = cam_pos + inv_RT[3, 3] = 1 + else: + inv_RT = torch.from_numpy(inv_RT).type_as(intrinsics) + + h, w, rh, rw = img_size[0], img_size[1], img_size[2], img_size[3] + if self.test_int is not None: + uv = torch.from_numpy(get_uv(h, w, h, w)[0]).type_as(intrinsics) + intrinsics = self.test_int + else: + uv = torch.from_numpy(get_uv(h * rh, w * rw, h, w)[0]).type_as(intrinsics) + + uv = uv.reshape(2, -1) + return uv, inv_RT + + def parse_sample(self,sample): + if len(sample) == 1: + return sample[0], 0, self.frames + elif len(sample) == 2: + return sample[0], sample[1], self.frames + elif len(sample) == 3: + return sample[0], sample[1], sample[2] + else: + raise NotImplementedError + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + model = models[0] + model.eval() + + logger.info("rendering starts. {}".format(model.text)) + output_path = self.output_dir + image_names = [] + sample, step, frames = self.parse_sample(sample) + + # fix the rendering size + a = sample['size'][0,0,0] / self.resolution[0] + b = sample['size'][0,0,1] / self.resolution[1] + sample['size'][:, :, 0] /= a + sample['size'][:, :, 1] /= b + sample['size'][:, :, 2] *= a + sample['size'][:, :, 3] *= b + + for shape in range(sample['shape'].size(0)): + max_step = step + frames + while step < max_step: + next_step = min(step + self.beam, max_step) + uv, inv_RT = zip(*[ + self.generate_rays( + k, + sample['intrinsics'][shape], + sample['size'][shape, 0], + self.test_poses[k] if self.test_poses is not None else None) + for k in range(step, next_step) + ]) + if self.test_frameids is not None: + assert next_step - step == 1 + ids = torch.tensor(self.test_frameids[step: next_step]).type_as(sample['id']) + else: + ids = sample['id'][shape:shape+1] + + real_images = sample['full_rgb'] if 'full_rgb' in sample else sample['colors'] + real_images = real_images.transpose(2, 3) if real_images.size(-1) != 3 else real_images + + _sample = { + 'id': ids, + 'colors': torch.cat([real_images[shape:shape+1] for _ in range(step, next_step)], 1), + 'intrinsics': sample['intrinsics'][shape:shape+1], + 'extrinsics': torch.stack(inv_RT, 0).unsqueeze(0), + 'uv': torch.stack(uv, 0).unsqueeze(0), + 'shape': sample['shape'][shape:shape+1], + 'view': torch.arange( + step, next_step, + device=sample['shape'].device).unsqueeze(0), + 'size': torch.cat([sample['size'][shape:shape+1] for _ in range(step, next_step)], 1), + 'step': step + } + with data_utils.GPUTimer() as timer: + outs = model(**_sample) + logger.info("rendering frame={}\ttotal time={:.4f}\tvoxel={:.4f}".format(step, timer.sum, outs['other_logs']['tvox_log'])) + + for k in range(step, next_step): + images = model.visualize(_sample, None, 0, k-step) + image_name = "{:04d}".format(k) + + for key in images: + name, type = key.split('/')[0].split('_') + if type in self.output_type and name == 'render': + prefix = os.path.join(output_path, type) + Path(prefix).mkdir(parents=True, exist_ok=True) + image = images[key].permute(2, 0, 1) \ + if images[key].dim() == 3 else torch.stack(3*[images[key]], 0) + save_image(image, os.path.join(prefix, image_name + '.png'), format=None) + image_names.append(os.path.join(prefix, image_name + '.png')) + + # save pose matrix + prefix = os.path.join(output_path, 'pose') + Path(prefix).mkdir(parents=True, exist_ok=True) + pose = self.test_poses[k] if self.test_poses is not None else inv_RT[k-step].cpu().numpy() + np.savetxt(os.path.join(prefix, image_name + '.txt'), pose) + + step = next_step + + logger.info("done") + return step, image_names + + def save_images(self, output_files, steps=None, combine_output=True): + if not os.path.exists(self.output_dir): + os.mkdir(self.output_dir) + timestamp = time.strftime('%Y-%m-%d.%H-%M-%S',time.localtime(time.time())) + if steps is not None: + timestamp = "step_{}.".format(steps) + timestamp + + if not combine_output: + for type in self.output_type: + images = [imageio.imread(file_path) for file_path in output_files if type in file_path] + # imageio.mimsave('{}/{}_{}.gif'.format(self.output_dir, type, timestamp), images, fps=self.fps) + imageio.mimwrite('{}/{}_{}.mp4'.format(self.output_dir, type, timestamp), images, fps=self.fps, quality=8) + else: + images = [[imageio.imread(file_path) for file_path in output_files if type in file_path] for type in self.output_type] + images = [np.concatenate([images[j][i] for j in range(len(images))], 1) for i in range(len(images[0]))] + imageio.mimwrite('{}/{}_{}.mp4'.format(self.output_dir, 'full', timestamp), images, fps=self.fps, quality=8) + + return timestamp + + def merge_videos(self, timestamps): + logger.info("mergining mp4 files..") + timestamp = time.strftime('%Y-%m-%d.%H-%M-%S',time.localtime(time.time())) + writer = imageio.get_writer( + os.path.join(self.output_dir, 'full_' + timestamp + '.mp4'), fps=self.fps) + for timestamp in timestamps: + tempfile = os.path.join(self.output_dir, 'full_' + timestamp + '.mp4') + reader = imageio.get_reader(tempfile) + for im in reader: + writer.append_data(im) + os.remove(tempfile) + writer.close() \ No newline at end of file diff --git a/fairnr/tasks/__init__.py b/fairnr/tasks/__init__.py new file mode 100644 index 0000000..9dfde61 --- /dev/null +++ b/fairnr/tasks/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith('.py') and not file.startswith('_'): + task_name = file[:file.find('.py')] + importlib.import_module('fairnr.tasks.' + task_name) diff --git a/fairnr/tasks/neural_rendering.py b/fairnr/tasks/neural_rendering.py new file mode 100644 index 0000000..471c78d --- /dev/null +++ b/fairnr/tasks/neural_rendering.py @@ -0,0 +1,293 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import json +import torch +import imageio +import numpy as np +from collections import defaultdict +from torchvision.utils import save_image +from argparse import Namespace + +from fairseq.tasks import FairseqTask, register_task +from fairseq.optim.fp16_optimizer import FP16Optimizer + +from fairnr.data import ( + ShapeViewDataset, SampledPixelDataset, ShapeViewStreamDataset, + WorldCoordDataset, ShapeDataset, InfiniteDataset +) +from fairnr.data.data_utils import write_images, recover_image, parse_views +from fairnr.data.geometry import ray, compute_normal_map +from fairnr.renderer import NeuralRenderer +from fairnr.data.trajectory import get_trajectory +from fairnr import ResetTrainerException + + +@register_task("single_object_rendering") +class SingleObjRenderingTask(FairseqTask): + """ + Task for remembering & rendering a single object. + """ + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser""" + parser.add_argument("data", help='data-path or data-directoy') + parser.add_argument("--object-id-path", type=str, help='path to object indices', default=None) + parser.add_argument("--no-preload", action="store_true") + parser.add_argument("--no-load-binary", action="store_true") + parser.add_argument("--load-depth", action="store_true", + help="load depth images if exists") + parser.add_argument("--transparent-background", type=str, default="1.0", + help="background color if the image is transparent") + parser.add_argument("--load-mask", action="store_true", + help="load pre-computed masks which is useful for subsampling during training.") + parser.add_argument("--train-views", type=str, default="0..50", + help="views sampled for training, you can set specific view id, or a range") + parser.add_argument("--valid-views", type=str, default="0..50", + help="views sampled for validation, you can set specific view id, or a range") + parser.add_argument("--test-views", type=str, default="0", + help="views sampled for rendering, only used for showing rendering results.") + parser.add_argument("--subsample-valid", type=int, default=-1, + help="if set > -1, subsample the validation (when training set is too large)") + parser.add_argument("--view-per-batch", type=int, default=6, + help="number of views training each batch (each GPU)") + parser.add_argument("--valid-view-per-batch", type=int, default=1, + help="number of views training each batch (each GPU)") + parser.add_argument("--view-resolution", type=str, default='64x64', + help="width for the squared image. downsampled from the original.") + parser.add_argument('--valid-view-resolution', type=str, default=None, + help="if not set, if valid view resolution will be train view resolution") + parser.add_argument("--min-color", choices=(0, -1), default=-1, type=int, + help="RGB range used in the model. conventionally used -1 ~ 1") + parser.add_argument("--virtual-epoch-steps", type=int, default=None, + help="virtual epoch used in Infinite Dataset. if None, set max-update") + parser.add_argument("--pruning-every-steps", type=int, default=None, + help="if the model supports pruning, prune unecessary voxels") + parser.add_argument("--half-voxel-size-at", type=str, default=None, + help='specific detailed number of updates to half the voxel sizes') + parser.add_argument("--reduce-step-size-at", type=str, default=None, + help='specific detailed number of updates to reduce the raymarching step sizes') + parser.add_argument("--rendering-every-steps", type=int, default=None, + help="if set, enables rendering online with default parameters") + parser.add_argument("--rendering-args", type=str, metavar='JSON') + parser.add_argument("--pruning-th", type=float, default=0.5, + help="if larger than this, we choose keep the voxel.") + parser.add_argument("--output-valid", type=str, default=None) + + def __init__(self, args): + super().__init__(args) + + # check dataset + self.train_data = self.val_data = self.test_data = args.data + self.object_ids = None if args.object_id_path is None else \ + {line.strip(): i for i, line in enumerate(open(args.object_id_path))} + self.output_valid = getattr(args, "output_valid", None) + + if os.path.isdir(args.data): + if os.path.exists(args.data + '/train.txt'): + self.train_data = args.data + '/train.txt' + if os.path.exists(args.data + '/val.txt'): + self.val_data = args.data + '/val.txt' + if os.path.exists(args.data + '/test.txt'): + self.test_data = args.data + '/test.txt' + if self.object_ids is None and os.path.exists(args.data + '/object_ids.txt'): + self.object_ids = {line.strip(): i for i, line in enumerate(open(args.data + '/object_ids.txt'))} + if self.object_ids is not None: + self.ids_object = {self.object_ids[o]: o for o in self.object_ids} + else: + self.ids_object = {0: 'model'} + + if len(self.args.tensorboard_logdir) > 0 and getattr(args, "distributed_rank", -1) == 0: + from tensorboardX import SummaryWriter + self.writer = SummaryWriter(self.args.tensorboard_logdir + '/images') + else: + self.writer = None + + self._num_updates = {'pv': 0, 'sv': 0, 'rs': 0, 're': 0} + self.pruning_every_steps = getattr(self.args, "pruning_every_steps", None) + self.pruning_th = getattr(self.args, "pruning_th", 0.5) + self.rendering_every_steps = getattr(self.args, "rendering_every_steps", None) + self.steps_to_half_voxels = getattr(self.args, "half_voxel_size_at", None) + self.steps_to_reduce_step = getattr(self.args, "reduce_step_size_at", None) + + if self.steps_to_half_voxels is not None: + self.steps_to_half_voxels = [int(s) for s in self.steps_to_half_voxels.split(',')] + if self.steps_to_reduce_step is not None: + self.steps_to_reduce_step = [int(s) for s in self.steps_to_reduce_step.split(',')] + + if self.rendering_every_steps is not None: + gen_args = { + 'path': args.save_dir, + 'render_beam': 1, 'render_resolution': '512x512', + 'render_num_frames': 120, 'render_angular_speed': 3, + 'render_output_types': ["rgb"], 'render_raymarching_steps': 10, + 'render_at_vector': "(0,0,0)", 'render_up_vector': "(0,0,-1)", + 'render_path_args': "{'radius': 1.5, 'h': 0.5}", + 'render_path_style': 'circle', "render_output": None + } + gen_args.update(json.loads(getattr(args, 'rendering_args', '{}') or '{}')) + self.renderer = self.build_generator(Namespace(**gen_args)) + else: + self.renderer = None + + self.train_views = parse_views(args.train_views) + self.valid_views = parse_views(args.valid_views) + self.test_views = parse_views(args.test_views) + + @classmethod + def setup_task(cls, args, **kwargs): + """ + Setup the task + """ + return cls(args) + + def repeat_dataset(self, split): + return 1 + + def load_dataset(self, split, **kwargs): + """ + Load a given dataset split (train, valid, test) + """ + DataLoader = ShapeViewStreamDataset if split == 'valid' else ShapeViewDataset + self.datasets[split] = DataLoader( + self.train_data if split == 'train' else \ + self.val_data if split == 'valid' else self.test_data, + views=self.train_views if split == 'train' else \ + self.valid_views if split == 'valid' else self.test_views, + num_view=self.args.view_per_batch if split == 'train' else \ + self.args.valid_view_per_batch if split == 'valid' else 1, + resolution=self.args.view_resolution if split == 'train' else \ + getattr(self.args, "valid_view_resolution", self.args.view_resolution) if split == 'valid' else \ + getattr(self.args, "render_resolution", self.args.view_resolution), + subsample_valid=self.args.subsample_valid if split == 'valid' else -1, + train=(split=='train'), + load_depth=self.args.load_depth and (split!='test'), + load_mask=self.args.load_mask and (split!='test'), + repeat=self.repeat_dataset(split), + preload=(not getattr(self.args, "no_preload", False)) and (split!='test'), + binarize=(not getattr(self.args, "no_load_binary", False)) and (split!='test'), + bg_color=getattr(self.args, "transparent_background", "1,1,1"), + min_color=getattr(self.args, "min_color", -1), + ids=self.object_ids + ) + + if split == 'train': + max_step = getattr(self.args, "virtual_epoch_steps", None) + if max_step is not None: + total_num_models = max_step * self.args.distributed_world_size * self.args.max_sentences + else: + total_num_models = 10000000 + self.datasets[split] = InfiniteDataset(self.datasets[split], total_num_models) + + + def build_generator(self, args): + """ + build a neural renderer for visualization + """ + return NeuralRenderer( + beam=args.render_beam, + resolution=args.render_resolution, + frames=args.render_num_frames, + speed=args.render_angular_speed, + raymarching_steps=args.render_raymarching_steps, + path_gen=get_trajectory(args.render_path_style)( + **eval(args.render_path_args) + ), + at=eval(args.render_at_vector), + up=eval(args.render_up_vector), + fps=getattr(args, "render_save_fps", 24), + output_dir=args.render_output if args.render_output is not None + else os.path.join(args.path, "output"), + output_type=args.render_output_types, + test_camera_poses=getattr(args, "render_camera_poses", None), + test_camera_intrinsics=getattr(args, "render_camera_intrinsics", None), + test_camera_views=getattr(args, "render_views", None) + ) + + @property + def source_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return None + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return None + + def update_step(self, num_updates, name='re'): + """Task level update when number of updates increases. + + This is called after the optimization step and learning rate + update at each iteration. + """ + self._num_updates[name] = num_updates + + def train_step(self, sample, model, criterion, optimizer, update_num, ignore_grad=False): + + if self.pruning_every_steps is not None and \ + (update_num % self.pruning_every_steps == 0) and \ + (update_num > 0) and \ + (update_num > self._num_updates['pv']) and \ + hasattr(model, 'prune_voxels'): + + model.prune_voxels(self.pruning_th) + self.update_step(update_num, 'pv') + + if self.steps_to_half_voxels is not None and \ + (update_num in self.steps_to_half_voxels) and \ + (update_num > self._num_updates['sv']): + + model.split_voxels() + self.update_step(update_num, 'sv') + raise ResetTrainerException + + if self.rendering_every_steps is not None and \ + (update_num % self.rendering_every_steps == 0) and \ + (update_num > 0) and \ + self.renderer is not None and \ + (update_num > self._num_updates['re']): + + sample_clone = {key: sample[key].clone() if sample[key] is not None else None for key in sample } + outputs = self.inference_step(self.renderer, [model], [sample_clone, 0])[1] + if getattr(self.args, "distributed_rank", -1) == 0: # save only for master + self.renderer.save_images(outputs, update_num) + self.steps_to_half_voxels = [a for a in self.steps_to_half_voxels if a != update_num] + + if self.steps_to_reduce_step is not None and \ + update_num in self.steps_to_reduce_step and \ + (update_num > self._num_updates['rs']): + + model.reduce_stepsize() + self.update_step(update_num, 'rs') + + self.update_step(update_num, 'step') + return super().train_step(sample, model, criterion, optimizer, update_num, ignore_grad) + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + model.add_eval_scores(logging_output, sample, model.cache, criterion, outdir=self.output_valid) + if self.writer is not None: + images = model.visualize(sample, shape=0, view=0) + if images is not None: + write_images(self.writer, images, self._num_updates['step']) + + return loss, sample_size, logging_output + + def save_image(self, img, id, view, group='gt'): + object_name = self.ids_object[id.item()] + def _mkdir(x): + if not os.path.exists(x): + os.mkdir(x) + _mkdir(self.output_valid) + _mkdir(os.path.join(self.output_valid, group)) + _mkdir(os.path.join(self.output_valid, group, object_name)) + imageio.imsave(os.path.join( + self.output_valid, group, object_name, + '{:04d}.png'.format(view)), + (img * 255).astype(np.uint8)) + diff --git a/fairnr_cli/__init__.py b/fairnr_cli/__init__.py new file mode 100644 index 0000000..6264236 --- /dev/null +++ b/fairnr_cli/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairnr_cli/extract.py b/fairnr_cli/extract.py new file mode 100644 index 0000000..a0c4b42 --- /dev/null +++ b/fairnr_cli/extract.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +This code is used for extact voxels/meshes from the learne model +""" +import logging +import numpy as np +import torch +import sys, os +import argparse +import open3d as o3d + +from fairseq import options +from fairseq import checkpoint_utils +from plyfile import PlyData, PlyElement + +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=sys.stdout, +) +logger = logging.getLogger('fairnr_cli.extract') + + +def cli_main(): + parser = argparse.ArgumentParser(description='Extract geometry from a trained model (only for learnable embeddings).') + parser.add_argument('--path', type=str, required=True) + parser.add_argument('--output', type=str, required=True) + parser.add_argument('--name', type=str, default='sparsevoxel') + parser.add_argument('--user-dir', default='fairnr') + args = options.parse_args_and_arch(parser) + + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [args.path], suffix=getattr(args, "checkpoint_suffix", ""), + ) + model = models[0] + voxel_idx, voxel_pts = model.encoder.extract_voxels() + + # write to ply file. + points = [ + (voxel_pts[k, 0], voxel_pts[k, 1], voxel_pts[k, 2], voxel_idx[k]) + for k in range(voxel_idx.size(0)) + ] + vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('quality', 'f4')]) + PlyData([PlyElement.describe(vertex, 'vertex')], text=True).write(os.path.join(args.output, args.name + '.ply')) + + # model = torch.load(args.path) + # voxel_pts = model['model']['encoder.points'][model['model']['encoder.keep'].bool()] + # from fairseq import pdb;pdb.set_trace() + # # write to ply file. + # points = [ + # (voxel_pts[k, 0], voxel_pts[k, 1], voxel_pts[k, 2]) + # for k in range(voxel_pts.size(0)) + # ] + # vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) + # PlyData([PlyElement.describe(vertex, 'vertex')], text=True).write(os.path.join(args.output, args.name + '.ply')) + +if __name__ == '__main__': + cli_main() diff --git a/fairnr_cli/launch_slurm.py b/fairnr_cli/launch_slurm.py new file mode 100644 index 0000000..9002f3b --- /dev/null +++ b/fairnr_cli/launch_slurm.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import random, shlex +import os, sys, subprocess + + +def launch_cluster(slurm_args, model_args): + # prepare + jobname = slurm_args.get('job-name', 'test') + train_log = slurm_args.get('output', None) + train_stderr = slurm_args.get('error', None) + nodes, gpus = slurm_args.get('nodes', 1), slurm_args.get('gpus', 8) + if not slurm_args.get('local', False): + assert (train_log is not None) and (train_stderr is not None) + + # parse slurm + train_cmd = ['python', 'train.py', ] + train_cmd.extend(['--distributed-world-size', str(nodes * gpus)]) + if nodes > 1: + train_cmd.extend(['--distributed-port', str(get_random_port())]) + + train_cmd += model_args + + base_srun_cmd = [ + 'srun', + '--job-name', jobname, + '--output', train_log, + '--error', train_stderr, + '--open-mode', 'append', + '--unbuffered', + ] + srun_cmd = base_srun_cmd + train_cmd + srun_cmd_str = ' '.join(map(shlex.quote, srun_cmd)) + srun_cmd_str = srun_cmd_str + ' &' + + sbatch_cmd = [ + 'sbatch', + '--job-name', jobname, + '--partition', slurm_args.get('partition', 'learnfair'), + '--gres', 'gpu:volta:{}'.format(gpus), + '--nodes', str(nodes), + '--ntasks-per-node', '1', + '--cpus-per-task', '48', + '--output', train_log, + '--error', train_stderr, + '--open-mode', 'append', + '--signal', 'B:USR1@180', + '--time', slurm_args.get('time', '4320'), + '--mem', slurm_args.get('mem', '500gb'), + '--exclusive', + ] + if 'constraint' in slurm_args: + sbatch_cmd += ['-C', slurm_args.get('constraint')] + if 'comment' in slurm_args: + sbatch_cmd += ['--comment', slurm_args.get('comment')] + + wrapped_cmd = requeue_support() + '\n' + srun_cmd_str + ' \n wait $! \n sleep 610 & \n wait $!' + sbatch_cmd += ['--wrap', wrapped_cmd] + sbatch_cmd_str = ' '.join(map(shlex.quote, sbatch_cmd)) + + # start training + env = os.environ.copy() + env['OMP_NUM_THREADS'] = '2' + if env.get('SLURM_ARGS', None) is not None: + del env['SLURM_ARGS'] + + if nodes > 1: + env['NCCL_SOCKET_IFNAME'] = '^docker0,lo' + env['NCCL_DEBUG'] = 'INFO' + + if slurm_args.get('dry-run', False): + print(sbatch_cmd_str) + + elif slurm_args.get('local', False): + assert nodes == 1, 'distributed training cannot be combined with local' + if 'CUDA_VISIBLE_DEVICES' not in env: + env['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(gpus))) + env['NCCL_DEBUG'] = 'INFO' + + if train_log is not None: + train_proc = subprocess.Popen(train_cmd, env=env, stdout=subprocess.PIPE) + tee_proc = subprocess.Popen(['tee', '-a', train_log], stdin=train_proc.stdout) + train_proc.stdout.close() + train_proc.wait() + tee_proc.wait() + else: + train_proc = subprocess.Popen(train_cmd, env=env) + train_proc.wait() + else: + with open(train_log, 'a') as train_log_h: + print(f'running command: {sbatch_cmd_str}\n') + with subprocess.Popen(sbatch_cmd, stdout=subprocess.PIPE, env=env) as train_proc: + stdout = train_proc.stdout.read().decode('utf-8') + print(stdout, file=train_log_h) + try: + job_id = int(stdout.rstrip().split()[-1]) + return job_id + except IndexError: + return None + + +def launch(slurm_args, model_args): + job_id = launch_cluster(slurm_args, model_args) + if job_id is not None: + print('Launched {}'.format(job_id)) + else: + print('Failed.') + + +def requeue_support(): + return """ + trap_handler () { + echo "Caught signal: " $1 + # SIGTERM must be bypassed + if [ "$1" = "TERM" ]; then + echo "bypass sigterm" + else + # Submit a new job to the queue + echo "Requeuing " $SLURM_JOB_ID + scontrol requeue $SLURM_JOB_ID + fi + } + + + # Install signal handler + trap 'trap_handler USR1' USR1 + trap 'trap_handler TERM' TERM + """ + + +def get_random_port(): + old_state = random.getstate() + random.seed() + port = random.randint(10000, 20000) + random.setstate(old_state) + return port diff --git a/fairnr_cli/render.py b/fairnr_cli/render.py new file mode 100644 index 0000000..ec1e53c --- /dev/null +++ b/fairnr_cli/render.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +This is a copy of fairseq-generate while simpler for other usage. +""" + + +import logging +import math +import os +import sys +import time +import torch +import imageio +import numpy as np + +from fairseq import checkpoint_utils, progress_bar, tasks, utils +from fairseq.meters import StopwatchMeter, TimeMeter +from fairnr import options + + +def main(args): + assert args.path is not None, '--path required for generation!' + + if args.results_path is not None: + os.makedirs(args.results_path, exist_ok=True) + output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset)) + with open(output_path, 'w', buffering=1) as h: + return _main(args, h) + else: + return _main(args, sys.stdout) + + +def _main(args, output_file): + logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=output_file, + ) + logger = logging.getLogger('fairnr_cli.render') + + utils.import_user_module(args) + + if args.max_tokens is None and args.max_sentences is None: + args.max_tokens = 12000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset splits + task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) + + + # Load ensemble + logger.info('loading model(s) from {}'.format(args.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(os.pathsep), + arg_overrides=eval(args.model_overrides), + task=task, + ) + + # Optimize ensemble for generation + for model in models: + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[model.max_positions() for model in models] + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + ).next_epoch_itr(shuffle=False) + + # Initialize generator + gen_timer = StopwatchMeter() + generator = task.build_generator(args) + + + output_files, step= [], 0 + with progress_bar.build_progress_bar(args, itr) as t: + wps_meter = TimeMeter() + for i, sample in enumerate(t): + sample = utils.move_to_cuda(sample) if use_cuda else sample + gen_timer.start() + + step, _output_files = task.inference_step(generator, models, [sample, step]) + output_files += _output_files + + gen_timer.stop(500) + wps_meter.update(500) + t.log({'wps': round(wps_meter.avg)}) + + break + # if i > 5: + # break + + generator.save_images(output_files, combine_output=args.render_combine_output) + +def cli_main(): + parser = options.get_rendering_parser() + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == '__main__': + cli_main() diff --git a/fairnr_cli/render_multigpu.py b/fairnr_cli/render_multigpu.py new file mode 100644 index 0000000..0ad5638 --- /dev/null +++ b/fairnr_cli/render_multigpu.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +This is a copy of fairseq-generate while simpler for other usage. +""" + + +import logging +import math +import os +import sys +import time +import torch +import imageio +import numpy as np + +from fairseq import checkpoint_utils, progress_bar, tasks, utils, distributed_utils +from fairseq.meters import StopwatchMeter, TimeMeter +from fairseq.options import add_distributed_training_args +from fairnr import options + + +def main(args, *kwargs): + assert args.path is not None, '--path required for generation!' + + if args.results_path is not None: + os.makedirs(args.results_path, exist_ok=True) + output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset)) + with open(output_path, 'w', buffering=1) as h: + return _main(args, h) + else: + return _main(args, sys.stdout) + + +def _main(args, output_file): + logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=output_file, + ) + logger = logging.getLogger('fairnr_cli.render') + + utils.import_user_module(args) + + if args.max_tokens is None and args.max_sentences is None: + args.max_tokens = 12000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset splits + task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) + + + # Load ensemble + logger.info('loading model(s) from {}'.format(args.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(os.pathsep), + arg_overrides=eval(args.model_overrides), + task=task, + ) + + # Optimize ensemble for generation + for model in models: + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + logging.info(model) + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[model.max_positions() for model in models] + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + seed=args.seed, + num_workers=args.num_workers + ).next_epoch_itr(shuffle=False) + + # Initialize generator + gen_timer = StopwatchMeter() + generator = task.build_generator(args) + shard_id, world_size = args.distributed_rank, args.distributed_world_size + output_files = [] + if generator.test_poses is not None: + total_frames = generator.test_poses.shape[0] + _frames = int(np.floor(total_frames / world_size)) + step = shard_id * _frames + frames = _frames if shard_id < (world_size - 1) else total_frames - step + else: + step = shard_id * args.render_num_frames + frames = args.render_num_frames + + with progress_bar.build_progress_bar(args, itr) as t: + wps_meter = TimeMeter() + for i, sample in enumerate(t): + sample = utils.move_to_cuda(sample) if use_cuda else sample + gen_timer.start() + + step, _output_files = task.inference_step( + generator, models, [sample, step, frames]) + output_files += _output_files + + gen_timer.stop(500) + wps_meter.update(500) + t.log({'wps': round(wps_meter.avg)}) + + timestamp = generator.save_images( + output_files, steps='shard{}'.format(shard_id), combine_output=args.render_combine_output) + + # join videos from all GPUs and delete temp files + try: + timestamps = distributed_utils.all_gather_list(timestamp) + except: + timestamps = [timestamp] + + if shard_id == 0: + generator.merge_videos(timestamps) + +def cli_main(): + parser = options.get_rendering_parser() + add_distributed_training_args(parser) + args = options.parse_args_and_arch(parser) + + distributed_utils.call_main(args, main) + + +if __name__ == '__main__': + cli_main() diff --git a/fairnr_cli/train.py b/fairnr_cli/train.py new file mode 100644 index 0000000..69777a2 --- /dev/null +++ b/fairnr_cli/train.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train a new model on one or across multiple GPUs. +This file is mostly copied from the original fairseq code +""" + +import logging +import math +import os +import random +import sys + +import numpy as np +import torch + +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils +from fairseq.data import iterators +from fairseq.logging import meters, metrics, progress_bar +from fairseq.trainer import Trainer +from fairseq.model_parallel.megatron_trainer import MegatronTrainer + +from fairnr import ResetTrainerException + +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=sys.stdout, +) +logger = logging.getLogger('fairnr_cli.train') + + +def main(args, init_distributed=False): + utils.import_user_module(args) + + assert args.max_tokens is not None or args.max_sentences is not None, \ + 'Must specify batch size either with --max-tokens or --max-sentences' + metrics.reset() + + # Initialize CUDA and distributed training + if torch.cuda.is_available() and not args.cpu: + torch.cuda.set_device(args.device_id) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if init_distributed: + args.distributed_rank = distributed_utils.distributed_init(args) + + if distributed_utils.is_master(args): + checkpoint_utils.verify_checkpoint_directory(args.save_dir) + + # Print args + logger.info(args) + + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(args) + + # Load valid dataset (we load training data below, based on the latest checkpoint) + for valid_sub_split in args.valid_subset.split(','): + task.load_dataset(valid_sub_split, combine=False, epoch=1) + + # Build model and criterion + model = task.build_model(args) + criterion = task.build_criterion(args) + logger.info(model) + logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) + logger.info('num. model params: {} (num. trained: {})'.format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) + + # Build trainer + if args.model_parallel_size == 1: + trainer = Trainer(args, task, model, criterion) + else: + trainer = MegatronTrainer(args, task, model, criterion) + + logger.info('training on {} GPUs'.format(args.distributed_world_size)) + logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format( + args.max_tokens, + args.max_sentences, + )) + + # Load the latest checkpoint if one is available and restore the + # corresponding train iterator + extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) + + # Train until the learning rate gets too small + max_epoch = args.max_epoch or math.inf + lr = trainer.get_lr() + train_meter = meters.StopwatchMeter() + train_meter.start() + valid_subsets = args.valid_subset.split(',') + while ( + lr > args.min_lr + and epoch_itr.next_epoch_idx <= max_epoch + ): + # train for one epoch + should_end_training = train(args, trainer, task, epoch_itr) + + valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) + + # only use first validation loss to update the learning rate + lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) + + epoch_itr = trainer.get_train_iterator( + epoch_itr.next_epoch_idx, + # sharded data: get train iterator for next epoch + load_dataset=(os.pathsep in getattr(args, 'data', '')), + ) + + if should_end_training: + break + train_meter.stop() + logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) + + +def should_stop_early(args, valid_loss): + # skip check if no validation was done in the current epoch + if valid_loss is None: + return False + if args.patience <= 0: + return False + + def is_better(a, b): + return a > b if args.maximize_best_checkpoint_metric else a < b + + prev_best = getattr(should_stop_early, 'best', None) + if prev_best is None or is_better(valid_loss, prev_best): + should_stop_early.best = valid_loss + should_stop_early.num_runs = 0 + return False + else: + should_stop_early.num_runs += 1 + if should_stop_early.num_runs >= args.patience: + logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) + return True + else: + return False + + +@metrics.aggregate('train') +def train(args, trainer, task, epoch_itr): + """Train the model for one epoch.""" + # Initialize data iterator + itr = epoch_itr.next_epoch_itr( + fix_batches_to_gpus=args.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > args.curriculum), + ) + update_freq = ( + args.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(args.update_freq) + else args.update_freq[-1] + ) + itr = iterators.GroupedIterator(itr, update_freq) + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch_itr.epoch, + tensorboard_logdir=( + args.tensorboard_logdir if distributed_utils.is_master(args) else None + ), + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + ) + + # task specific setup per epoch + task.begin_epoch(epoch_itr.epoch, trainer.get_model()) + + valid_subsets = args.valid_subset.split(',') + max_update = args.max_update or math.inf + should_end_training = False + for samples in progress: + with metrics.aggregate('train_inner'): + try: + log_output = trainer.train_step(samples) + + except ResetTrainerException: + trainer._wrapped_criterion = None + trainer._wrapped_model = None + trainer._optimizer = None + + logger.info("reset the trainer at {}".format(trainer.get_num_updates())) + log_output = trainer.train_step(samples) + + if log_output is None: # OOM, overflow, ... + continue + + # log mid-epoch stats + num_updates = trainer.get_num_updates() + if num_updates % args.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values('train_inner')) + progress.log(stats, tag='train_inner', step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters('train_inner') + + valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) + if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: + should_end_training = True + break + + # log end-of-epoch stats + stats = get_training_stats(metrics.get_smoothed_values('train')) + progress.print(stats, tag='train', step=num_updates) + + # reset epoch-level meters + metrics.reset_meters('train') + return should_end_training + + +def validate_and_save(args, trainer, task, epoch_itr, valid_subsets): + num_updates = trainer.get_num_updates() + do_save = ( + ( + args.save_interval_updates > 0 + and num_updates > 0 + and num_updates % args.save_interval_updates == 0 + ) + or ( + epoch_itr.end_of_epoch() + and epoch_itr.epoch % args.save_interval == 0 + ) + ) + do_validate = ( + ( + do_save # saving requires validation + or ( + epoch_itr.end_of_epoch() + and epoch_itr.epoch % args.validate_interval == 0 + ) + ) + and not args.disable_validation + ) + + # Validate + valid_losses = [None] + if do_validate: + valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + # Save + if do_save: + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + return valid_losses + + +def get_training_stats(stats): + if 'nll_loss' in stats and 'ppl' not in stats: + stats['ppl'] = utils.get_perplexity(stats['nll_loss']) + stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) + return stats + + +def validate(args, trainer, task, epoch_itr, subsets): + """Evaluate the model on the validation set(s) and return the losses.""" + if args.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(args.fixed_validation_seed) + + # reset dummy batch only for validation + trainer._dummy_batch = "DUMMY" # reset dummy batch + + valid_losses = [] + for subset in subsets: + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=task.dataset(subset), + max_tokens=args.max_tokens_valid, + max_sentences=args.max_sentences_valid, + max_positions=utils.resolve_max_positions( + task.max_positions(), + trainer.get_model().max_positions(), + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + seed=args.seed, + num_shards=args.distributed_world_size, + shard_id=args.distributed_rank, + num_workers=args.num_workers, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch_itr.epoch, + prefix=f"valid on '{subset}' subset", + tensorboard_logdir=( + args.tensorboard_logdir if distributed_utils.is_master(args) else None + ), + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + ) + + # create a new root metrics aggregator so validation metrics + # don't pollute other aggregators (e.g., train meters) + with metrics.aggregate(new_root=True) as agg: + for step, sample in enumerate(progress): + trainer.valid_step(sample) + stats = get_training_stats(agg.get_smoothed_values()) + progress.log(stats, tag='valid', step=step) + + # log validation stats + stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) + progress.print(stats, tag=subset, step=trainer.get_num_updates()) + + valid_losses.append(stats[args.best_checkpoint_metric]) + + # reset dummy batch again for continuing training + trainer._dummy_batch = "DUMMY" + return valid_losses + + +def get_valid_stats(args, trainer, stats): + if 'nll_loss' in stats and 'ppl' not in stats: + stats['ppl'] = utils.get_perplexity(stats['nll_loss']) + stats['num_updates'] = trainer.get_num_updates() + if hasattr(checkpoint_utils.save_checkpoint, 'best'): + key = 'best_{0}'.format(args.best_checkpoint_metric) + best_function = max if args.maximize_best_checkpoint_metric else min + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, + stats[args.best_checkpoint_metric], + ) + return stats + + +def distributed_main(i, args, start_rank=0): + args.device_id = i + if args.distributed_rank is None: # torch.multiprocessing.spawn + args.distributed_rank = start_rank + i + main(args, init_distributed=True) + + +def cli_main(modify_parser=None): + parser = options.get_training_parser() + args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + + if args.distributed_init_method is None: + distributed_utils.infer_init_method(args) + + if args.distributed_init_method is not None: + # distributed training + if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: + start_rank = args.distributed_rank + args.distributed_rank = None # assign automatically + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args, start_rank), + nprocs=torch.cuda.device_count(), + ) + else: + distributed_main(args.device_id, args) + elif args.distributed_world_size > 1: + # fallback for single node with multiple GPUs + assert args.distributed_world_size <= torch.cuda.device_count() + port = random.randint(10000, 20000) + args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + args.distributed_rank = None # set based on device id + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args, ), + nprocs=args.distributed_world_size, + ) + else: + # single GPU training + main(args) + + +if __name__ == '__main__': + cli_main() diff --git a/fairnr_cli/validate.py b/fairnr_cli/validate.py new file mode 100644 index 0000000..465253c --- /dev/null +++ b/fairnr_cli/validate.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import sys + +import numpy as np +import torch +from itertools import chain +from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.logging import metrics, progress_bar +from fairseq.options import add_distributed_training_args + +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=sys.stdout, +) +logger = logging.getLogger('fairnr_cli.validate') + + +def main(args, override_args=None): + utils.import_user_module(args) + + assert args.max_tokens is not None or args.max_sentences is not None, \ + 'Must specify batch size either with --max-tokens or --max-sentences' + + use_fp16 = args.fp16 + use_cuda = torch.cuda.is_available() and not args.cpu + + if override_args is not None: + try: + override_args = override_args['override_args'] + except TypeError: + override_args = override_args + overrides = vars(override_args) + overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) + else: + overrides = None + + # Load ensemble + logger.info('loading model(s) from {}'.format(args.path)) + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [args.path], + arg_overrides=overrides, + suffix=getattr(args, "checkpoint_suffix", ""), + ) + model = models[0] + + # Move models to GPU + for model in models: + if use_fp16: + model.half() + if use_cuda: + model.cuda() + + # Print args + logger.info(model_args) + + # Build criterion + criterion = task.build_criterion(model_args) + if use_fp16: + criterion.half() + if use_cuda: + criterion.cuda() + criterion.eval() + + for subset in args.valid_subset.split(','): + try: + task.load_dataset(subset, combine=False, epoch=1) + dataset = task.dataset(subset) + except KeyError: + raise Exception('Cannot find dataset: ' + subset) + + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=dataset, + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[m.max_positions() for m in models], + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + seed=args.seed, + num_workers=args.num_workers, + num_shards=args.distributed_world_size, + shard_id=args.distributed_rank + ).next_epoch_itr(shuffle=False) + + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + prefix=f"valid on '{subset}' subset", + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + ) + + log_outputs = [] + for i, sample in enumerate(progress): + sample = utils.move_to_cuda(sample) if use_cuda else sample + sample = utils.apply_to_sample( + lambda t: t.half() if t.dtype is torch.float32 else t, sample) if use_fp16 else sample + try: + with torch.no_grad(): # do not save backward passes + max_num_rays = 900 * 900 + if sample['uv'].shape[3] > max_num_rays: + sample['ray_split'] = sample['uv'].shape[3] // max_num_rays + _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) + + progress.log(log_output, step=i) + log_outputs.append(log_output) + + except TypeError: + break + + with metrics.aggregate() as agg: + task.reduce_metrics(log_outputs, criterion) + log_output = agg.get_smoothed_values() + + + # summarize all the gpus + if args.distributed_world_size > 1: + all_log_output = list(zip(*distributed_utils.all_gather_list([log_output])))[0] + log_output = { + key: np.mean([log[key] for log in all_log_output]) + for key in all_log_output[0] + } + + progress.print(log_output, tag=subset, step=i) + + + +def cli_main(): + parser = options.get_validation_parser() + args = options.parse_args_and_arch(parser) + + # only override args that are explicitly given on the command line + override_parser = options.get_validation_parser() + override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + + # support multi-gpu validation, use all available gpus + default_world_size = max(1, torch.cuda.device_count()) + if args.distributed_world_size < default_world_size: + args.distributed_world_size = default_world_size + override_args.distributed_world_size = default_world_size + + distributed_utils.call_main(args, main, override_args=override_args) + + +if __name__ == '__main__': + cli_main() diff --git a/render.py b/render.py new file mode 100644 index 0000000..e26b183 --- /dev/null +++ b/render.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairnr_cli.render_multigpu import cli_main + + +if __name__ == '__main__': + cli_main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..325a809 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +open3d==0.10.0 +opencv_python==4.2.0.32 +tqdm==4.43.0 +pandas==0.25.3 +imageio==2.6.1 +scikit_image==0.16.2 +scipy==1.4.1 +plyfile==0.7.1 +matplotlib==3.1.2 +numpy==1.16.4 +mathutils==2.81.2 +tensorboardX==2.0 +imageio-ffmpeg==0.4.2 +git+https://github.com/pytorch/fairseq.git@8aa06aa03b596de58d106d3f55ff43e2b9aa0b80 +git+https://github.com/aliutkus/torchsearchsorted +git+https://github.com/MultiPath/lpips-pytorch.git \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..63ef501 --- /dev/null +++ b/setup.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import glob + +# build clib +_ext_src_root = "fairnr/clib" +_ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( + "{}/src/*.cu".format(_ext_src_root) +) +_ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) + +setup( + name='fairnr', + ext_modules=[ + CUDAExtension( + name='fairnr.clib._ext', + sources=_ext_sources, + extra_compile_args={ + "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], + "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], + }, + ) + ], + cmdclass={ + 'build_ext': BuildExtension + }, + entry_points={ + 'console_scripts': [ + 'fairnr-render = fairnr_cli.render:cli_main', + 'fairnr-train = fairseq_cli.train:cli_main' + ], + }, +) diff --git a/train.py b/train.py new file mode 100644 index 0000000..f5a508a --- /dev/null +++ b/train.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import sys, os +from fairnr_cli.train import cli_main +from fairnr_cli.launch_slurm import launch + +if __name__ == '__main__': + if os.getenv('SLURM_ARGS') is not None: + slurm_arg = eval(os.getenv('SLURM_ARGS')) + all_args = sys.argv[1:] + + print(slurm_arg) + print(all_args) + launch(slurm_arg, all_args) + + else: + cli_main() diff --git a/validate.py b/validate.py new file mode 100644 index 0000000..1cffd84 --- /dev/null +++ b/validate.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairnr_cli.validate import cli_main + + +if __name__ == '__main__': + cli_main()