1 Star 0 Fork 0

qiaodl/panoptic-deeplab-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train_panoptic.py 18.27 KB
一键复制 编辑 原始数据 按行查看 历史
toluwajosh 提交于 5年前 . code final cleanup
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
import argparse
import os
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from dataloaders import make_data_loader
from modeling.panoptic_deeplab import PanopticDeepLab
from modeling.sync_batchnorm.replicate import patch_replication_callback
from mypath import Path
from utils.calculate_weights import calculate_weigths_labels
from utils.loss import PanopticLosses, SegmentationLosses
from utils.lr_scheduler import LR_Scheduler
from utils.metrics import Evaluator
from utils.saver import Saver
from utils.summaries import TensorboardSummary
torch.cuda.empty_cache()
class Trainer(object):
def __init__(self, args):
self.args = args
# Define Saver
self.saver = Saver(args)
self.saver.save_experiment_config()
# Define Tensorboard Summary
self.summary = TensorboardSummary(self.saver.experiment_dir)
self.writer = self.summary.create_summary()
# Define Dataloader
kwargs = {"num_workers": args.workers, "pin_memory": True}
(
self.train_loader,
self.val_loader,
self.test_loader,
self.nclass,
) = make_data_loader(args, **kwargs)
# Define network
model = PanopticDeepLab(
num_classes=self.nclass,
backbone=args.backbone,
output_stride=args.out_stride,
sync_bn=args.sync_bn,
freeze_bn=args.freeze_bn,
)
if args.create_params:
train_params = [
{"params": model.get_1x_lr_params(), "lr": args.lr},
{"params": model.get_10x_lr_params(), "lr": args.lr * 10},
]
else:
train_params = model.parameters()
# Define Optimizer
optimizer = torch.optim.SGD(
train_params,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=args.nesterov,
)
# Define Criterion
# whether to use class balanced weights
if args.use_balanced_weights:
classes_weights_path = os.path.join(
Path.db_root_dir(args.dataset),
args.dataset + "_classes_weights.npy",
)
if os.path.isfile(classes_weights_path):
weight = np.load(classes_weights_path)
else:
weight = calculate_weigths_labels(
args.dataset, self.train_loader, self.nclass
)
weight = torch.from_numpy(weight.astype(np.float32))
else:
weight = None
self.criterion = PanopticLosses(
weight=weight, cuda=args.cuda
).build_loss(mode=args.loss_type)
self.model, self.optimizer = model, optimizer
# Define Evaluator
self.evaluator = Evaluator(self.nclass)
self.scheduler = ReduceLROnPlateau(
optimizer, mode="max", factor=0.89, patience=2, verbose=True
)
# Using cuda
if args.cuda:
self.model = torch.nn.DataParallel(
self.model, device_ids=self.args.gpu_ids
)
patch_replication_callback(self.model)
self.model = self.model.cuda()
# Resuming checkpoint
self.best_pred = 0.0
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError(
"=> no checkpoint found at '{}'".format(args.resume)
)
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint["epoch"]
if args.cuda:
self.model.module.load_state_dict(checkpoint["state_dict"])
else:
self.model.load_state_dict(checkpoint["state_dict"])
if not args.ft:
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.best_pred = checkpoint["best_pred"]
print(
"=> loaded checkpoint '{}' (epoch {})".format(
args.resume, checkpoint["epoch"]
)
)
# Clear start epoch if fine-tuning
if args.ft:
args.start_epoch = 0
def training(self, epoch):
train_loss = 0.0
semantic_loss_out = 0.0
center_loss_out = 0.0
reg_x_loss_out = 0.0
reg_y_loss_out = 0.0
self.model.train()
tbar = tqdm(self.train_loader)
num_img_tr = len(self.train_loader)
for i, sample in enumerate(tbar):
image, label, center, x_reg, y_reg = (
sample["image"],
sample["label"],
sample["center"],
sample["x_reg"],
sample["y_reg"],
)
if self.args.cuda:
image, label, center, x_reg, y_reg = (
image.cuda(),
label.cuda(),
center.cuda(),
x_reg.cuda(),
y_reg.cuda(),
)
self.optimizer.zero_grad()
try:
output = self.model(image)
except ValueError as identifier:
# catch error with wrong input size
print("Error: ", identifier)
continue
(
semantic_loss,
center_loss,
reg_x_loss,
reg_y_loss,
) = self.criterion.forward(output, label, center, x_reg, y_reg)
# total loss
loss = semantic_loss + center_loss + reg_x_loss + reg_y_loss
loss.backward()
self.optimizer.step()
train_loss += loss.item()
semantic_loss_out += semantic_loss.item()
center_loss_out += center_loss.item()
reg_x_loss_out += reg_x_loss.item()
reg_y_loss_out += reg_y_loss.item()
tbar.set_description(
"Losses -> Train: %.3f, Semantic: %.3f, Center: %.3f, x_reg: %.3f, y_reg: %.3f"
% (
train_loss / (i + 1),
semantic_loss_out / (i + 1),
center_loss_out / (i + 1),
reg_x_loss_out / (i + 1),
reg_y_loss_out / (i + 1),
)
)
self.writer.add_scalar(
"train/semantic_loss_iter",
semantic_loss.item(),
i + num_img_tr * epoch,
)
self.writer.add_scalar(
"train/center_loss_iter",
center_loss.item(),
i + num_img_tr * epoch,
)
self.writer.add_scalar(
"train/reg_x_loss_iter",
reg_x_loss.item(),
i + num_img_tr * epoch,
)
self.writer.add_scalar(
"train/reg_y_loss_iter",
reg_y_loss.item(),
i + num_img_tr * epoch,
)
self.writer.add_scalar(
"train/total_loss_iter", loss.item(), i + num_img_tr * epoch
)
# Show 10 * 3 inference results each epoch
if i % (num_img_tr // 10) == 0:
global_step = i + num_img_tr * epoch
self.summary.visualize_image(
self.writer,
self.args.dataset,
image,
label,
output[0],
global_step,
centers=output[1],
reg_x=output[2],
reg_y=output[3],
)
self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch)
print(
"[Epoch: %d, numImages: %5d]"
% (epoch, i * self.args.batch_size + image.data.shape[0])
)
print("Loss: %.3f" % train_loss)
if self.args.no_val:
# save checkpoint every epoch
is_best = False
self.saver.save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": self.model.module.state_dict(),
"optimizer": self.optimizer.state_dict(),
"best_pred": self.best_pred,
},
is_best,
)
def validation(self, epoch):
self.model.eval()
self.evaluator.reset()
tbar = tqdm(self.val_loader, desc="\r")
test_loss = 0.0
for i, sample in enumerate(tbar):
image, label, center, x_reg, y_reg = (
sample["image"],
sample["label"],
sample["center"],
sample["x_reg"],
sample["y_reg"],
)
if self.args.cuda:
image, label, center, x_reg, y_reg = (
image.cuda(),
label.cuda(),
center.cuda(),
x_reg.cuda(),
y_reg.cuda(),
)
with torch.no_grad():
try:
output = self.model(image)
except ValueError as identifier:
# catch error with wrong input size
print("Error: ", identifier)
continue
(
semantic_loss,
center_loss,
reg_x_loss,
reg_y_loss,
) = self.criterion.forward(output, label, center, x_reg, y_reg)
# total loss
loss = semantic_loss + center_loss + reg_x_loss + reg_y_loss
test_loss += loss.item()
tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1)))
pred = output[0].data.cpu().numpy()
label = label.cpu().numpy()
pred = np.argmax(pred, axis=1)
# Add batch sample into evaluator
self.evaluator.add_batch(label, pred)
# Fast test during the training
Acc = self.evaluator.Pixel_Accuracy()
Acc_class = self.evaluator.Pixel_Accuracy_Class()
mIoU = self.evaluator.Mean_Intersection_over_Union()
FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch)
self.writer.add_scalar("val/mIoU", mIoU, epoch)
self.writer.add_scalar("val/Acc", Acc, epoch)
self.writer.add_scalar("val/Acc_class", Acc_class, epoch)
self.writer.add_scalar("val/fwIoU", FWIoU, epoch)
print("Validation:")
print(
"[Epoch: %d, numImages: %5d]"
% (epoch, i * self.args.batch_size + image.data.shape[0])
)
print(
"Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
Acc, Acc_class, mIoU, FWIoU
)
)
print("Loss: %.3f" % test_loss)
new_pred = mIoU
if new_pred > self.best_pred:
is_best = True
self.best_pred = new_pred
self.saver.save_checkpoint(
{
"epoch": epoch + 1,
"state_dict": self.model.module.state_dict(),
"optimizer": self.optimizer.state_dict(),
"best_pred": self.best_pred,
},
is_best,
)
def main():
parser = argparse.ArgumentParser(
description="PyTorch Panoptic Deeplab Training"
)
parser.add_argument(
"--backbone",
type=str,
default="resnet",
choices=["xception_3stage", "mobilenet_3stage", "resnet_3stage"],
help="backbone name (default: resnet)",
)
parser.add_argument(
"--out-stride",
type=int,
default=16,
help="network output stride (default: 8)",
)
parser.add_argument(
"--dataset",
type=str,
default="pascal",
choices=["pascal", "coco", "cityscapes"],
help="dataset name (default: pascal)",
)
parser.add_argument(
"--task",
type=str,
default="segmentation",
choices=["segmentation", "panoptic"],
help="training task (default: segmentation)",
)
parser.add_argument(
"--use-sbd",
action="store_true",
default=True,
help="whether to use SBD dataset (default: True)",
)
parser.add_argument(
"--workers",
type=int,
default=4,
metavar="N",
help="dataloader threads",
)
parser.add_argument(
"--base-size", type=int, default=513, help="base image size"
)
parser.add_argument(
"--crop-size", type=int, default=513, help="crop image size"
)
parser.add_argument(
"--sync-bn",
type=bool,
default=None,
help="whether to use sync bn (default: auto)",
)
parser.add_argument(
"--freeze-bn",
type=bool,
default=False,
help="whether to freeze bn parameters (default: False)",
)
parser.add_argument(
"--loss-type",
type=str,
default="ce",
choices=["ce", "focal"],
help="loss func type (default: ce)",
)
# training hyper params
parser.add_argument(
"--epochs",
type=int,
default=None,
metavar="N",
help="number of epochs to train (default: auto)",
)
parser.add_argument(
"--start_epoch",
type=int,
default=0,
metavar="N",
help="start epochs (default:0)",
)
parser.add_argument(
"--batch-size",
type=int,
default=None,
metavar="N",
help="input batch size for \
training (default: auto)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=None,
metavar="N",
help="input batch size for \
testing (default: auto)",
)
parser.add_argument(
"--use-balanced-weights",
action="store_true",
default=False,
help="whether to use balanced weights (default: False)",
)
# optimizer params
parser.add_argument(
"--lr",
type=float,
default=None,
metavar="LR",
help="learning rate (default: auto)",
)
parser.add_argument(
"--lr-scheduler",
type=str,
default="poly",
choices=["poly", "step", "cos"],
help="lr scheduler mode: (default: poly)",
)
parser.add_argument(
"--momentum",
type=float,
default=0.9,
metavar="M",
help="momentum (default: 0.9)",
)
parser.add_argument(
"--weight-decay",
type=float,
default=5e-4,
metavar="M",
help="w-decay (default: 5e-4)",
)
parser.add_argument(
"--nesterov",
action="store_true",
default=False,
help="whether use nesterov (default: False)",
)
# cuda, seed and logging
parser.add_argument(
"--no-cuda",
action="store_true",
default=False,
help="disables CUDA training",
)
parser.add_argument(
"--gpu-ids",
type=str,
default="0",
help="use which gpu to train, must be a \
comma-separated list of integers only (default=0)",
)
parser.add_argument(
"--seed",
type=int,
default=1,
metavar="S",
help="random seed (default: 1)",
)
# checking point
parser.add_argument(
"--resume",
type=str,
default=None,
help="put the path to resuming file if needed",
)
parser.add_argument(
"--checkname", type=str, default=None, help="set the checkpoint name"
)
# finetuning pre-trained models
parser.add_argument(
"--ft",
action="store_true",
default=False,
help="finetuning on a different dataset",
)
# evaluation option
parser.add_argument(
"--eval-interval",
type=int,
default=1,
help="evaluation interval (default: 1)",
)
parser.add_argument(
"--no-val",
action="store_true",
default=False,
help="skip validation during training",
)
parser.add_argument(
"--create-params",
action="store_true",
default=False,
help="create manual parameter groups for optimizer",
)
parser.add_argument(
"--lr-step",
type=float,
default=50,
help="create manual parameter groups for optimizer",
)
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(",")]
except ValueError:
raise ValueError(
"Argument --gpu_ids must be a comma-separated list of integers only"
)
if args.sync_bn is None:
if args.cuda and len(args.gpu_ids) > 1:
args.sync_bn = True
else:
args.sync_bn = False
# default settings for epochs, batch_size and lr
if args.epochs is None:
epoches = {
"coco": 30,
"cityscapes": 200,
"pascal": 50,
}
args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None:
args.batch_size = 4 * len(args.gpu_ids)
if args.test_batch_size is None:
args.test_batch_size = args.batch_size
if args.lr is None:
lrs = {
"coco": 0.1,
"cityscapes": 0.01,
"pascal": 0.007,
}
args.lr = (
lrs[args.dataset.lower()]
/ (4 * len(args.gpu_ids))
* args.batch_size
)
if args.checkname is None:
args.checkname = "panoptic-deeplab-" + str(args.backbone)
print(args)
torch.manual_seed(args.seed)
trainer = Trainer(args)
print("Starting Epoch:", trainer.args.start_epoch)
print("Total Epoches:", trainer.args.epochs)
for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
trainer.training(epoch)
if not trainer.args.no_val and epoch % args.eval_interval == (
args.eval_interval - 1
):
trainer.validation(epoch)
trainer.scheduler.step(trainer.best_pred)
print('\n=>Epoches %i, \
previous best = %.4f' % (epoch, trainer.best_pred))
trainer.writer.close()
if __name__ == "__main__":
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/qiaodl/panoptic-deeplab-pytorch.git
git@gitee.com:qiaodl/panoptic-deeplab-pytorch.git
qiaodl
panoptic-deeplab-pytorch
panoptic-deeplab-pytorch
master

搜索帮助