代码拉取完成,页面将自动刷新
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32, run_distribute=False):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
Returns:
dataset
"""
if device_target == "GPU":
if do_train:
if run_distribute:
from mindspore.communication.management import get_rank, get_group_size
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
raise ValueError("Unsupported device_target.")
resize_height = config.image_height
resize_width = config.image_width
buffer_size = 1000
# define map operations
decode_op = C.Decode()
resize_crop_op = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
resize_op = C.Resize(256)
center_crop = C.CenterCrop(resize_width)
rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
change_swap_op = C.HWC2CHW()
if do_train:
trans = [resize_crop_op, horizontal_flip_op, rescale_op, normalize_op, change_swap_op]
else:
trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
# apply shuffle operations
data_set = data_set.shuffle(buffer_size=buffer_size)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset_cifar(dataset_path,
do_train,
repeat_num=1,
batch_size=32,
target="CPU"):
"""
create a train or evaluate cifar10 dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
data_set = ds.Cifar10Dataset(dataset_path,
num_parallel_workers=8,
shuffle=True)
# define map operations
if do_train:
trans = [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
C.Resize((224, 224)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.CutOut(112),
C.HWC2CHW()
]
else:
trans = [
C.Resize((224, 224)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=type_cast_op,
input_columns="label",
num_parallel_workers=8)
data_set = data_set.map(operations=trans,
input_columns="image",
num_parallel_workers=8)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。