1 Star 0 Fork 0

qiaodl/panoptic-deeplab-pytorch

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
loss.py 4.28 KB
一键复制 编辑 原始数据 按行查看 历史
toluwajosh 提交于 2020-03-23 05:09 +08:00 . updated readme and removed unecessary comments
import torch
import torch.nn as nn
mse_loss = nn.MSELoss()
l1_loss = nn.L1Loss()
class SegmentationLosses(object):
def __init__(
self,
weight=None,
size_average=True,
batch_average=True,
ignore_index=255,
cuda=False,
):
self.ignore_index = ignore_index
self.weight = weight
self.size_average = size_average
self.batch_average = batch_average
self.cuda = cuda
def build_loss(self, mode="ce"):
"""Choices: ['ce' or 'focal']"""
if mode == "ce":
return self.CrossEntropyLoss
elif mode == "focal":
return self.FocalLoss
else:
raise NotImplementedError
def CrossEntropyLoss(self, logit, target):
n, c, h, w = logit.size()
criterion = nn.CrossEntropyLoss(
weight=self.weight,
ignore_index=self.ignore_index,
size_average=self.size_average,
)
if self.cuda:
criterion = criterion.cuda()
loss = criterion(logit, target.long())
if self.batch_average:
loss /= n
return loss
def FocalLoss(self, logit, target, gamma=2, alpha=0.5):
n, c, h, w = logit.size()
criterion = nn.CrossEntropyLoss(
weight=self.weight,
ignore_index=self.ignore_index,
size_average=self.size_average,
)
if self.cuda:
criterion = criterion.cuda()
logpt = -criterion(logit, target.long())
pt = torch.exp(logpt)
if alpha is not None:
logpt *= alpha
loss = -((1 - pt) ** gamma) * logpt
if self.batch_average:
loss /= n
return loss
class PanopticLosses(object):
def __init__(
self,
weight=None,
size_average=True,
batch_average=True,
ignore_index=255,
cuda=False,
):
self.ignore_index = ignore_index
self.weight = weight
self.size_average = size_average
self.batch_average = batch_average
self.cuda = cuda
# by default
self.semantic_loss = self.CrossEntropyLoss
def build_loss(self, mode="ce"):
"""Choices: ['ce' or 'focal']"""
if mode == "ce":
self.semantic_loss = self.CrossEntropyLoss
return self
elif mode == "focal":
self.semantic_loss = self.FocalLoss
return self
else:
raise NotImplementedError
def CrossEntropyLoss(self, logit, target):
n, c, h, w = logit.size()
criterion = nn.CrossEntropyLoss(
weight=self.weight, ignore_index=self.ignore_index,
)
if self.cuda:
criterion = criterion.cuda()
loss = criterion(logit, target.long())
if self.batch_average:
loss /= n
return loss
def FocalLoss(self, logit, target, gamma=2, alpha=0.5):
n, c, h, w = logit.size()
criterion = nn.CrossEntropyLoss(
weight=self.weight,
ignore_index=self.ignore_index,
size_average=self.size_average,
)
if self.cuda:
criterion = criterion.cuda()
logpt = -criterion(logit, target.long())
pt = torch.exp(logpt)
if alpha is not None:
logpt *= alpha
loss = -((1 - pt) ** gamma) * logpt
if self.batch_average:
loss /= n
return loss
def forward(self, prediction, label, center, x_reg, y_reg):
semantic_predict, center_predict, x_reg_pred, y_reg_pred = prediction
# calculate losses
semantic_loss = self.semantic_loss(semantic_predict, label)
center_loss = mse_loss(center_predict, center.unsqueeze(1))
x_reg_loss = mse_loss(x_reg_pred, x_reg.unsqueeze(1))
y_reg_loss = mse_loss(y_reg_pred, y_reg.unsqueeze(1))
return (
semantic_loss * 10.0,
center_loss * 0.15,
x_reg_loss * 0.01,
y_reg_loss * 0.01,
)
if __name__ == "__main__":
loss = SegmentationLosses(cuda=True)
a = torch.rand(1, 3, 7, 7).cuda()
b = torch.rand(1, 7, 7).cuda()
print(loss.CrossEntropyLoss(a, b).item())
print(loss.FocalLoss(a, b, gamma=0, alpha=None).item())
print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item())
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/qiaodl/panoptic-deeplab-pytorch.git
git@gitee.com:qiaodl/panoptic-deeplab-pytorch.git
qiaodl
panoptic-deeplab-pytorch
panoptic-deeplab-pytorch
master

搜索帮助