78 Star 603 Fork 1.2K

Ascend/pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_anchor_generator.py 3.40 KB
一键复制 编辑 原始数据 按行查看 历史
王夏夏 提交于 3年前 . add 5个亲和库
# Copyright (c) 2020 Huawei Technologies Co., Ltd
# All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.contrib.function import npu_single_level_responsible_flags
class TestAnchorGenerator(TestCase):
def single_level_responsible_flags(self,
featmap_size,
gt_bboxes,
stride,
num_base_anchors,
device='cpu'):
"""Generate the responsible flags of anchor in a single feature map.
Args:
featmap_size (tuple[int]): The size of feature maps.
gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
stride (tuple(int)): stride of current level
num_base_anchors (int): The number of base anchors.
device (str, optional): Device where the flags will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: The valid flags of each anchor in a single level \
feature map.
"""
feat_h, feat_w = featmap_size
gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
# row major indexing
gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
responsible_grid = torch.zeros(
feat_h * feat_w, dtype=torch.uint8, device=device)
responsible_grid[gt_bboxes_grid_idx] = 1
responsible_grid = responsible_grid[:, None].expand(
responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
return responsible_grid
def test_anchor_generator(self):
featmap_sizes = [[10, 10], [20, 20], [40, 40]]
stride = [[32, 32], [16, 16], [8, 8]]
gt_bboxes = torch.randint(0, 100, size=(128, 4))
num_base_anchors = 3
featmap_level = len(featmap_sizes)
for i in range(featmap_level):
gt_bboxes = gt_bboxes.npu()
cpuout = self.single_level_responsible_flags(featmap_sizes[i],
gt_bboxes,
stride[i],
num_base_anchors)
npuout = npu_single_level_responsible_flags(featmap_sizes[i],
gt_bboxes,
stride[i],
num_base_anchors)
self.assertRtolEqual(cpuout, npuout.cpu())
if __name__ == "__main__":
run_tests()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch
v2.0.1-5.0.rc3

搜索帮助