1 Star 0 Fork 0

zweien/pytorch-3dunet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train_config_dice.yaml 5.65 KB
一键复制 编辑 原始数据 按行查看 历史
# use a fixed random seed to guarantee that when you run the code twice you will get the same outcome
manual_seed: 0
# model configuration
model:
# model class, e.g. UNet3D, ResidualUNet3D
name: UNet3D
# number of input channels to the model
in_channels: 1
# number of output channels
out_channels: 1
# determines the order of operators in a single layer (gcr - GroupNorm+Conv3d+ReLU)
layer_order: gcr
# feature maps scale factor
f_maps: 32
# number of groups in the groupnorm
num_groups: 8
# apply element-wise nn.Sigmoid after the final 1x1 convolution, otherwise apply nn.Softmax
final_sigmoid: true
# if True applies the final normalization layer (sigmoid or softmax), otherwise the networks returns the output from the final convolution layer; use False for regression problems, e.g. de-noising
is_segmentation: true
# trainer configuration
trainer:
# path to the checkpoint directory
checkpoint_dir: 3dunet
# path to latest checkpoint; if provided the training will be resumed from that checkpoint
resume: null
# how many iterations between validations
validate_after_iters: 20
# how many iterations between tensorboard logging
log_after_iters: 20
# max number of epochs
epochs: 50
# max number of iterations
iters: 100000
# model with higher eval score is considered better
eval_score_higher_is_better: True
# optimizer configuration
optimizer:
# initial learning rate
learning_rate: 0.0002
# weight decay
weight_decay: 0.0001
# loss function configuration
loss:
# loss function to be used during training
name: DiceLoss
# A manual rescaling weight given to each class.
weight: null
# a target value that is ignored and does not contribute to the input gradient
ignore_index: null
# evaluation metric configuration
eval_metric:
name: MeanIoU
# a target label that is ignored during metric evaluation
ignore_index: null
# learning rate scheduler configuration
lr_scheduler:
name: MultiStepLR
milestones: [10, 30, 60]
gamma: 0.2
# data loaders configuration
loaders:
# class of the HDF5 dataset, currently StandardHDF5Dataset and LazyHDF5Dataset are supported.
# When using LazyHDF5Dataset make sure to set `num_workers = 1`, due to a bug in h5py which corrupts the data
# when reading from multiple threads.
dataset: StandardHDF5Dataset
# batch dimension; if number of GPUs is N > 1, then a batch_size of N * batch_size will automatically be taken for DataParallel
batch_size: 1
# how many subprocesses to use for data loading
num_workers: 4
# path to the raw data within the H5
raw_internal_path: raw
# path to the the label data withtin the H5
label_internal_path: label
# path to the pixel-wise weight map withing the H5 if present
weight_internal_path: null
# configuration of the train loader
train:
# absolute paths to the training datasets; if a given path is a directory all H5 files ('*.h5', '*.hdf', '*.hdf5', '*.hd5')
# inside this this directory will be included as well (non-recursively)
file_paths:
- '../resources/random3D.h5'
# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
# SliceBuilder class
name: SliceBuilder
# train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better)
patch_shape: [32, 64, 64]
# train stride between patches
stride_shape: [8, 16, 16]
# data transformations/augmentations
transformer:
raw:
# re-scale the values to be 0-mean and 1-std
- name: Standardize
# randomly flips an image across randomly chosen axis
- name: RandomFlip
# rotate an image by 90 degrees around a randomly chosen plane
- name: RandomRotate90
# rotate an image by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval
- name: RandomRotate
# rotate only in ZY only since most volumetric data is anisotropic
axes: [[2, 1]]
angle_spectrum: 15
mode: reflect
# apply elasitc deformations of 3D patches on a per-voxel mesh
- name: ElasticDeformation
spline_order: 3
# randomly adjust contrast
- name: RandomContrast
# apply additive Gaussian noise
- name: AdditiveGaussianNoise
# apply additive Poisson noise
- name: AdditivePoissonNoise
# convert to torch tensor
- name: ToTensor
# add additional 'channel' axis when the input data is 3D
expand_dims: true
label:
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
# rotate only in ZY only since most volumetric data is anisotropic
axes: [[2, 1]]
angle_spectrum: 15
mode: reflect
- name: ElasticDeformation
spline_order: 0
- name: ToTensor
expand_dims: true
# configuration of the validation loaders
val:
# paths to the validation datasets; if a given path is a directory all H5 files ('*.h5', '*.hdf', '*.hdf5', '*.hd5')
# inside this this directory will be included as well (non-recursively)
file_paths:
- '../resources/random3D_copy.h5'
# SliceBuilder configuration
slice_builder:
# SliceBuilder class
name: SliceBuilder
# validation patch (can be bigger than train patch since there is no backprop)
patch_shape: [32, 64, 64]
# validation stride (validation patches doesn't need to overlap)
stride_shape: [32, 64, 64]
transformer:
raw:
- name: Standardize
- name: ToTensor
expand_dims: true
label:
- name: ToTensor
expand_dims: true
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zweien/pytorch-3dunet.git
git@gitee.com:zweien/pytorch-3dunet.git
zweien
pytorch-3dunet
pytorch-3dunet
master

搜索帮助