Run the following to generate and test WordNet hierarchies for CIFAR10, CIFAR100, and TinyImagenet200. The script also downloads the NLTK WordNet corpus.
```bash
bash scripts/generate_hierarchies_wordnet.sh
```
The below just explains the above `generate_hierarchies_wordnet.sh`, using CIFAR10. You do not need to run the following after running the above bash script.
```bash
# Generate mapping from classes to WNID. This is required for CIFAR10 and CIFAR100.
nbdt-wnids --dataset=CIFAR10
# Generate hierarchy, using the WNIDs. This is required for all datasets: CIFAR10, CIFAR100, TinyImagenet200
nbdt-hierarchy --method=wordnet --dataset=CIFAR10
```
See example WordNet visualization. [click to expand]
We can generate a visualization with a slightly improved zoom and with wordnet IDs. By default, the script builds the Wordnet hierarchy for CIFAR10.
```
nbdt-hierarchy --method=wordnet --vis-zoom=1.25 --vis-sublabels
```
Generate random hierarchy. [click to expand]
Use `--method=random` to randomly generate a binary-ish hierarchy. Optionally, use the `--seed` (`--seed=-1` to *not* shuffle leaves) and `--branching-factor` flags. When debugging, we set branching factor to the number of classes. For example, the sanity check hierarchy for CIFAR10 is
```bash
nbdt-hierarchy --seed=-1 --branching-factor=10 --dataset=CIFAR10
```
## 2. Tree Supervision Loss
In the below training commands, we uniformly use `--path-resume=
--lr=0.01` to fine-tune instead of training from scratch. Our results using a recently state-of-the-art pretrained checkpoint (WideResNet) were fine-tuned. Run the following to fine-tune WideResNet with soft tree supervision loss on CIFAR10.
```bash
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=SoftTreeSupLoss
```
See how it works and how to configure. [click to expand]

The tree supervision loss features two variants: a hard version and a soft version. Simply change the loss to `HardTreeSupLoss` or `SoftTreeSupLoss`, depending on the one you want.
```bash
# fine-tune the wrn pretrained checkpoint on CIFAR10 with hard tree supervision loss
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=HardTreeSupLoss
# fine-tune the wrn pretrained checkpoint on CIFAR10 with soft tree supervision loss
python main.py --lr=0.01 --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --pretrained --loss=SoftTreeSupLoss
```
To train from scratch, use `--lr=0.1` and do not pass the `--path-resume` or `--pretrained` flags. We fine-tune WideResnet on CIFAR10, CIFAR100, but where the baseline neural network accuracy is reproducible, we train from scratch.
## 3. Inference
Like with the tree supervision loss variants, there are two inference variants: one is hard and one is soft. Below, we run soft inference on the model we just trained with the soft loss.
Run the following bash script to obtain these numbers.
```bash
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=SoftEmbeddedDecisionRules
```
See how it works and how to configure. [click to expand]

Note the following commands are nearly identical to the corresponding train commands -- we drop the `lr`, `pretrained` flags and add `resume`, `eval`, and the `analysis` type (hard or soft inference). The best results in our paper, oddly enough, were obtained by running hard and soft inference *both* on the neural network supervised by a soft tree supervision loss. This is reflected in the commands below.
```bash
# running soft inference on soft-supervised model
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=SoftEmbeddedDecisionRules
# running hard inference on soft-supervised model
python main.py --dataset=CIFAR10 --arch=wrn28_10_cifar10 --hierarchy=induced-wrn28_10_cifar10 --loss=SoftTreeSupLoss --eval --resume --analysis=HardEmbeddedDecisionRules
```
Logging maximum and minimum 'path entropy' samples. [click to expand]
```
# get min and max entropy samples for baseline neural network
python main.py --pretrained --dataset=TinyImagenet200 --eval --dataset-test=Imagenet1000 --disable-test-eval --analysis=TopEntropy # or Entropy, or TopDifference
# download public checkpoint
wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-TinyImagenet200-ResNet18-induced-ResNet18-SoftTreeSupLoss-tsw10.0.pth -O checkpoint/ckpt-TinyImagenet200-ResNet18-induced-ResNet18-SoftTreeSupLoss-tsw10.0.pth
# get min and max 'path entropy' samples for NBDT
python main.py --dataset TinyImagenet200 --resume --path-resume checkpoint/ckpt-TinyImagenet200-ResNet18-induced-ResNet18-SoftTreeSupLoss-tsw10.0.pth --eval --analysis NBDTEntropyMaxMin --dataset-test=Imagenet1000 --disable-test-eval --hierarchy induced-ResNet18
```
Running zero-shot evaluation on superclasses. [click to expand]
```
# get wnids for animal and vehicle -- use the outputted wnids for below commands
nbdt-wnids --classes animal vehicle
# evaluate CIFAR10-trained ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200
python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=Superclass --superclass-wnids n00015388 n04524313 --pretrained
# download public checkpoint
wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth -O checkpoint/ckpt-CIFAR10-ResNet18-induced-SoftTreeSupLoss.pth
# evaluate CIFAR10-trained NBDT-ResNet18 on "Animal vs. Vehicle" superclasses, with images from TinyImagenet200
python main.py --dataset-test=TinyImagenet200 --dataset=CIFAR10 --disable-test-eval --eval --analysis=SuperclassNBDT --superclass-wnids n00015388 n04524313 --loss=SoftTreeSupLoss --resume
```
Visualize decision nodes using 'prototypical' samples. [click to expand]
```
# get wnids for animal and vehicle -- use the outputted wnids for below commands
nbdt-wnids --classes animal vehicle
# find samples representative for CIFAR10-trained ResNet18, from animal and vehicle ImageNet images
python main.py --dataset-test=Imagenet1000 --dataset=CIFAR10 --disable-test-eval --eval --analysis=VisualizeDecisionNode --vdnw=n00015388 --pretrained --superclass-wnids n00015388 n04524313 # samples for "animal" node
python main.py --dataset-test=Imagenet1000 --dataset=CIFAR10 --disable-test-eval --eval --analysis=VisualizeDecisionNode --vdnw=n00015388 --pretrained --superclass-wnids n00015388 n04524313 # samples for "ungulate" node
# download public checkpoint
wget https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth -O checkpoint/ckpt-CIFAR10-ResNet18-induced-SoftTreeSupLoss.pth
# find samples representative for CIFAR10-trained NBDT with ResNet18 backbone, from animal and vehicle ImageNet images
python main.py --dataset-test=Imagenet1000 --dataset=CIFAR10 --disable-test-eval --eval --analysis=VisualizeDecisionNode --vdnw=n01466257 --loss=SoftTreeSupLoss --resume --hierarchy=induced-ResNet18 --superclass-wnids n00015388 n04524313 # samples for "animal" node
```
Visualize inference probabilities in hierarchy. [click to expand]
```
python main.py --analysis=VisualizeHierarchyInference --eval --pretrained # soft inference by default
```
# Results
We compare against all previous decision-tree-based methods that report on CIFAR10, CIFAR100, and/or ImageNet; we use numbers reported in the original papers (except DNDF, which did not have CIFAR or ImageNet top-1 scores):
- Deep Neural Decision Forest (DNDF, updated with ResNet18)
- Explainable Observer-Classifier (XOC)
- Deep ConvolutionalDecision Jungle (DCDJ)
- Network of Experts (NofE)
- Deep Decision Network (DDN)
- Adaptive Neural Trees (ANT)
- Oblique Decision Trees (ODT)
- Classic Decision Trees
| | CIFAR10 | CIFAR100 | TinyImagenet200 | ImageNet |
|----------------------|---------|----------|-----------------|----------|
| NBDT (Ours) | 97.55% | 82.97% | 67.72% | 76.60% |
| Best Pre-NBDT Acc | 94.32% | 76.24% | 44.56% | 61.29% |
| Best Pre-NBDT Method | DNDF | NofE | DNDF | NofE |
| Our improvement | 3.23% | 6.73% | 23.16% | **15.31%** |
Our pretrained checkpoints (CIFAR10, CIFAR100, and TinyImagenet200) may deviate from these numbers by 0.1-0.2%, as we retrained all models for public release.
# Customize Repository for Your Application
As discussed above, you can use the `nbdt` python library to integrate NBDT training into any existing training pipeline, like ClassyVision ([ClassyVision + NBDT Imagenet example](https://github.com/alvinwan/neural-backed-decision-trees/tree/master/examples/imagenet)). However, if you wish to use the barebones training utilities here, refer to the following sections for adding custom models and datasets.
If you have not already, start by cloning the repository and installing all requirements. As a sample, we've included copies of the WideResNet bash script but for ResNet18.
```bash
git clone git@github.com:alvinwan/neural-backed-decision-trees.git # or http addr if you don't have private-public github key setup
cd neural-backed-decision-trees
python setup.py develop
bash scripts/gen_train_eval_resnet.sh
```
For any models that have pretrained checkpoints for the datasets of interest (e.g., CIFAR10, CIFAR100, and ImageNet models from `pytorchcv` or ImageNet models from `torchvision`), modify `scripts/gen_train_eval_pretrained.sh`; it suffices to change the model name. For all models that do not have pretrained checkpoint for the dataset of interest, modify `scripts/gen_train_eval_nopretrained.sh`.
## Models
Without any modifications to `main.py`, you can replace ResNet18 with your favorite network: Pass any [`torchvision.models`](https://pytorch.org/docs/stable/torchvision/models.html) model or any [`pytorchcv`](https://github.com/osmr/imgclsmob/tree/master/pytorch) model to `--arch`, as we directly support both model zoos. Note that the former only supports models pretrained on ImageNet. The latter supports models pretrained on CIFAR10, CIFAR100, andd ImageNet; for each dataset, the corresponding model name includes the dataset e.g., `wrn28_10_cifar10`. However, neither supports models pretrained on TinyImagenet.
To add a new model from scratch:
1. Create a new file containing your network, such as `./nbdt/models/yournet.py`. This file should contain an `__all__` only exposing functions that return a model. These functions should accept `pretrained: bool` and `progress: bool`, then forward all other keyword arguments to the model constructor.
2. Expose your new file via `./nbdt/models/__init__.py`: `from .yournet import *`.
3. Train the original neural network on the target dataset. e.g., `python main.py --arch=yournet18`.
## Dataset
Without any modifications to `main.py`, you can use any image classification dataset found at [`torchvision.datasets`](https://pytorch.org/docs/stable/torchvision/datasets.html) by passing it to `--dataset`. To add a new dataset from scratch:
1. Create a new file containing your dataset, such as `./nbdt/data/yourdata.py`. Say the data class is `YourData10`. Like before, only expose the dataset class via `__all__`. This dataset class should support a `.classes` attribute which returns a list of human-readable class names.
2. Expose your new file via `'./nbdt/data/__init__.py'`: `from .yourdata import *`.
3. Modify `nbdt.utils.DATASETS` to include the name of your dataset, which is `YourData10` in this example.
4. Also in `nbdt/utils.py`, modify `DATASET_TO_NUM_CLASSES` and `DATASET_TO_CLASSES` to include your new dataset.
5. (Optional) Create a text file with wordnet IDs in `./nbdt/wnids/{dataset}.txt`. This list should be in the same order that your dataset's `.classes` is. You may optionally use the utility `nbdt-wnids` to generate wnids (see note below)
6. Train the original neural network on the target dataset. e.g., `python main.py --dataset=YourData10`
> **\*Note**: You may optionally use the utility `nbdt-wnids` to generate wnids:
> ```
> nbdt-wnids --dataset=YourData10
> ```
> , where `YourData` is your dataset name. If a provided class name from `YourData.classes` does not exist in the WordNet corpus, the script will generate a fake wnid. This does not affect training but subsequent analysis scripts will be unable to provide WordNet-imputed node meanings.
## Tests
To run tests, use the following command
```
pytest nbdt tests
```
# Citation
If you find this work useful for your research, please cite our [paper](http://nbdt.alvinwan.com/paper/):
```
@misc{nbdt,
title={NBDT: Neural-Backed Decision Trees},
author={Alvin Wan and Lisa Dunlap and Daniel Ho and Jihan Yin and Scott Lee and Henry Jin and Suzanne Petryk and Sarah Adel Bargal and Joseph E. Gonzalez},
year={2020},
eprint={2004.00221},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```