1 Star 0 Fork 1

jinxiang.wang/pytorch-3dunet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
MIT

DOI

pytorch-3dunet

PyTorch implementation of a standard 3D U-Net based on:

3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation Özgün Çiçek et al.

as well as Residual 3D U-Net based on:

Superhuman Accuracy on the SNEMI3D Connectomics Challenge Kisuk Lee et al.

Prerequisites

  • Linux
  • NVIDIA GPU
  • CUDA CuDNN

Getting Started

Dependencies

  • pytorch (0.4.1+)
  • torchvision (0.2.1+)
  • tensorboardx (1.6+)
  • h5py
  • scipy
  • scikit-image
  • pytest

Setup a new conda environment with the required dependencies via:

conda create -n 3dunet pytorch torchvision tensorboardx h5py scipy scikit-image pyyaml pytest -c conda-forge -c pytorch

Activate newly created conda environment via:

source activate 3dunet

Supported model architectures

  • in order to train standard 3D U-Net specify name: UNet3D in the model section of the config file
  • in order to train Residual U-Net specify name: ResidualUNet3D in the model section of the config file

Supported Loss Functions

For a detailed explanation of the loss functions used see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso

  • WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the above paper for a detailed explanation; one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config)
  • CrossEntropyLoss (one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config)
  • PixelWiseCrossEntropyLoss (one can specify not only class weights but also per pixel weights in order to give more/less gradient in some regions of the ground truth)
  • BCEWithLogitsLoss
  • DiceLoss standard Dice loss (see 'Dice Loss' in the above paper for a detailed explanation).
  • GeneralizedDiceLoss (see 'Generalized Dice Loss (GDL)' in the above paper for a detailed explanation; one can specify class weights via weight: [w_1, ..., w_k] in the loss section of the config). Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at lease 3 orders of magnitude more voxels than the others. Otherwise use standard DiceLoss which works better than GDL most of the time.

Supported Evaluation Metrics

  • MeanIoU - Mean intersection over union
  • DiceCoefficient - Dice Coefficient (computes per channel Dice Coefficient and returns the average)
  • BoundaryAveragePrecision - Average Precision (normally used for evaluating instance segmentation, however it can be used when the 3D UNet is used to predict the boundary signal from the instance segmentation ground truth)
  • AdaptedRandError - Adapted Rand Error (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)

If not specified MeanIoU will be used by default.

Train

E.g. fit to randomly generated 3D volume and random segmentation mask from random_label3D.h5 run:

python train.py --config resources/train_config_ce.yaml # train with CrossEntropyLoss

or:

python train.py --config resources/train_config_dice.yaml # train with DiceLoss

See the train_config_ce.yaml for more info.

In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the train_config_ce.yaml. The HDF5 files should contain the raw/label data sets in the following axis order: DHW (in case of 3D) CDHW (in case of 4D).

Monitor progress with Tensorboard tensorboard --logdir ./3dunet/logs/ --port 8666 (you need tensorflow installed in your conda env). 3dunet-training

Training tips

  1. In order to train with BCEWithLogitsLoss, DiceLoss or GeneralizedDiceLoss the label data has to be 4D (one target binary mask per channel). If you have a 3D binary data (foreground/background), you can just change ToTensor transform for the label to contain expand_dims: true, see e.g. train_config_dice.yaml.

  2. When training with binary-based losses (BCEWithLogitsLoss, DiceLoss, GeneralizedDiceLoss) final_sigmoid=True has to be present in the training config, since every output channel gives the probability of the foreground. When training with cross entropy based losses (WeightedCrossEntropyLoss, CrossEntropyLoss, PixelWiseCrossEntropyLoss) set final_sigmoid=False so that Softmax normalization is applied to the output.

Test

Test on randomly generated 3D volume (just for demonstration purposes) from random_label3D.h5.

python predict.py --config resources/test_config_ce.yaml

or if you trained with DiceLoss:

python predict.py --config resources/test_config_dice.yaml

Prediction masks will be saved to resources/random_label3D_probabilities.h5.

In order to predict your own raw dataset provide the path to your model as well as paths to HDF5 test datasets in the test_config_ce.yaml.

Prediction tips

In order to avoid block artifacts in the output prediction masks the patch predictions are averaged, so make sure that patch/stride params lead to overlapping blocks, e.g. patch: [64 128 128] stride: [32 96 96] will give you a 'halo' of 32 voxels in each direction.

Contribute

If you want to contribute back, please make a pull request.

Cite

If you use this code for your research, please cite as:

Adrian Wolny. (2019, May 7). wolny/pytorch-3dunet: PyTorch implementation of 3D U-Net (Version v1.0.0). Zenodo. http://doi.org/10.5281/zenodo.2671581

MIT License Copyright (c) 2018 Adrian Wolny Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

简介

暂无描述 展开 收起
README
MIT
取消

发行版

暂无发行版

贡献者 (4)

全部

近期动态

接近6年前创建了仓库
不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/jinxiang/pytorch-3dunet.git
git@gitee.com:jinxiang/pytorch-3dunet.git
jinxiang
pytorch-3dunet
pytorch-3dunet
master

搜索帮助