PyTorch implementation of a standard 3D U-Net based on:
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation Özgün Çiçek et al.
as well as Residual 3D U-Net based on:
Superhuman Accuracy on the SNEMI3D Connectomics Challenge Kisuk Lee et al.
Setup a new conda environment with the required dependencies via:
conda create -n 3dunet pytorch torchvision tensorboardx h5py scipy scikit-image pyyaml pytest -c conda-forge -c pytorch
Activate newly created conda environment via:
source activate 3dunet
name: UNet3D
in the model
section of the config filename: ResidualUNet3D
in the model
section of the config fileFor a detailed explanation of the loss functions used see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso
weight: [w_1, ..., w_k]
in the loss
section of the config)weight: [w_1, ..., w_k]
in the loss
section of the config)weight: [w_1, ..., w_k]
in the loss
section of the config).
Note: use this loss function only if the labels in the training dataset are very imbalanced
e.g. one class having at lease 3 orders of magnitude more voxels than the others. Otherwise use standard DiceLoss which works better than GDL most of the time.If not specified MeanIoU
will be used by default.
E.g. fit to randomly generated 3D volume and random segmentation mask from random_label3D.h5 run:
python train.py --config resources/train_config_ce.yaml # train with CrossEntropyLoss
or:
python train.py --config resources/train_config_dice.yaml # train with DiceLoss
See the train_config_ce.yaml for more info.
In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the train_config_ce.yaml.
The HDF5 files should contain the raw/label data sets in the following axis order: DHW
(in case of 3D) CDHW
(in case of 4D).
Monitor progress with Tensorboard tensorboard --logdir ./3dunet/logs/ --port 8666
(you need tensorflow
installed in your conda env).
In order to train with BCEWithLogitsLoss
, DiceLoss
or GeneralizedDiceLoss
the label data has to be 4D (one target binary mask per channel).
If you have a 3D binary data (foreground/background), you can just change ToTensor
transform for the label to contain expand_dims: true
, see e.g. train_config_dice.yaml.
When training with binary-based losses (BCEWithLogitsLoss
, DiceLoss
, GeneralizedDiceLoss
) final_sigmoid=True
has to be present in the training config, since every output channel gives the probability of the foreground.
When training with cross entropy based losses (WeightedCrossEntropyLoss
, CrossEntropyLoss
, PixelWiseCrossEntropyLoss
) set final_sigmoid=False
so that Softmax
normalization is applied to the output.
Test on randomly generated 3D volume (just for demonstration purposes) from random_label3D.h5.
python predict.py --config resources/test_config_ce.yaml
or if you trained with DiceLoss
:
python predict.py --config resources/test_config_dice.yaml
Prediction masks will be saved to resources/random_label3D_probabilities.h5
.
In order to predict your own raw dataset provide the path to your model as well as paths to HDF5 test datasets in the test_config_ce.yaml.
In order to avoid block artifacts in the output prediction masks the patch predictions are averaged, so make sure that patch/stride
params lead to overlapping blocks, e.g. patch: [64 128 128] stride: [32 96 96]
will give you a 'halo' of 32 voxels in each direction.
If you want to contribute back, please make a pull request.
If you use this code for your research, please cite as:
Adrian Wolny. (2019, May 7). wolny/pytorch-3dunet: PyTorch implementation of 3D U-Net (Version v1.0.0). Zenodo. http://doi.org/10.5281/zenodo.2671581
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。