Ai
6 Star 69 Fork 19

OpenMMLab/mmsegmentation

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_iou_metric.py 3.66 KB
一键复制 编辑 原始数据 按行查看 历史
Miao Zheng 提交于 2023-03-17 22:58 +08:00 . [Features]Support dump segment predition (#2712)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import PixelData
from mmseg.evaluation import IoUMetric
from mmseg.structures import SegDataSample
class TestIoUMetric(TestCase):
def _demo_mm_inputs(self,
batch_size=2,
image_shapes=(3, 64, 64),
num_classes=5):
"""Create a superset of inputs needed to run test or train batches.
Args:
batch_size (int): batch size. Default to 2.
image_shapes (List[tuple], Optional): image shape.
Default to (3, 64, 64)
num_classes (int): number of different classes.
Default to 5.
"""
if isinstance(image_shapes, list):
assert len(image_shapes) == batch_size
else:
image_shapes = [image_shapes] * batch_size
data_samples = []
for idx in range(batch_size):
image_shape = image_shapes[idx]
_, h, w = image_shape
data_sample = SegDataSample()
gt_semantic_seg = np.random.randint(
0, num_classes, (1, h, w), dtype=np.uint8)
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
gt_sem_seg_data = dict(data=gt_semantic_seg)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
data_samples.append(data_sample.to_dict())
return data_samples
def _demo_mm_model_output(self,
data_samples,
batch_size=2,
image_shapes=(3, 64, 64),
num_classes=5):
_, h, w = image_shapes
for data_sample in data_samples:
data_sample['seg_logits'] = dict(
data=torch.randn(num_classes, h, w))
data_sample['pred_sem_seg'] = dict(
data=torch.randint(0, num_classes, (1, h, w)))
data_sample[
'img_path'] = 'tests/data/pseudo_dataset/imgs/00000_img.jpg'
return data_samples
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
data_samples = self._demo_mm_inputs()
data_samples = self._demo_mm_model_output(data_samples)
iou_metric = IoUMetric(iou_metrics=['mIoU'])
iou_metric.dataset_meta = dict(
classes=['wall', 'building', 'sky', 'floor', 'tree'],
label_map=dict(),
reduce_zero_label=False)
iou_metric.process([0] * len(data_samples), data_samples)
res = iou_metric.evaluate(2)
self.assertIsInstance(res, dict)
# test save segment file in output_dir
iou_metric = IoUMetric(iou_metrics=['mIoU'], output_dir='tmp')
iou_metric.dataset_meta = dict(
classes=['wall', 'building', 'sky', 'floor', 'tree'],
label_map=dict(),
reduce_zero_label=False)
iou_metric.process([0] * len(data_samples), data_samples)
assert osp.exists('tmp')
assert osp.isfile('tmp/00000_img.png')
shutil.rmtree('tmp')
# test format_only
iou_metric = IoUMetric(
iou_metrics=['mIoU'], output_dir='tmp', format_only=True)
iou_metric.dataset_meta = dict(
classes=['wall', 'building', 'sky', 'floor', 'tree'],
label_map=dict(),
reduce_zero_label=False)
iou_metric.process([0] * len(data_samples), data_samples)
assert iou_metric.results == []
assert osp.exists('tmp')
assert osp.isfile('tmp/00000_img.png')
shutil.rmtree('tmp')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/open-mmlab/mmsegmentation.git
git@gitee.com:open-mmlab/mmsegmentation.git
open-mmlab
mmsegmentation
mmsegmentation
main

搜索帮助