1 Star 0 Fork 0

shopping / nanodet_comments

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
dsl_assigner.py 9.69 KB
一键复制 编辑 原始数据 按行查看 历史
shopping 提交于 2022-11-29 09:08 . update dsl_assigner.py.
import torch
import torch.nn.functional as F
from ...loss.iou_loss import bbox_overlaps
from .assign_result import AssignResult
from .base_assigner import BaseAssigner
class DynamicSoftLabelAssigner(BaseAssigner):
"""Computes matching between predictions and ground truth with
dynamic soft label assignment.
Args:
topk (int): Select top-k predictions to calculate dynamic k
best matchs for each gt. Default 13.
iou_factor (float): The scale factor of iou cost. Default 3.0.
"""
def __init__(self, topk=13, iou_factor=3.0):
self.topk = topk
self.iou_factor = iou_factor
def assign(
self,
pred_scores,
priors,
decoded_bboxes,
gt_bboxes,
gt_labels,
):
"""Assign gt to priors with dynamic soft label assignment.
Args:
pred_scores (Tensor): Classification scores of one image,
a 2D-Tensor with shape [num_priors, num_classes]
priors (Tensor): All priors of one image, a 2D-Tensor with shape
[num_priors, 4] in [cx, xy, stride_w, stride_y] format.
decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
[num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth labels of one image, a Tensor
with shape [num_gts].
Returns:
:obj:`AssignResult`: The assigned result.
"""
INF = 100000000
num_gt = gt_bboxes.size(0)
num_bboxes = decoded_bboxes.size(0)
# assign 0 by default,每个 cell(anchor) 的标签分配结果储存数组
assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ), 0, dtype=torch.long)
prior_center = priors[:, :2] # image 坐标系下,cell 左上角坐标, shape = (num_priors, 2)
# 如果结果全为正,说明 cell 的左上角在 gt 里面
lt_ = prior_center[:, None] - gt_bboxes[:, :2] # 所有cell左上角 与 gt左上角 的差值, shape = (num_priors, num_gt, 2)
rb_ = gt_bboxes[:, 2:] - prior_center[:, None] # gt右下角 与 所有cell左上角 的差值, shape = (num_priors, num_gt, 2)
deltas = torch.cat([lt_, rb_], dim=-1) # (num_priors, num_gt, 4),坐标差值 (delta_x1, delta_y1, delta_x2. delta_y2)
# 判断 每个cell左上角 是否在 gt 里面。 先挑出4个坐标差值的最小值,再看其是否大于0,如果是置为 True.
is_in_gts = deltas.min(dim=-1).values > 0 # shape = (num_priors, num_gt)
# 如果 cell的左上角 至少在一个 gt 里面,则为 True ,否则为 False.
valid_mask = is_in_gts.sum(dim=1) > 0 # shape = (num_priors, ),得到 每个cell的左上角 与 所有gt 的关系数组
valid_decoded_bbox = decoded_bboxes[valid_mask] # 筛选,shape = (num_valid, 4)
valid_pred_scores = pred_scores[valid_mask] # 筛选, shape = (num_valid, num_classes)
num_valid = valid_decoded_bbox.size(0) # 得到符合条件的 bbox 个数
if num_gt == 0 or num_bboxes == 0 or num_valid == 0:
# No ground truth or boxes, return empty assignment
max_overlaps = decoded_bboxes.new_zeros((num_bboxes,))
if num_gt == 0:
# No truth, assign everything to background
assigned_gt_inds[:] = 0 # 如果没有 gt ,则全是背景,标签分配为 0
if gt_labels is None:
assigned_labels = None
else:
assigned_labels = decoded_bboxes.new_full( # 背景的类别序号是 -1
(num_bboxes,), -1, dtype=torch.long
)
return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
)
# shape = (num_valid, num_gt). IOU越大,匹配效果越好,我们需要 IOU 大的 bbox结果。
pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes) # 计算符合条件 bbox 与 gt 的 IOU值
# 转为 IOU 损失,IOU越大(靠近1),损失越小
iou_cost = -torch.log(pairwise_ious + 1e-7)
gt_onehot_label = (
F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1]) # shape = (num_gts, num_classes)
.float()
.unsqueeze(0) # shape = (1, num_gts, num_classes)
.repeat(num_valid, 1, 1) # shape = (num_valid, num_gts, num_classes)
)
# shape 变为 (num_valid, num_gt, num_classes)
valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
# 沿用了 gfl 的思路,用 IOU 值做分类的 label
soft_label = gt_onehot_label * pairwise_ious[..., None]
scale_factor = soft_label - valid_pred_scores
# 还是 gfl 的思路
cls_cost = F.binary_cross_entropy(
valid_pred_scores, soft_label, reduction="none"
) * scale_factor.abs().pow(2.0)
cls_cost = cls_cost.sum(dim=-1)
# shape = (num_valid, num_gt)。这个cost数组是分类损失与bbox损失的综合损失。
cost_matrix = cls_cost + iou_cost * self.iou_factor # IOU更重视,毕竟当前是标签分配阶段,IOU越大,标签与bbox越匹配
# matched_pred_ious :shape = (bbox_match_gt_num, ),获得了 bbox 与其匹配的 gt 的 IOU 值
# matched_gt_inds :shape = (bbox_match_gt_num, ),里面的元素是 bbox 匹配到的 gt 的列索引
matched_pred_ious, matched_gt_inds = self.dynamic_k_matching(
cost_matrix, pairwise_ious, num_gt, valid_mask
)
# convert to AssignResult format
# dynamic_k_matching函数里对valid_mask的修改不会改变其内存地址,所以此处的valid_mask与函数里的是同一个
assigned_gt_inds[valid_mask] = matched_gt_inds + 1 # 背景是0,前景从1开始算
assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
max_overlaps = assigned_gt_inds.new_full(
(num_bboxes,), -INF, dtype=torch.float32
)
max_overlaps[valid_mask] = matched_pred_ious
return AssignResult(
num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
)
def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
"""Use sum of topk pred iou as dynamic k. Refer from OTA and YOLOX.
Args:
cost (Tensor): Cost matrix. shape = (num_valid, num_gt)
pairwise_ious (Tensor): Pairwise iou matrix. shape = (num_valid, num_gt)
num_gt (int): Number of gt.
valid_mask (Tensor): Mask for valid bboxes. shape = (num_priors, )
"""
matching_matrix = torch.zeros_like(cost) # shape = (num_valid, num_gt)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.topk, pairwise_ious.size(0)) # 两个数之间选个最小值,免得报错
# 降序输出 每个gt与所有候选bbox的 前topk 个 IOU值。 shape = (candidate_topk, num_gt)
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
# calculate dynamic k for each gt. 先得到每个 gt 的前topk个IOU值之和,再取整,最后做截断。得到每个gt IOU之和的整数部分
# shape = (num_gt, ) 这个数组的每个元素是对应gt可以与几个bbox做匹配,最小值为1是因为gt肯定至少有一个bbox与之匹配
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False
) # 升序 动态K,选出损失最小的前 dynamic_k 个 bbox
matching_matrix[:, gt_idx][pos_idx] = 1.0 # gt 与哪个bbox匹配,元素值置为1
del topk_ious, dynamic_ks, pos_idx
# shape = (num_valid, )
prior_match_gt_mask = matching_matrix.sum(1) > 1 # 大于 1 说明存在某些 bbox 会与多个 gt 匹配。
if prior_match_gt_mask.sum() > 0: # 判断是否有 bbox 匹配到 多个gt 的情况
# 下面几行的作用是 去除匹配多个 gt 的 bbox 情况,每个 bbox 只匹配一个 gt
cost_min, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1) # 选择损失最小的那个 gt 与 bbox 做匹配
matching_matrix[prior_match_gt_mask, :] *= 0.0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0 # 除损失最小的gt外,其他都置为 0
# 上面步骤结束后,matching_matrix 已经赋值结束,为 1 说明匹配到了,为 0 说明没有匹配到
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(1) > 0.0 # 大于 0 是前景,等于 0 是背景。 shape = (num_valid, )
# valid_mask[valid_mask] 的 shape 与 fg_mask_inboxes 是一样的,都是 (num_valid, )
# valid_mask[valid_mask] 里的 bbox 都是左上角在 gt 内部的,是个很粗糙的分配结果。
# fg_mask_inboxes 才是最终的分配结果,所以要同步给 valid_mask。
valid_mask[valid_mask.clone()] = fg_mask_inboxes # 此处的赋值不会新建一块内存,所以此处的 valid_mask 与 114 行的一样
# shape = (bbox_match_gt_num, ), 里面的元素是 bbox 匹配到的 gt 的列索引
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
# 获得了 bbox 与其匹配的 gt 的 IOU 值
matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes]
return matched_pred_ious, matched_gt_inds
Python
1
https://gitee.com/shopping-tang/nanodet_comments.git
git@gitee.com:shopping-tang/nanodet_comments.git
shopping-tang
nanodet_comments
nanodet_comments
master

搜索帮助