代码拉取完成,页面将自动刷新
# Copyright (c) 2020 Huawei Technologies Co., Ltd
# All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.contrib.function import npu_single_level_responsible_flags
class TestAnchorGenerator(TestCase):
def single_level_responsible_flags(self,
featmap_size,
gt_bboxes,
stride,
num_base_anchors,
device='cpu'):
"""Generate the responsible flags of anchor in a single feature map.
Args:
featmap_size (tuple[int]): The size of feature maps.
gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
stride (tuple(int)): stride of current level
num_base_anchors (int): The number of base anchors.
device (str, optional): Device where the flags will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: The valid flags of each anchor in a single level \
feature map.
"""
feat_h, feat_w = featmap_size
gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
# row major indexing
gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
responsible_grid = torch.zeros(
feat_h * feat_w, dtype=torch.uint8, device=device)
responsible_grid[gt_bboxes_grid_idx] = 1
responsible_grid = responsible_grid[:, None].expand(
responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
return responsible_grid
def test_anchor_generator(self):
featmap_sizes = [[10, 10], [20, 20], [40, 40]]
stride = [[32, 32], [16, 16], [8, 8]]
gt_bboxes = torch.randint(0, 100, size=(128, 4))
num_base_anchors = 3
featmap_level = len(featmap_sizes)
for i in range(featmap_level):
gt_bboxes = gt_bboxes.npu()
cpuout = self.single_level_responsible_flags(featmap_sizes[i],
gt_bboxes,
stride[i],
num_base_anchors)
npuout = npu_single_level_responsible_flags(featmap_sizes[i],
gt_bboxes,
stride[i],
num_base_anchors)
self.assertRtolEqual(cpuout, npuout.cpu())
if __name__ == "__main__":
run_tests()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。