We use python files as configs, incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.
You can find all the provided configs under $PROJECT/configs
.
We can logically divide the configuration file into components:
The fllowing content explain these configuration components one by one.
training training configurations contains all paramters to control model training, include optimizer, hooks, runner and soon on.
import os
from datetime import datetime
method = 'nerf' # which nerf method
# optimizer setting
optimizer = dict(type='Adam', lr=5e-4, betas=(0.9, 0.999))
optimizer_config = dict(grad_clip=None)
max_iters = 20000 # train for how many iters
lr_config = dict(policy='step', step=500 * 1000, gamma=0.1, by_epoch=False) # learning rate and decay
checkpoint_config = dict(interval=5000, by_epoch=False) # when to save checkpoint
log_level = 'INFO'
log_config = dict(interval=5000,
by_epoch=False,
hooks=[dict(type='TextLoggerHook')])
workflow = [('train', 5000), ('val', 1)] # loop: train 5000 iters, validate 1 iter
# hooks
# 'params' are numeric type value, 'variables' are variables in local environment
train_hooks = [
dict(type='SetValPipelineHook',
params=dict(),
variables=dict(valset='valset')),
dict(type='ValidateHook',
params=dict(save_folder='visualizations/validation')),
dict(type='SaveSpiralHook',
params=dict(save_folder='visualizations/spiral')),
dict(type='PassIterHook', params=dict()), # 将当前iter数告诉dataset
dict(type='OccupationHook',
params=dict()), # no need for open-source vision
]
test_hooks = [
dict(type='SetValPipelineHook',
params=dict(),
variables=dict(valset='testset')),
dict(type='TestHook', params=dict()),
]
# runner
train_runner = dict(type='NerfTrainRunner')
test_runner = dict(type='NerfTestRunner')
# runtime settings
num_gpus = 1
distributed = (num_gpus > 1) # whether to use ddp
work_dir = './work_dirs/nerfsv3/nerf_#DATANAME#_base01/' # where to save ckpt, images, video, logs
timestamp = datetime.now().strftime('%d-%b-%H-%M') # to make sure different log-files each train
# some shared params by model and data, to avoid define twice
dataset_type = 'blender'
no_batching = True # only take random rays from 1 image at a time
no_ndc = True
white_bkgd = True # set to render synthetic data on a white bkgd (always use for dvoxels)
is_perturb = True # set to 0. for no jitter, 1. for jitter
use_viewdirs = True # use full 5D input instead of 3D
N_rand_per_sampler = 1024 * 4 # how many N_rand in get_item() function
lindisp = False # sampling linearly in disparity rather than depth
N_samples = 64 # number of coarse samples per ray
# resume_from = os.path.join(work_dir, 'latest.pth')
# load_from = os.path.join(work_dir, 'latest.pth')
model define network structure, a network is usually composed of embedder, mlp and render.
model = dict(
type='NerfNetwork', # network class name
cfg=dict(
phase='train', # 'train' or 'test'
N_importance=128, # number of additional fine samples per ray
is_perturb=is_perturb, # see above
chunk=1024 * 32, # mainly work for val, to avoid oom
bs_data='rays_o', # the data's shape indicates the real batch-size, this's also the num of rays
),
mlp=dict( # coarse mlp model
type='NerfMLP', # mlp class name
skips=[4],
netdepth=8, # layers in network
netwidth=256, # channels per layer
netchunk=1024 * 32, # to avoid oom
output_ch=5, # 5 if cfg.N_importance>0 else 4
use_viewdirs=use_viewdirs,
embedder=dict(
type='BaseEmbedder', # embedder class name
i_embed=0, # set 0 for default positional encoding, -1 for none
multires=10, # log2 of max freq for positional encoding (3D location)
multires_dirs=4, # this is 'multires_views' in origin codes, log2 of max freq for positional encoding (2D direction)
),
),
mlp_fine=dict( # fine model
type='NerfMLP',
skips=[4],
netdepth=8,
netwidth=256,
netchunk=1024 * 32,
output_ch=5,
use_viewdirs=use_viewdirs,
embedder=dict(
type='BaseEmbedder',
i_embed=0,
multires=10,
multires_dirs=4,
),
),
render=dict(
type='NerfRender', # render cloass name
white_bkgd=white_bkgd, # see above
raw_noise_std=0, # std dev of noise added to regularize sigma_a output, 1e0 recommended
),
)
data define network structure, a network is usually composed of embedder, mlp and render.
basedata_cfg = dict(
dataset_type=dataset_type,
datadir='data/nerf_synthetic/#DATANAME#',
half_res=True, # load blender synthetic data at 400x400 instead of 800x800
testskip=
8, # will load 1/N images from test/val sets, useful for large datasets like deepvoxels
white_bkgd=white_bkgd,
is_batching=False, # True for blender, False for llff
mode='train',
)
traindata_cfg = basedata_cfg.copy()
valdata_cfg = basedata_cfg.copy()
testdata_cfg = basedata_cfg.copy()
traindata_cfg.update(dict())
valdata_cfg.update(dict(mode='val'))
testdata_cfg.update(dict(mode='test', testskip=0))
train_pipeline = [
dict(type='Sample'),
dict(type='DeleteUseless', keys=['images', 'poses', 'i_data', 'idx']),
dict(type='ToTensor', keys=['pose', 'target_s']),
dict(type='GetRays'),
dict(type='SelectRays',
sel_n=N_rand_per_sampler,
precrop_iters=500,
precrop_frac=0.5), # in the first 500 iter, select rays inside center of image
dict(type='GetViewdirs', enable=use_viewdirs),
dict(type='ToNDC', enable=(not no_ndc)),
dict(type='GetBounds'),
dict(type='GetZvals', lindisp=lindisp,
N_samples=N_samples), # N_samples: number of coarse samples per ray
dict(type='PerturbZvals', enable=is_perturb),
dict(type='GetPts'),
dict(type='DeleteUseless', keys=['pose', 'iter_n']),
]
test_pipeline = [
dict(type='ToTensor', keys=['pose']),
dict(type='GetRays'),
dict(type='FlattenRays'),
dict(type='GetViewdirs', enable=use_viewdirs),
dict(type='ToNDC', enable=(not no_ndc)),
dict(type='GetBounds'),
dict(type='GetZvals', lindisp=lindisp, N_samples=N_samples),
dict(type='PerturbZvals', enable=False), # do not perturb when test
dict(type='GetPts'),
dict(type='DeleteUseless', keys=['pose']),
]
data = dict(
train_loader=dict(batch_size=1, num_workers=4),
train=dict(
type='SceneBaseDataset',
cfg=traindata_cfg,
pipeline=train_pipeline,
),
val_loader=dict(batch_size=1, num_workers=0),
val=dict(
type='SceneBaseDataset',
cfg=valdata_cfg,
pipeline=test_pipeline,
),
test_loader=dict(batch_size=1, num_workers=0),
test=dict(
type='SceneBaseDataset',
cfg=testdata_cfg,
pipeline=test_pipeline, # same pipeline as validation
),
)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。