# RETapp **Repository Path**: hao203/RETapp ## Basic Information - **Project Name**: RETapp - **Description**: RETapp- A revised gradio-based app for training models to predict diseases from retinal images 一个利用自己的视网膜图像数据集自动进行疾病预测的项目 - **Primary Language**: Python - **License**: Not specified - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2024-10-01 - **Last Updated**: 2024-10-14 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README RETapp- A revised gradio-based app for training models to predict diseases from retinal images. ## 🚀Introduction This repository is a revised gradio-based app for training models predicting diseases from retinal imaging. (source: https://github.com/rmaphoh/RETFound_MAE) We just want to let doctors to train their own models on their own datasets of retinal images easily. Also, we updated the packages to be compatible with Python 3.8.+, cuda 11.7 and ubuntu 22.04. ## Installation We first install the dependencies: pip install -r requirements.txt Then we install cuda on Linux Ubuntu 22.04 (64-bit) support cuda 11.7+ Here we install cuda 11.7 and pytorch==1.13.1+cu117 ``` pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 # or conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia` ``` ## How to use 1. create a new folder for the datasets in the repo `data/`. e.g. I have download the OCTID dataset from [downloaded here](BENCHMARK.md). It was split into 3 folders: train, val and test and then organised into 5 classes: ANormal, ARMD, CSR, Diabetic_retinopathy, Macular_Hole. ![split folder](./pic/split.png) ![label folder](./pic/label.png) The structure should be like this: ``` ├── data folder ├──train ├──class_a ├──class_b ├──class_c ├──val ├──class_a ├──class_b ├──class_c ├──test ├──class_a ├──class_b ├──class_c ``` > Note: the folder name should be the same as the class name. 2. Training Run `python train_web.py` to launch the gradio app and input the paramiters. ![train](./pic/train.png) The paramiters are listed below: --batch_size 16 --world_size 1 --model vit_large_patch16 --epochs 50 --blr 0 --layer_decay 0.65 --weight_decay 0.05 --nb_classes 5 (number of classes) --data_path ./data/OCTID/ --task ./finetune_OCTID/ (path to the task folder, including metrics and checkpoints) --finetune ./models/RETFound_cfp_weights.pth (path to the pretrained weights) --input_size 224 --drop_path 0.1 --device cuda You can see the training progress (Tensorboard) and log in the `task` folder. ![result](./pic/result.png) 3. Prediction modify the paramiters in `app.py` This is for the finetuned models. ![app modi](./pic/modifine.png) This is for the task you want and choose the basic model. ![task aim](./pic/basemodel.png) Run `python app.py` to launch the gradio app (modify the paramiters). ![appui](./pic/appui.png) Enjoy!!! --- **The following is the original readme file from the official repo:** --- >## RETFound - A foundation model for retinal imaging Official repo for [RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae): Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions. Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE) ### 📝Key features - RETFound is pre-trained on 1.6 million retinal images with self-supervised learning - RETFound has been validated in multiple disease detection tasks - RETFound can be efficiently adapted to customised tasks ### 🎉News - 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_Feature.ipynb) are now online! - 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online! - 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation! - 2023/09: a [visualisation demo](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_visualize.ipynb) is added - 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size ### 🔧Install environment 1. Create environment with conda: ``` conda create -n retfound python=3.7.5 -y conda activate retfound ``` 2. Install dependencies ``` git clone https://github.com/rmaphoh/RETFound_MAE/ cd RETFound_MAE pip install -r requirement.txt ``` ### 🌱Fine-tuning with RETFound weights To fine tune RETFound on your own data, follow these steps: 1. Download the RETFound pre-trained weights
ViT-Large
Colour fundus image download
OCT download
2. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md)) ``` ├── data folder ├──train ├──class_a ├──class_b ├──class_c ├──val ├──class_a ├──class_b ├──class_c ├──test ├──class_a ├──class_b ├──class_c ``` 3. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training. ``` python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ --batch_size 16 \ --world_size 1 \ --model vit_large_patch16 \ --epochs 50 \ --blr 5e-3 --layer_decay 0.65 \ --weight_decay 0.05 --drop_path 0.2 \ --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./finetune_IDRiD/ \ --finetune ./RETFound_cfp_weights.pth \ --input_size 224 ``` 4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below) ``` python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ --eval --batch_size 16 \ --world_size 1 \ --model vit_large_patch16 \ --epochs 50 \ --blr 5e-3 --layer_decay 0.65 \ --weight_decay 0.05 --drop_path 0.2 \ --nb_classes 5 \ --data_path ./IDRiD_data/ \ --task ./internal_IDRiD/ \ --resume ./finetune_IDRiD/checkpoint-best.pth \ --input_size 224 ``` ### Load the model and weights (if you want to call the model in your code) ```python import torch import models_vit from util.pos_embed import interpolate_pos_embed from timm.models.layers import trunc_normal_ # call the model model = models_vit.__dict__['vit_large_patch16']( num_classes=2, drop_path_rate=0.2, global_pool=True, ) # load RETFound weights checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu') checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # interpolate position embedding interpolate_pos_embed(model, checkpoint_model) # load pre-trained model msg = model.load_state_dict(checkpoint_model, strict=False) assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} # manually initialize fc layer trunc_normal_(model.head.weight, std=2e-5) print("Model = %s" % str(model)) ``` ### 📃Citation If you find this repository useful, please consider citing this paper: ``` @article{zhou2023foundation, title={A foundation model for generalizable disease detection from retinal images}, author={Zhou, Yukun and Chia, Mark A and Wagner, Siegfried K and Ayhan, Murat S and Williamson, Dominic J and Struyven, Robbert R and Liu, Timing and Xu, Moucheng and Lozano, Mateo G and Woodward-Court, Peter and others}, journal={Nature}, volume={622}, number={7981}, pages={156--163}, year={2023}, publisher={Nature Publishing Group UK London} } ```