# tiny-cuda-nn **Repository Path**: mirrors_NVlabs/tiny-cuda-nn ## Basic Information - **Project Name**: tiny-cuda-nn - **Description**: Lightning fast C++/CUDA neural network framework - **Primary Language**: Unknown - **License**: BSD-3-Clause - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 1 - **Forks**: 0 - **Created**: 2021-04-23 - **Last Updated**: 2025-09-20 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # Tiny CUDA Neural Networks ![](https://github.com/NVlabs/tiny-cuda-nn/workflows/CI/badge.svg) This is a small, self-contained framework for training and querying neural networks. Most notably, it contains a lightning fast ["fully fused" multi-layer perceptron](https://raw.githubusercontent.com/NVlabs/tiny-cuda-nn/master/data/readme/fully-fused-mlp-diagram.png) ([technical paper](https://tom94.net/data/publications/mueller21realtime/mueller21realtime.pdf)), a versatile [multiresolution hash encoding](https://raw.githubusercontent.com/NVlabs/tiny-cuda-nn/master/data/readme/multiresolution-hash-encoding-diagram.png) ([technical paper](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.pdf)), as well as support for various other input encodings, losses, and optimizers. ## Performance ![Image](data/readme/fully-fused-vs-tensorflow.png) _Fully fused networks vs. TensorFlow v2.5.0 w/ XLA. Measured on 64 (solid line) and 128 (dashed line) neurons wide multi-layer perceptrons on an RTX 3090. Generated by `benchmarks/bench_ours.cu` and `benchmarks/bench_tensorflow.py` using `data/config_oneblob.json`._ ## Usage Tiny CUDA neural networks have a simple C++/CUDA API: ```cpp #include // Configure the model nlohmann::json config = { {"loss", { {"otype", "L2"} }}, {"optimizer", { {"otype", "Adam"}, {"learning_rate", 1e-3}, }}, {"encoding", { {"otype", "HashGrid"}, {"n_levels", 16}, {"n_features_per_level", 2}, {"log2_hashmap_size", 19}, {"base_resolution", 16}, {"per_level_scale", 2.0}, }}, {"network", { {"otype", "FullyFusedMLP"}, {"activation", "ReLU"}, {"output_activation", "None"}, {"n_neurons", 64}, {"n_hidden_layers", 2}, }}, }; using namespace tcnn; auto model = create_from_config(n_input_dims, n_output_dims, config); model->set_jit_fusion(supports_jit_fusion()); // Optional: accelerate with JIT fusion // Train the model (batch_size must be a multiple of tcnn::BATCH_SIZE_GRANULARITY) GPUMatrix training_batch_inputs(n_input_dims, batch_size); GPUMatrix training_batch_targets(n_output_dims, batch_size); for (int i = 0; i < n_training_steps; ++i) { generate_training_batch(&training_batch_inputs, &training_batch_targets); // <-- your code float loss; model.trainer->training_step(training_batch_inputs, training_batch_targets, &loss); std::cout << "iteration=" << i << " loss=" << loss << std::endl; } // Use the model GPUMatrix inference_inputs(n_input_dims, batch_size); generate_inputs(&inference_inputs); // <-- your code GPUMatrix inference_outputs(n_output_dims, batch_size); model.network->inference(inference_inputs, inference_outputs); ``` ## JIT fusion JIT fusion is a new, optional feature with tiny-cuda-nn v2.0 and later. It is *almost always* recommended to enable [automatic JIT fusion](#automatic-jit-fusion) for a performance boost of 1.5x to 2.5x, depending on the model and GPU. Newer GPUs exhibit larger speedups. If your model has very large hash grids (~20 million+ parameters) or MLPs (layer sizes larger than 128 neurons), or when your GPU is an RTX 3000 series or earlier, JIT fusion *can* slow down training. Rarely inference, too. It this case, it is recommended to try enabling JIT fusion separately for training and inference to measure whether it is faster. Please [open an issue](https://github.com/NVlabs/tiny-cuda-nn/issues) if you encounter a slowdown in a different situation or other problems with JIT fusion enabled. ### Automatic JIT fusion To enable JIT fusion, set the `jit_fusion` property of your model to `true`. All future uses of the model, whether inference or training, will then use JIT mode. Note that if there is an error during JIT compilation, a warning will be emitted and JIT compilation mode automatically turned off. Your code will still run using the tiny-cuda-nn 1.X code path. ```cpp auto model = tcnn::create_from_config(...); model->set_jit_fusion(tcnn::supports_jit_fusion()); // Enable JIT if the system supports it ``` JIT fusion can also be enabled via the PyTorch bindings but the speed-up will be lower, particularly during training. This is because the JIT compiler does not have access to the whole compute graph and can therefore fuse and optimize less. ```python import tinycudann as tcnn model = tcnn.NetworkWithInputEncoding(...) # Or any other tcnn model model.jit_fusion = tcnn.supports_jit_fusion() # Enable JIT if the system supports it ``` ### Manual JIT fusion Even larger speed-ups are possible when applications integrate more tightly with JIT fusion. For example, [Instant NGP](https://github.com/nvlabs/instant-ngp) achieves a 5x speedup by fusing the entire NeRF ray marcher into a single kernel. JIT fusion works by converting a given tiny-cuda-nn model to a CUDA device function and then compiling it into a kernel using CUDA's runtime compilation (RTC) feature. To integrate a tiny-cuda-nn model with a larger kernel in your app, you need to 1. turn your kernel into a string, 2. prepend the tiny-cuda-nn model's device function, 3. pass the result to tiny-cuda-nn's runtime compilation API. Here is an example that implements a minimal kernel using a tiny-cuda-nn model with 32 input dimensions and 16 output dimensions: ```cpp #include auto model = tcnn::create_from_config(32 /* input dims */, 16 /* output dims */, ...); auto fused_kernel = tcnn::CudaRtcKernel( "your_kernel", fmt::format(R" {MODEL_DEVICE_FUNCTION} __global__ void your_kernel(...) { // Get input to model from either registers or memory. tcnn::hvec<32> input = ...; // Call tiny-cuda-nn model. All 32 threads of the warp must be active here. tcnn::hvec<16> output = model_fun(nerf_in, params); // Do something with the model output. }", fmt::arg("MODEL_DEVICE_FUNCTION", model->generate_device_function("model_fun")), ) ); uint32_t blocks = 1; uint32_t threads = 128; // Must be multiple of 32 for neural networks to work. uint32_t shmem_size = 0; // Can be any size that your_kernel needs. cudaStream_t stream = nullptr; // Can be any stream. fused_kernel.launch(blocks, threads, shmem_size, stream, ... /* params of your_kernel */); ``` And here is Instant NGP's NeRF integration with the JIT compiler for reference: - [src/testbed_nerf.cu](https://github.com/NVlabs/instant-ngp/blob/d6bbefb0b68e6322711b518eac7f9ab4c1cc7b1e/src/testbed_nerf.cu#L1931) - [include/neural-graphics-primitives/fused_kernels/render_nerf.cuh](https://github.com/NVlabs/instant-ngp/blob/master/include/neural-graphics-primitives/fused_kernels/render_nerf.cuh) ## Example: learning a 2D image We provide a sample application where an image function _(x,y) -> (R,G,B)_ is learned. It can be run via ```sh tiny-cuda-nn$ ./build/mlp_learning_an_image data/images/albert.jpg data/config_hash.json ``` producing an image every couple of training steps. Each 1000 steps should take a bit over 1 second with the default configuration on an RTX 4090. | 10 steps | 100 steps | 1000 steps | Reference image | |:---:|:---:|:---:|:---:| | ![10steps](data/readme/10.jpg) | ![100steps](data/readme/100.jpg) | ![1000steps](data/readme/1000.jpg) | ![reference](data/images/albert.jpg) | ## Requirements - An __NVIDIA GPU__; tensor cores increase performance when available. All shown results come from an RTX 3090. - A __C++14__ capable compiler. The following choices are recommended and have been tested: - __Windows:__ Visual Studio 2019 or 2022 - __Linux:__ GCC/G++ 8 or higher - A recent version of __[CUDA](https://developer.nvidia.com/cuda-toolkit)__. The following choices are recommended and have been tested: - __Windows:__ CUDA 11.5 or higher - __Linux:__ CUDA 10.2 or higher - __[CMake](https://cmake.org/) v3.21 or higher__. - The fully fused MLP component of this framework requires a __very large__ amount of shared memory in its default configuration. It will likely only work on an RTX 3090, an RTX 2080 Ti, or higher-end GPUs. Lower end cards must reduce the `n_neurons` parameter or use the `CutlassMLP` (better compatibility but slower) instead. If you are using Linux, install the following packages ```sh sudo apt-get install build-essential git ``` We also recommend installing [CUDA](https://developer.nvidia.com/cuda-toolkit) in `/usr/local/` and adding the CUDA installation to your PATH. For example, if you have CUDA 12.6.3, add the following to your `~/.bashrc` ```sh export PATH="/usr/local/cuda-12.6.3/bin:$PATH" export LD_LIBRARY_PATH="/usr/local/cuda-12.6.3/lib64:$LD_LIBRARY_PATH" ``` ## Compilation (Windows & Linux) Begin by cloning this repository and all its submodules using the following command: ```sh $ git clone --recursive https://github.com/nvlabs/tiny-cuda-nn $ cd tiny-cuda-nn ``` Then, use CMake to build the project: (on Windows, this must be in a [developer command prompt](https://docs.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=msvc-160#developer_command_prompt)) ```sh tiny-cuda-nn$ cmake . -B build -DCMAKE_BUILD_TYPE=RelWithDebInfo tiny-cuda-nn$ cmake --build build --config RelWithDebInfo -j ``` If compilation fails inexplicably or takes longer than an hour, you might be running out of memory. Try running the above command without `-j` in that case. ## PyTorch extension __tiny-cuda-nn__ comes with a [PyTorch](https://github.com/pytorch/pytorch) extension that allows using the fast MLPs and input encodings from within a [Python](https://www.python.org/) context. These bindings can be significantly faster than full Python implementations; in particular for the [multiresolution hash encoding](https://raw.githubusercontent.com/NVlabs/tiny-cuda-nn/master/data/readme/multiresolution-hash-encoding-diagram.png). > The overheads of Python/PyTorch can nonetheless be extensive if the batch size is small. > For example, with a batch size of 64k, the bundled `mlp_learning_an_image` example is __~2x slower__ through PyTorch than native CUDA. > With a batch size of 256k and higher (default), the performance is much closer. Begin by setting up a Python 3.X environment with a recent, CUDA-enabled version of PyTorch. Then, invoke ```sh pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch ``` Alternatively, if you would like to install from a local clone of __tiny-cuda-nn__, invoke ```sh tiny-cuda-nn$ cd bindings/torch tiny-cuda-nn/bindings/torch$ python setup.py install ``` Upon success, you can use __tiny-cuda-nn__ models as in the following example: ```py import commentjson as json import tinycudann as tcnn import torch with open("data/config_hash.json") as f: config = json.load(f) # Option 1: efficient Encoding+Network combo. model = tcnn.NetworkWithInputEncoding( n_input_dims, n_output_dims, config["encoding"], config["network"] ) # Option 2: separate modules. Slower but more flexible. encoding = tcnn.Encoding(n_input_dims, config["encoding"]) network = tcnn.Network(encoding.n_output_dims, n_output_dims, config["network"]) model = torch.nn.Sequential(encoding, network) model.jit_fusion = tcnn.supports_jit_fusion() # Optional: accelerate with JIT fusion ``` See `samples/mlp_learning_an_image_pytorch.py` for an example. ## Components Following is a summary of the components of this framework. [The JSON documentation](DOCUMENTATION.md) lists configuration options. | Networks |   |   | :--- | :---------- | :----- | Fully fused MLP | `src/fully_fused_mlp.cu` | Lightning fast implementation of small multi-layer perceptrons (MLPs). | CUTLASS MLP | `src/cutlass_mlp.cu` | MLP based on [CUTLASS](https://github.com/NVIDIA/cutlass)' GEMM routines. Slower than fully-fused, but handles larger networks and still is reasonably fast. | Input encodings |   |   | :--- | :---------- | :----- | Composite | `include/tiny-cuda-nn/encodings/composite.h` | Allows composing multiple encodings. Can be, for example, used to assemble the Neural Radiance Caching encoding [[Müller et al. 2021]](https://tom94.net/). | Frequency | `include/tiny-cuda-nn/encodings/frequency.h` | NeRF's [[Mildenhall et al. 2020]](https://www.matthewtancik.com/nerf) positional encoding applied equally to all dimensions. | Grid | `include/tiny-cuda-nn/encodings/grid.h` | Encoding based on trainable multiresolution grids. Used for [Instant Neural Graphics Primitives [Müller et al. 2022]](https://nvlabs.github.io/instant-ngp/). The grids can be backed by hashtables, dense storage, or tiled storage. | Identity | `include/tiny-cuda-nn/encodings/identity.h` | Leaves values untouched. | Oneblob | `include/tiny-cuda-nn/encodings/oneblob.h` | From Neural Importance Sampling [[Müller et al. 2019]](https://tom94.net/data/publications/mueller18neural/mueller18neural-v4.pdf) and Neural Control Variates [[Müller et al. 2020]](https://tom94.net/data/publications/mueller20neural/mueller20neural.pdf). | SphericalHarmonics | `include/tiny-cuda-nn/encodings/spherical_harmonics.h` | A frequency-space encoding that is more suitable to direction vectors than component-wise ones. | TriangleWave | `include/tiny-cuda-nn/encodings/triangle_wave.h` | Low-cost alternative to the NeRF's encoding. Used in Neural Radiance Caching [[Müller et al. 2021]](https://tom94.net/). | Losses |   |   | :--- | :---------- | :----- | L1 | `include/tiny-cuda-nn/losses/l1.h` | Standard L1 loss. | Relative L1 | `include/tiny-cuda-nn/losses/l1.h` | Relative L1 loss normalized by the network prediction. | MAPE | `include/tiny-cuda-nn/losses/mape.h` | Mean absolute percentage error (MAPE). The same as Relative L1, but normalized by the target. | SMAPE | `include/tiny-cuda-nn/losses/smape.h` | Symmetric mean absolute percentage error (SMAPE). The same as Relative L1, but normalized by the mean of the prediction and the target. | L2 | `include/tiny-cuda-nn/losses/l2.h` | Standard L2 loss. | Relative L2 | `include/tiny-cuda-nn/losses/relative_l2.h` | Relative L2 loss normalized by the network prediction [[Lehtinen et al. 2018]](https://github.com/NVlabs/noise2noise). | Relative L2 Luminance | `include/tiny-cuda-nn/losses/relative_l2_luminance.h` | Same as above, but normalized by the luminance of the network prediction. Only applicable when network prediction is RGB. Used in Neural Radiance Caching [[Müller et al. 2021]](https://tom94.net/). | Cross Entropy | `include/tiny-cuda-nn/losses/cross_entropy.h` | Standard cross entropy loss. Only applicable when the network prediction is a PDF. | Variance | `include/tiny-cuda-nn/losses/variance_is.h` | Standard variance loss. Only applicable when the network prediction is a PDF. | Optimizers |   |   | :--- | :---------- | :----- | Adam | `include/tiny-cuda-nn/optimizers/adam.h` | Implementation of Adam [[Kingma and Ba 2014]](https://arxiv.org/abs/1412.6980), generalized to AdaBound [[Luo et al. 2019]](https://github.com/Luolc/AdaBound). | Novograd | `include/tiny-cuda-nn/optimizers/lookahead.h` | Implementation of Novograd [[Ginsburg et al. 2019]](https://arxiv.org/abs/1905.11286). | SGD | `include/tiny-cuda-nn/optimizers/sgd.h` | Standard stochastic gradient descent (SGD). | Shampoo | `include/tiny-cuda-nn/optimizers/shampoo.h` | Implementation of the 2nd order Shampoo optimizer [[Gupta et al. 2018]](https://arxiv.org/abs/1802.09568) with home-grown optimizations as well as those by [Anil et al. [2020]](https://arxiv.org/abs/2002.09018). | Average | `include/tiny-cuda-nn/optimizers/average.h` | Wraps another optimizer and computes a linear average of the weights over the last N iterations. The average is used for inference only (does not feed back into training). | Batched | `include/tiny-cuda-nn/optimizers/batched.h` | Wraps another optimizer, invoking the nested optimizer once every N steps on the averaged gradient. Has the same effect as increasing the batch size but requires only a constant amount of memory. | | Composite | `include/tiny-cuda-nn/optimizers/composite.h` | Allows using several optimizers on different parameters. | EMA | `include/tiny-cuda-nn/optimizers/average.h` | Wraps another optimizer and computes an exponential moving average of the weights. The average is used for inference only (does not feed back into training). | Exponential Decay | `include/tiny-cuda-nn/optimizers/exponential_decay.h` | Wraps another optimizer and performs piecewise-constant exponential learning-rate decay. | Lookahead | `include/tiny-cuda-nn/optimizers/lookahead.h` | Wraps another optimizer, implementing the lookahead algorithm [[Zhang et al. 2019]](https://arxiv.org/abs/1907.08610). ## License and Citation This framework is licensed under the BSD 3-clause license. Please see `LICENSE.txt` for details. If you use it in your research, we would appreciate a citation via ```bibtex @software{tiny-cuda-nn, author = {M\"uller, Thomas}, license = {BSD-3-Clause}, month = {4}, title = {{tiny-cuda-nn}}, url = {https://github.com/NVlabs/tiny-cuda-nn}, version = {2.0}, year = {2021} } ``` For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) ## Publications & Software Among others, this framework powers the following publications: > __Instant Neural Graphics Primitives with a Multiresolution Hash Encoding__ > [Thomas Müller](https://tom94.net), [Alex Evans](https://research.nvidia.com/person/alex-evans), [Christoph Schied](https://research.nvidia.com/person/christoph-schied), [Alexander Keller](https://research.nvidia.com/person/alex-keller) > _ACM Transactions on Graphics (__SIGGRAPH__), July 2022_ > __[Website](https://nvlabs.github.io/instant-ngp/) / [Paper](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.pdf) / [Code](https://github.com/NVlabs/instant-ngp) / [Video](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.mp4) / [BibTeX](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.bib)__ > __Extracting Triangular 3D Models, Materials, and Lighting From Images__ > [Jacob Munkberg](https://research.nvidia.com/person/jacob-munkberg), [Jon Hasselgren](https://research.nvidia.com/person/jon-hasselgren), [Tianchang Shen](http://www.cs.toronto.edu/~shenti11/), [Jun Gao](http://www.cs.toronto.edu/~jungao/), [Wenzheng Chen](http://www.cs.toronto.edu/~wenzheng/), [Alex Evans](https://research.nvidia.com/person/alex-evans), [Thomas Müller](https://tom94.net), [Sanja Fidler](https://www.cs.toronto.edu/~fidler/) > __CVPR (Oral)__, June 2022 > __[Website](https://nvlabs.github.io/nvdiffrec/) / [Paper](https://nvlabs.github.io/nvdiffrec/assets/paper.pdf) / [Video](https://nvlabs.github.io/nvdiffrec/assets/video.mp4) / [BibTeX](https://nvlabs.github.io/nvdiffrec/assets/bib.txt)__ > __Real-time Neural Radiance Caching for Path Tracing__ > [Thomas Müller](https://tom94.net), [Fabrice Rousselle](https://research.nvidia.com/person/fabrice-rousselle), [Jan Novák](http://jannovak.info), [Alexander Keller](https://research.nvidia.com/person/alex-keller) > _ACM Transactions on Graphics (__SIGGRAPH__), August 2021_ > __[Paper](https://tom94.net/data/publications/mueller21realtime/mueller21realtime.pdf) / [GTC talk](https://gtc21.event.nvidia.com/media/Fully%20Fused%20Neural%20Network%20for%20Radiance%20Caching%20in%20Real%20Time%20Rendering%20%5BE31307%5D/1_liqy6k1c) / [Video](https://tom94.net/data/publications/mueller21realtime/mueller21realtime.mp4) / [Interactive results viewer](https://tom94.net/data/publications/mueller21realtime/interactive-viewer/) / [BibTeX](https://tom94.net/data/publications/mueller21realtime/mueller21realtime.bib)__ As well as the following software: > __NerfAcc: A General NeRF Accleration Toolbox__ > [Ruilong Li](https://www.liruilong.cn/), [Matthew Tancik](https://www.matthewtancik.com/about-me), [Angjoo Kanazawa](https://people.eecs.berkeley.edu/~kanazawa/) > __https://github.com/KAIR-BAIR/nerfacc__ > __Nerfstudio: A Framework for Neural Radiance Field Development__ > [Matthew Tancik*](https://www.matthewtancik.com/about-me), [Ethan Weber*](https://ethanweber.me/), [Evonne Ng*](http://people.eecs.berkeley.edu/~evonne_ng/), [Ruilong Li](https://www.liruilong.cn/), Brent Yi, Terrance Wang, Alexander Kristoffersen, Jake Austin, Kamyar Salahi, Abhik Ahuja, David McAllister, [Angjoo Kanazawa](https://people.eecs.berkeley.edu/~kanazawa/) > __https://github.com/nerfstudio-project/nerfstudio__ Please feel free to make a pull request if your publication or software is not listed. ## Acknowledgments Special thanks go to the NRC authors for helpful discussions and to [Nikolaus Binder](https://research.nvidia.com/person/nikolaus-binder) for providing part of the infrastructure of this framework, as well as for help with utilizing TensorCores from within CUDA.