Spaces:
Build error
Build error
| # JaxNeRF | |
| This is a [JAX](https://github.com/google/jax) implementation of | |
| [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://www.matthewtancik.com/nerf). | |
| This code is created and maintained by | |
| [Boyang Deng](https://boyangdeng.com/), | |
| [Jon Barron](https://jonbarron.info/), | |
| and [Pratul Srinivasan](https://people.eecs.berkeley.edu/~pratul/). | |
| <div align="center"> | |
| <img width="95%" alt="NeRF Teaser" src="https://raw.githubusercontent.com/bmild/nerf/master/imgs/pipeline.jpg"> | |
| </div> | |
| Our JAX implementation currently supports: | |
| <table class="tg"> | |
| <thead> | |
| <tr> | |
| <th class="tg-0lax"><span style="font-weight:bold">Platform</span></th> | |
| <th class="tg-0lax" colspan="2"><span style="font-weight:bold">Single-Host GPU</span></th> | |
| <th class="tg-0lax" colspan="2"><span style="font-weight:bold">Multi-Device TPU</span></th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| <tr> | |
| <td class="tg-0lax"><span style="font-weight:bold">Type</span></td> | |
| <td class="tg-0lax">Single-Device</td> | |
| <td class="tg-0lax">Multi-Device</td> | |
| <td class="tg-0lax">Single-Host</td> | |
| <td class="tg-0lax">Multi-Host</td> | |
| </tr> | |
| <tr> | |
| <td class="tg-0lax"><span style="font-weight:bold">Training</span></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| </tr> | |
| <tr> | |
| <td class="tg-0lax"><span style="font-weight:bold">Evaluation</span></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| <td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td> | |
| </tr> | |
| </tbody> | |
| </table> | |
| The training job on 128 TPUv2 cores can be done in **2.5 hours (v.s 3 days for TF | |
| NeRF)** for 1 million optimization steps. In other words, JaxNeRF trains to the best while trains very fast. | |
| As for inference speed, here are the statistics of rendering an image with | |
| 800x800 resolution (numbers are averaged over 50 rendering passes): | |
| | Platform | 1 x NVIDIA V100 | 8 x NVIDIA V100 | 128 x TPUv2 | | |
| |----------|:---------------:|:-----------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------:| | |
| | TF NeRF | 27.74 secs | <img src="http://storage.googleapis.com/gresearch/jaxnerf/cross.png" alt="Not Supported" width=18px height=18px> | <img src="http://storage.googleapis.com/gresearch/jaxnerf/cross.png" alt="Not Supported" width=18px height=18px> | | |
| | JaxNeRF | 20.77 secs | 2.65 secs | 0.35 secs | | |
| The code is tested and reviewed carefully to match the | |
| [original TF NeRF implementation](https://github.com/bmild/nerf). | |
| If you have any issues using this code, please do not open an issue as the repo | |
| is shared by all projects under Google Research. Instead, just email | |
| [email protected]. | |
| ## Installation | |
| We recommend using [Anaconda](https://www.anaconda.com/products/individual) to set | |
| up the environment. Run the following commands: | |
| ``` | |
| # Clone the repo | |
| svn export https://github.com/google-research/google-research/trunk/jaxnerf | |
| # Create a conda environment, note you can use python 3.6-3.8 as | |
| # one of the dependencies (TensorFlow) hasn't supported python 3.9 yet. | |
| conda create --name jaxnerf python=3.6.12; conda activate jaxnerf | |
| # Prepare pip | |
| conda install pip; pip install --upgrade pip | |
| # Install requirements | |
| pip install -r jaxnerf/requirements.txt | |
| # [Optional] Install GPU and TPU support for Jax | |
| # Remember to change cuda101 to your CUDA version, e.g. cuda110 for CUDA 11.0. | |
| pip install --upgrade jax jaxlib==0.1.57+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html | |
| ``` | |
| Then, you'll need to download the datasets | |
| from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). | |
| Please download the `nerf_synthetic.zip` and `nerf_llff_data.zip` and unzip them | |
| in the place you like. Let's assume they are placed under `/tmp/jaxnerf/data/`. | |
| That's it for installation. You're good to go. **Notice:** For the following instructions, you don't need to enter the jaxnerf folder. Just stay in the parent folder. | |
| ## Two Commands for Everything | |
| ``` | |
| bash jaxnerf/train.sh demo /tmp/jaxnerf/data | |
| bash jaxnerf/eval.sh demo /tmp/jaxnerf/data | |
| ``` | |
| Once both jobs are done running (which may take a while if you only have 1 GPU | |
| or CPU), you'll have a folder, `/tmp/jaxnerf/data/demo`, with: | |
| * Trained NeRF models for all scenes in the blender dataset. | |
| * Rendered images and depth maps for all test views. | |
| * The collected PSNRs of all scenes in a TXT file. | |
| Note that we used the `demo` config here which is basically the `blender` config | |
| in the paper except smaller batch size and much less train steps. Of course, you | |
| can use other configs to replace `demo` and other data locations to replace | |
| `/tmp/jaxnerf/data`. | |
| We provide 2 configurations in the folder `configs` which match the original | |
| configurations used in the paper for the blender dataset and the LLFF dataset. | |
| Be careful when you use them. Their batch sizes are large so you may get OOM error if you have limited resources, for example, 1 GPU with small memory. Also, they have many many train steps so you may need days to finish training all scenes. | |
| ## Play with One Scene | |
| You can also train NeRF on only one scene. The easiest way is to use given configs: | |
| ``` | |
| python -m jaxnerf.train \ | |
| --data_dir=/PATH/TO/YOUR/SCENE/DATA \ | |
| --train_dir=/PATH/TO/THE/PLACE/YOU/WANT/TO/SAVE/CHECKPOINTS \ | |
| --config=configs/CONFIG_YOU_LIKE | |
| ``` | |
| Evaluating NeRF on one scene is similar: | |
| ``` | |
| python -m jaxnerf.eval \ | |
| --data_dir=/PATH/TO/YOUR/SCENE/DATA \ | |
| --train_dir=/PATH/TO/THE/PLACE/YOU/SAVED/CHECKPOINTS \ | |
| --config=configs/CONFIG_YOU_LIKE \ | |
| --chunk=4096 | |
| ``` | |
| The `chunk` parameter defines how many rays are feed to the model in one go. | |
| We recommend you to use the largest value that fits to your device's memory but | |
| small values are fine, only a bit slow. | |
| You can also define your own configurations by passing command line flags. Please refer to the `define_flags` function in `nerf/utils.py` for all the flags and their meanings. | |
| **Note**: For the ficus scene in the blender dataset, we noticed that it's sensible to different initializations, | |
| e.g. using different random seeds, if using the original learning rate schedule in the paper. | |
| Therefore, we provide a simple tweak (turned off by default) for more stable trainings: using `lr_delay_steps` and `lr_delay_mult`. | |
| This allows the training to start from a smaller learning rate (`lr_init` * `lr_delay_mult`) in the first `lr_delay_steps`. | |
| We didn't use them for our pretrained models | |
| but we tested `lr_delay_steps=5000` with `lr_delay_mult=0.2` and it works quite smoothly. | |
| ## Pretrained Models | |
| We provide a collection of pretrained NeRF models that match the numbers | |
| reported in the [paper](https://arxiv.org/abs/2003.08934). Actually, ours are | |
| slightly better overall because we trained for more iterations (while still | |
| being much faster!). You can find our pretrained models | |
| [here](http://storage.googleapis.com/gresearch/jaxnerf/jaxnerf_pretrained_models.zip). | |
| The performances (in PSNR) of our pretrained NeRF models are listed below: | |
| ### Blender | |
| | Scene | Chair | Drums | Ficus | Hotdog | Lego | Materials | Mic | Ship | Mean | | |
| |---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:| | |
| | TF NeRF | 33.00 | 25.01 | 30.13 | 36.18 | 32.54 | 29.62 | 32.91 | 28.65 | 31.01 | | |
| | JaxNeRF | **34.08** | **25.03** | **30.43** | **36.92** | **33.28** | **29.91** | **34.53** | **29.36** | **31.69** | | |
| ### LLFF | |
| | Scene | Room | Fern | Leaves | Fortress | Orchids | Flower | T-Rex | Horns | Mean | | |
| |---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:| | |
| | TF NeRF | 32.70 | **25.17** | 20.92 | 31.16 | **20.36** | 27.40 | 26.80 | 27.45 | 26.50 | | |
| | JaxNeRF | **33.04** | 24.83 | **21.23** | **31.76** | 20.27 | **28.07** | **27.42** | **28.10** | **26.84** | | |
| ## Citation | |
| If you use this software package, please cite it as: | |
| ``` | |
| @software{jaxnerf2020github, | |
| author = {Boyang Deng and Jonathan T. Barron and Pratul P. Srinivasan}, | |
| title = {{JaxNeRF}: an efficient {JAX} implementation of {NeRF}}, | |
| url = {https://github.com/google-research/google-research/tree/master/jaxnerf}, | |
| version = {0.0}, | |
| year = {2020}, | |
| } | |
| ``` | |
| and also cite the original NeRF paper: | |
| ``` | |
| @inproceedings{mildenhall2020nerf, | |
| title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis}, | |
| author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng}, | |
| year={2020}, | |
| booktitle={ECCV}, | |
| } | |
| ``` | |
| ## Acknowledgement | |
| We'd like to thank | |
| [Daniel Duckworth](http://www.stronglyconvex.com/), | |
| [Dan Gnanapragasam](https://research.google/people/DanGnanapragasam/), | |
| and [James Bradbury](https://twitter.com/jekbradbury) | |
| for their help on reviewing and optimizing this code. | |
| We'd like to also thank the amazing [JAX](https://github.com/google/jax) team for | |
| very insightful and helpful discussions on how to use JAX for NeRF. | |