1 Star 0 Fork 0

shopping / nanodet_comments

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
gfocal_loss.py 7.75 KB
一键复制 编辑 原始数据 按行查看 历史
shopping 提交于 2022-06-10 08:16 . 损失函数核心部分
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import weighted_loss
@weighted_loss
def quality_focal_loss(pred, target, beta=2.0):
r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (torch.Tensor): Predicted joint representation of classification
and quality (IoU) estimation with shape (N, C), C is the number of
classes.
target (tuple([torch.Tensor])): Target category label with shape (N,)
and target quality label with shape (N,).
beta (float): The beta parameter for calculating the modulating factor.
Defaults to 2.0.
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
assert (
len(target) == 2
), """target for QFL must be a tuple of two elements,
including category label and quality label, respectively"""
# label denotes the category id, score denotes the quality score
label, score = target # bbox 类别标签及 bbox 与 bbox 标签的 IOU 值,此处被称为 质量值
# negatives are supervised by 0 quality score
pred_sigmoid = pred.sigmoid() # [batch * h * w, 80]
scale_factor = pred_sigmoid
zerolabel = scale_factor.new_zeros(pred.shape)
loss = F.binary_cross_entropy_with_logits(
pred, zerolabel, reduction="none"
) * scale_factor.pow(beta)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = pred.size(1)
pos = torch.nonzero((label >= 0) & (label < bg_class_ind), as_tuple=False).squeeze(
1
)
pos_label = label[pos].long()
# positives are supervised by bbox quality (IoU) score
scale_factor = score[pos] - pred_sigmoid[pos, pos_label] # 差值做权值因子,二者差距越大,权值越大,在线难例挖掘
loss[pos, pos_label] = F.binary_cross_entropy_with_logits( # 对应类别处做交叉熵损失
pred[pos, pos_label], score[pos], reduction="none" # 类别的标签是 IOU 值,且标签是动态的
) * scale_factor.abs().pow(beta)
loss = loss.sum(dim=1, keepdim=False)
return loss
@weighted_loss
def distribution_focal_loss(pred, label):
r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
Qualified and Distributed Bounding Boxes for Dense Object Detection
<https://arxiv.org/abs/2006.04388>`_.
Args:
pred (torch.Tensor): Predicted general distribution of bounding boxes
(before softmax) with shape (N, n+1), n is the max value of the
integral set `{0, ..., n}` in paper.
label (torch.Tensor): Target distance label for bounding boxes with
shape (N,).
Returns:
torch.Tensor: Loss tensor with shape (N,).
"""
dis_left = label.long() # label 是个小数,这里是将其整数部分赋给 dis_left
dis_right = dis_left + 1 # label 处于 dis_left 与 dis_right 两个整数之间
weight_left = dis_right.float() - label # label 与其右界的差值,label越靠近 dis_left,该值越大
weight_right = label - dis_left.float() # label 与其左界的差值,label越靠近 dis_right,该值越大
# F.cross_entropy,它对输入的 pred 会先做softmax、log,再求损失。dis_left是正类别的序号,不是一个one-shot数组。
# 让 label 被两个交叉熵函数给拉扯,使其处于 dis_left 与 dis_right 两个整数之间
loss = (
F.cross_entropy(pred, dis_left, reduction="none") * weight_left
+ F.cross_entropy(pred, dis_right, reduction="none") * weight_right
)
return loss
class QualityFocalLoss(nn.Module):
r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss:
Learning Qualified and Distributed Bounding Boxes for Dense Object
Detection <https://arxiv.org/abs/2006.04388>`_.
Args:
use_sigmoid (bool): Whether sigmoid operation is conducted in QFL.
Defaults to True.
beta (float): The beta parameter for calculating the modulating factor.
Defaults to 2.0.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Loss weight of current loss.
"""
def __init__(self, use_sigmoid=True, beta=2.0, reduction="mean", loss_weight=1.0):
super(QualityFocalLoss, self).__init__()
assert use_sigmoid is True, "Only sigmoid in QFL supported now."
self.use_sigmoid = use_sigmoid # True
self.beta = beta # 2.0
self.reduction = reduction
self.loss_weight = loss_weight # 1.0
def forward(
self, pred, target, weight=None, avg_factor=None, reduction_override=None
):
"""Forward function.
Args:
pred (torch.Tensor): Predicted joint representation of
classification and quality (IoU) estimation with shape (N, C),
C is the number of classes.
target (tuple([torch.Tensor])): Target category label with shape
(N,) and target quality label with shape (N,).
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, "none", "mean", "sum")
reduction = reduction_override if reduction_override else self.reduction
if self.use_sigmoid:
loss_cls = self.loss_weight * quality_focal_loss(
pred,
target,
weight,
beta=self.beta,
reduction=reduction,
avg_factor=avg_factor,
)
else:
raise NotImplementedError
return loss_cls
class DistributionFocalLoss(nn.Module):
r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss:
Learning Qualified and Distributed Bounding Boxes for Dense Object
Detection <https://arxiv.org/abs/2006.04388>`_.
Args:
reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
loss_weight (float): Loss weight of current loss.
"""
def __init__(self, reduction="mean", loss_weight=1.0):
super(DistributionFocalLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(
self, pred, target, weight=None, avg_factor=None, reduction_override=None
):
"""Forward function.
Args:
pred (torch.Tensor): Predicted general distribution of bounding
boxes (before softmax) with shape (N, n+1), n is the max value
of the integral set `{0, ..., n}` in paper.
target (torch.Tensor): Target distance label for bounding boxes
with shape (N,).
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, "none", "mean", "sum")
reduction = reduction_override if reduction_override else self.reduction
loss_cls = self.loss_weight * distribution_focal_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor
)
return loss_cls
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/shopping-tang/nanodet_comments.git
git@gitee.com:shopping-tang/nanodet_comments.git
shopping-tang
nanodet_comments
nanodet_comments
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891