# big_transfer **Repository Path**: shi4180/big_transfer ## Basic Information - **Project Name**: big_transfer - **Description**: 医学图像迁移 - **Primary Language**: Python - **License**: Apache-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2021-06-04 - **Last Updated**: 2021-06-04 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README ## Big Transfer (BiT): General Visual Representation Learning *by Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby* **Update 08/02/2021:** We also release ALL BiT-M models fine-tuned on ALL 19 VTAB-1k datasets, see below. ## Introduction In this repository we release multiple models from the [Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370) paper that were pre-trained on the [ILSVRC-2012](http://www.image-net.org/challenges/LSVRC/2012/) and [ImageNet-21k](http://www.image-net.org/) datasets. We provide the code to fine-tuning the released models in the major deep learning frameworks [TensorFlow 2](https://www.tensorflow.org/), [PyTorch](https://pytorch.org/) and [Jax](https://jax.readthedocs.io/en/latest/index.html)/[Flax](http://flax.readthedocs.io). We hope that the computer vision community will benefit by employing more powerful ImageNet-21k pretrained models as opposed to conventional models pre-trained on the ILSVRC-2012 dataset. We also provide colabs for a more exploratory interactive use: a [TensorFlow 2 colab](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb), a [PyTorch colab](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_pytorch.ipynb), and a [Jax colab](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_jax.ipynb). ## Installation Make sure you have `Python>=3.6` installed on your machine. To setup [Tensorflow 2](https://github.com/tensorflow/tensorflow), [PyTorch](https://github.com/pytorch/pytorch) or [Jax](https://github.com/google/jax), follow the instructions provided in the corresponding repository linked here. In addition, install python dependencies by running (please select `tf2`, `pytorch` or `jax` in the command below): ``` pip install -r bit_{tf2|pytorch|jax}/requirements.txt ``` ## How to fine-tune BiT First, download the BiT model. We provide models pre-trained on ILSVRC-2012 (BiT-S) or ImageNet-21k (BiT-M) for 5 different architectures: ResNet-50x1, ResNet-101x1, ResNet-50x3, ResNet-101x3, and ResNet-152x4. For example, if you would like to download the ResNet-50x1 pre-trained on ImageNet-21k, run the following command: ``` wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5} ``` Other models can be downloaded accordingly by plugging the name of the model (BiT-S or BiT-M) and architecture in the above command. Note that we provide models in two formats: `npz` (for PyTorch and Jax) and `h5` (for TF2). By default we expect that model weights are stored in the root folder of this repository. Then, you can run fine-tuning of the downloaded model on your dataset of interest in any of the three frameworks. All frameworks share the command line interface ``` python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10 ``` Currently. all frameworks will automatically download CIFAR-10 and CIFAR-100 datasets. Other public or custom datasets can be easily integrated: in TF2 and JAX we rely on the extensible [tensorflow datasets library](https://github.com/tensorflow/datasets/). In PyTorch, we use [torchvision’s data input pipeline](https://pytorch.org/docs/stable/torchvision/index.html). Note that our code uses all available GPUs for fine-tuning. We also support training in the low-data regime: the `--examples_per_class ` option will randomly draw K samples per class for training. To see a detailed list of all available flags, run `python3 -m bit_{pytorch|jax|tf2}.train --help`. ### BiT-M models fine-tuned on ILSVRC-2012 For convenience, we provide BiT-M models that were already fine-tuned on the ILSVRC-2012 dataset. The models can be downloaded by adding the `-ILSVRC2012` postfix, e.g. ``` wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz ``` ### Available architectures We release all architectures mentioned in the paper, such that you may choose between accuracy or speed: R50x1, R101x1, R50x3, R101x3, R152x4. In the above path to the model file, simply replace `R50x1` by your architecture of choice. We further investigated more architectures after the paper's publication and found R152x2 to have a nice trade-off between speed and accuracy, hence we also include this in the release and provide a few numbers below. ### BiT-M models fine-tuned on the 19 VTAB-1k tasks We also release the fine-tuned models for each of the 19 tasks included in the VTAB-1k benchmark. We ran each model three times and release each of these runs. This means we release a total of 5x19x3=285 models, and hope these can be useful in further analysis of transfer learning. The files can be downloaded via the following pattern: ``` wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz ``` We did not convert these models to TF2 (hence there is no corresponding `.h5` file), however, we also uploaded [TFHub](http://tfhub.dev) models which can be used in TF1 and TF2. An example sequence of commands for downloading one such model is: ``` mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb} mkdir variables && cd variables wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index} ``` ### Hyper-parameters For reproducibility, our training script uses hyper-parameters (BiT-HyperRule) that were used in the original paper. Note, however, that BiT models were trained and finetuned using Cloud TPU hardware, so for a typical GPU setup our default hyper-parameters could require too much memory or result in a very slow progress. Moreover, BiT-HyperRule is designed to generalize across many datasets, so it is typically possible to devise more efficient application-specific hyper-parameters. Thus, we encourage the user to try more light-weight settings, as they require much less resources and often result in a similar accuracy. For example, we tested our code using a 8xV100 GPU machine on the CIFAR-10 and CIFAR-100 datasets, while reducing batch size from 512 to 128 and learning rate from 0.003 to 0.001. This setup resulted in nearly identical performance (see [Expected results](#expected-results) below) in comparison to BiT-HyperRule, despite being less computationally demanding. Below, we provide more suggestions on how to optimize our paper's setup. ### Tips for optimizing memory or speed The default BiT-HyperRule was developed on Cloud TPUs and is quite memory-hungry. This is mainly due to the large batch-size (512) and image resolution (up to 480x480). Here are some tips if you are running out of memory: 1. In `bit_hyperrule.py` we specify the input resolution. By reducing it, one can save a lot of memory and compute, at the expense of accuracy. 2. The batch-size can be reduced in order to reduce memory consumption. However, one then also needs to play with learning-rate and schedule (steps) in order to maintain the desired accuracy. 3. The PyTorch codebase supports a batch-splitting technique ("micro-batching") via `--batch_split` option. For example, running the fine-tuning with `--batch_split 8` reduces memory requirement by a factor of 8. ## Expected results We verified that when using the BiT-HyperRule, the code in this repository reproduces the paper's results. ### CIFAR results (few-shot and full) For these common benchmarks, the aforementioned changes to the BiT-HyperRule (`--batch 128 --base_lr 0.001`) lead to the following, very similar results. The table shows the min←**median**→max result of at least five runs. **NOTE**: This is not a comparison of frameworks, just evidence that all code-bases can be trusted to reproduce results. #### BiT-M-R101x3 | Dataset | Ex/cls | TF2 | Jax | PyTorch | | :--- | :---: | :---: | :---: | :---: | | CIFAR10 | 1 | 52.5 ← **55.8** → 60.2 | 48.7 ← **53.9** → 65.0 | 56.4 ← **56.7** → 73.1 | | CIFAR10 | 5 | 85.3 ← **87.2** → 89.1 | 80.2 ← **85.8** → 88.6 | 84.8 ← **85.8** → 89.6 | | CIFAR10 | full | **98.5** | **98.4** | 98.5 ← **98.6** → 98.6 | | CIFAR100 | 1 | 34.8 ← **35.7** → 37.9 | 32.1 ← **35.0** → 37.1 | 31.6 ← **33.8** → 36.9 | | CIFAR100 | 5 | 68.8 ← **70.4** → 71.4 | 68.6 ← **70.8** → 71.6 | 70.6 ← **71.6** → 71.7 | | CIFAR100 | full | **90.8** | **91.2** | 91.1 ← **91.2** → 91.4 | #### BiT-M-R152x2 | Dataset | Ex/cls | Jax | PyTorch | | :--- | :---: | :---: | :---: | | CIFAR10 | 1 | 44.0 ← **56.7** → 65.0 | 50.9 ← **55.5** → 59.5 | | CIFAR10 | 5 | 85.3 ← **87.0** → 88.2 | 85.3 ← **85.8** → 88.6 | | CIFAR10 | full | **98.5** | 98.5 ← **98.5** → 98.6 | | CIFAR100 | 1 | 36.4 ← **37.2** → 38.9 | 34.3 ← **36.8** → 39.0 | | CIFAR100 | 5 | 69.3 ← **70.5** → 72.0 | 70.3 ← **72.0** → 72.3 | | CIFAR100 | full | **91.2** | 91.2 ← **91.3** → 91.4 | (TF2 models not yet available.) #### BiT-M-R50x1 | Dataset | Ex/cls | TF2 | Jax | PyTorch | | :--- | :---: | :---: | :---: | :---: | | CIFAR10 | 1 | 49.9 ← **54.4** → 60.2 | 48.4 ← **54.1** → 66.1 | 45.8 ← **57.9** → 65.7 | | CIFAR10 | 5 | 80.8 ← **83.3** → 85.5 | 76.7 ← **82.4** → 85.4 | 80.3 ← **82.3** → 84.9 | | CIFAR10 | full | **97.2** | **97.3** | **97.4** | | CIFAR100 | 1 | 35.3 ← **37.1** → 38.2 | 32.0 ← **35.2** → 37.8 | 34.6 ← **35.2** → 38.6 | | CIFAR100 | 5 | 63.8 ← **65.0** → 66.5 | 63.4 ← **64.8** → 66.5 | 64.7 ← **65.5** → 66.0 | | CIFAR100 | full | **86.5** | **86.4** | **86.6** | ### ImageNet results These results were obtained using BiT-HyperRule. However, because this results in large batch-size and large resolution, memory can be an issue. The PyTorch code supports batch-splitting, and hence we can still run things there without resorting to Cloud TPUs by adding the `--batch_split N` command where `N` is a power of two. For instance, the following command produces a validation accuracy of `80.68` on a machine with 8 V100 GPUs: ``` python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4 ``` Further increase to `--batch_split 8` when running with 4 V100 GPUs, etc. Full results achieved that way in some test runs were: | Ex/cls | R50x1 | R152x2 | R101x3 | | :---: | :---: | :---: | :---: | | 1 | 18.36 | 24.5 | 25.55 | | 5 | 50.64 | 64.5 | 64.18 | | full | 80.68 | 85.15 | WIP | ### VTAB-1k results These are re-runs and not the exact paper models. The expected VTAB scores for two of the models are: | Model | Full | Natural | Structured | Specialized | | :--- | :---: | :---: | :---: | :---: | | BiT-M-R152x4 | 73.51 | 80.77 | 61.08 | 85.67 | | BiT-M-R101x3 | 72.65 | 80.29 | 59.40 | 85.75 | ## Out of context dataset In Appendix G of our paper, we investigate whether BiT improves out-of-context robustness. To do this, we created a dataset comprising foreground objects corresponding to 21 ILSVRC-2012 classes pasted onto 41 miscellaneous backgrounds. To download the dataset, run ``` wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip ``` Images from each of the 21 classes are kept in a directory with the name of the class.