99 Star 800 Fork 1.4K

MindSpore / models

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
yolo_dataset.py 7.34 KB
一键复制 编辑 原始数据 按行查看 历史
zhaoting 提交于 2022-11-17 14:18 . move official models
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV3 dataset."""
import os
import multiprocessing
import cv2
from PIL import Image
import numpy as np
from pycocotools.coco import COCO
import mindspore.dataset as ds
from src.distributed_sampler import DistributedSampler
from src.transforms import reshape_fn, MultiScaleTrans
min_keypoints_per_image = 10
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def has_valid_annotation(anno):
"""Check annotation file."""
# if it's empty, there is no annotation
if not anno:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different criteria for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
class COCOYoloDataset:
"""YOLOV3 Dataset for COCO."""
def __init__(self, root, ann_file, remove_images_without_annotations=True,
filter_crowd_anno=True, is_training=True):
self.coco = COCO(ann_file)
self.root = root
self.img_ids = list(sorted(self.coco.imgs.keys()))
self.filter_crowd_anno = filter_crowd_anno
self.is_training = is_training
# filter images without any annotations
if remove_images_without_annotations:
img_ids = []
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if has_valid_annotation(anno):
img_ids.append(img_id)
self.img_ids = img_ids
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
self.cat_ids_to_continuous_ids = {
v: i for i, v in enumerate(self.coco.getCatIds())
}
self.continuous_ids_cat_ids = {
v: k for k, v in self.cat_ids_to_continuous_ids.items()
}
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
generated by the image's annotation. img is a PIL image.
"""
coco = self.coco
img_id = self.img_ids[index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
if not self.is_training:
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
return img, img_id
img = np.fromfile(os.path.join(self.root, img_path), dtype="int8")
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
else:
annos = [anno for anno in target]
target = {}
boxes = [anno["bbox"] for anno in annos]
target["bboxes"] = boxes
classes = [anno["category_id"] for anno in annos]
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
target["labels"] = classes
bboxes = target['bboxes']
labels = target['labels']
out_target = []
for bbox, label in zip(bboxes, labels):
tmp = []
# convert to [x_min y_min x_max y_max]
bbox = self._convetTopDown(bbox)
tmp.extend(bbox)
tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label]
out_target.append(tmp)
return img, out_target, [], [], [], [], [], []
def __len__(self):
return len(self.img_ids)
def _convetTopDown(self, bbox):
x_min = bbox[0]
y_min = bbox[1]
w = bbox[2]
h = bbox[3]
return [x_min, y_min, x_min+w, y_min+h]
def create_yolo_dataset(image_dir, anno_path, batch_size, device_num, rank,
config=None, is_training=True, shuffle=True):
"""Create dataset for YOLOV3."""
cv2.setNumThreads(0)
if is_training:
filter_crowd = True
remove_empty_anno = True
else:
filter_crowd = False
remove_empty_anno = False
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
hwc_to_chw = ds.vision.HWC2CHW()
config.dataset_size = len(yolo_dataset)
cores = multiprocessing.cpu_count()
num_parallel_workers = int(cores / device_num)
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
if is_training:
multi_scale_trans = MultiScaleTrans(config, device_num)
dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
"gt_box1", "gt_box2", "gt_box3"]
if device_num != 8:
dataset = ds.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler)
dataset = dataset.map(operations=ds.vision.Decode(), input_columns=["image"])
dataset = dataset.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
else:
dataset = ds.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler)
dataset = dataset.map(operations=ds.vision.Decode(), input_columns=["image"])
dataset = dataset.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
else:
dataset = ds.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
dataset = dataset.map(operations=compose_map_func, input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
num_parallel_workers=8)
dataset = dataset.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
r2.2

搜索帮助