PyTorch implementation 3D U-Net and its variants:
Standard 3D U-Net based on 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation Özgün Çiçek et al.
Residual 3D U-Net based on Superhuman Accuracy on the SNEMI3D Connectomics Challenge Kisuk Lee et al.
The code allows for training the U-Net for both: semantic segmentation (binary and multi-class) and regression problems (e.g. de-noising, learning deconvolutions).
Training the standard 2D U-Net is also possible, see train_config_2d for example configuration. Just make sure to keep the singleton z-dimension in your H5 dataset (i.e. (1, Y, X)
instead of (Y, X)
) , cause data loading / data augmentation requires tensors of rank 3 always.
The package has not been tested on Windows, however some reported using it on Windows. One thing to keep in mind:
when training with CrossEntropyLoss
: the label type in the config file should be change from long
to int64
,
otherwise there will be an error: RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'
.
DiceLoss
defined as 1 - DiceCoefficient
used for binary semantic segmentation; when more than 2 classes are present in the ground truth, it computes the DiceLoss
per channel and averages the values).alpha * BCE + beta * Dice
, alpha, beta
can be specified in the loss
section of the config)weight: [w_1, ..., w_k]
in the loss
section of the config)weight: [w_1, ..., w_k]
in the loss
section of the config)weight: [w_1, ..., w_k]
in the loss
section of the config).
Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at least 3 orders of magnitude more voxels than the others. Otherwise use standard DiceLoss.For a detailed explanation of some of the supported loss functions see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso
If not specified MeanIoU
will be used by default.
pytorch-3dunet
package is via conda:conda create -n 3dunet -c conda-forge -c awolny python=3.7 pytorch-3dunet
conda activate 3dunet
After installation the following commands are accessible within the conda environment:
train3dunet
for training the network and predict3dunet
for prediction (see below).
python setup.py install
Make sure that the installed pytorch
is compatible with your CUDA version, otherwise the training/prediction will fail to run on GPU. You can re-install pytorch
compatible with your CUDA in the 3dunet
env by:
conda install -c pytorch torchvision cudatoolkit=<YOU_CUDA_VERSION> pytorch
Given that pytorch-3dunet
package was installed via conda as described above, one can train the network by simply invoking:
train3dunet --config <CONFIG>
where CONFIG
is the path to a YAML configuration file, which specifies all aspects of the training procedure.
See e.g. train_config_ce.yaml which describes how to train a standard 3D U-Net on a randomly generated 3D volume and random segmentation mask (random_label3D.h5) with cross-entropy loss (just a demo).
In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the train_config_ce.yaml.
The HDF5 files should contain the raw/label data sets in the following axis order: DHW
(in case of 3D) CDHW
(in case of 4D).
One can monitor the training progress with Tensorboard tensorboard --logdir <checkpoint_dir>/logs/
(you need tensorflow
installed in your conda env), where checkpoint_dir
is the path to the checkpoint directory specified in the config.
To try out training on randomly generated data right away, just checkout the repository and run:
cd pytorch3dunet
train3dunet --config ../resources/train_config_ce.yaml # train with CrossEntropyLoss (segmentation)
#train3dunet --config ../resources/train_config_dice.yaml # train with DiceLoss (segmentation)
#train3dunet --config ../resources/train_config_regression.yaml # train with SmoothL1Loss (regression)
To try out a boundary prediction task given a sample 3D confocal volume of plant cells (cell membrane marker), run:
cd pytorch3dunet
train3dunet --config ../resources/train_boundary.yaml
When training with binary-based losses, i.e.: BCEWithLogitsLoss
, DiceLoss
, BCEDiceLoss
, GeneralizedDiceLoss
:
ToTensor
transform for the label to contain expand_dims: true
, see e.g. train_config_dice.yaml.final_sigmoid=True
has to be present in the model
section of the config, since every output channel gives the probability of the foreground.
When training with cross entropy based losses (WeightedCrossEntropyLoss
, CrossEntropyLoss
, PixelWiseCrossEntropyLoss
) set final_sigmoid=False
so that Softmax
normalization is applied to the output.Given that pytorch-3dunet
package was installed via conda as described above, one can run the prediction via:
predict3dunet --config <CONFIG>
To run the prediction on randomly generated 3D volume (just for demonstration purposes) from random_label3D.h5 and a network trained with cross-entropy loss:
cd pytorch3dunet
predict3dunet --config ../resources/test_config_ce.yaml
or if trained with DiceLoss
:
cd pytorch3dunet
predict3dunet --config ../resources/test_config_dice.yaml
Predicted volume will be saved to resources/random_label3D_probabilities.h5
.
In order to predict on your own data, just provide the path to your model as well as paths to HDF5 test files (seetest_config_ce.yaml).
In order to avoid checkerboard artifacts in the output prediction masks the patch predictions are averaged, so make sure that patch/stride
params lead to overlapping blocks, e.g. patch: [64 128 128] stride: [32 96 96]
will give you a 'halo' of 32 voxels in each direction.
By default, if multiple GPUs are available training/prediction will be run on all the GPUs using DataParallel.
If training/prediction on all available GPUs is not desirable, restrict the number of GPUs using CUDA_VISIBLE_DEVICES
, e.g.
CUDA_VISIBLE_DEVICES=0,1 train3dunet --config <CONFIG>
or
CUDA_VISIBLE_DEVICES=0,1 predict3dunet --config <CONFIG>
If you want to contribute back, please make a pull request.
If you use this code for your research, please cite as:
@article {Wolny2020.01.17.910562,
author = {Wolny, Adrian and Cerrone, Lorenzo and Vijayan, Athul and Tofanelli, Rachele and Barro,
Amaya Vilches and Louveaux, Marion and Wenzl, Christian and Steigleder, Susanne and Pape,
Constantin and Bailoni, Alberto and Duran-Nebreda, Salva and Bassel, George and Lohmann,
Jan U. and Hamprecht, Fred A. and Schneitz, Kay and Maizel, Alexis and Kreshuk, Anna},
title = {Accurate And Versatile 3D Segmentation Of Plant Tissues At Cellular Resolution},
elocation-id = {2020.01.17.910562},
year = {2020},
doi = {10.1101/2020.01.17.910562},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2020/01/18/2020.01.17.910562},
eprint = {https://www.biorxiv.org/content/early/2020/01/18/2020.01.17.910562.full.pdf},
journal = {bioRxiv}
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。