Ascend
GPU
CPU
Data Preparation
MindSpore provides multiple samplers to help you sample datasets for various purposes to meet training requirements and solve problems such as oversized datasets and uneven distribution of sample categories. You only need to import the sampler object when loading the dataset for sampling the data.
The following table lists part of the common samplers supported by MindSpore. In addition, you can define your own sampler class as required. For more samplers, see MindSpore API.
Sampler | Description |
---|---|
RandomSampler | Random sampler, which randomly samples a specified amount of data from a dataset. |
WeightedRandomSampler | Weighted random sampler, which randomly samples a specified amount of data from the first N samples based on the specified probability list with the length of N. |
SubsetRandomSampler | Subset random sampler, which randomly samples a specified amount of data within a specified index range. |
PKSampler | PK sampler, which samples K pieces of data from the specified P categories. |
DistributedSampler | Distributed sampler, which samples dataset shards in distributed training. |
The following uses the CIFAR-10 as an example to introduce several common MindSpore samplers.
Download the CIFAR-10 data set and unzip it to the specified path, execute the following command:
wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz --no-check-certificate
mkdir -p datasets
tar -xzf cifar-10-binary.tar.gz -C datasets
mkdir -p datasets/cifar-10-batches-bin/train datasets/cifar-10-batches-bin/test
mv -f datasets/cifar-10-batches-bin/test_batch.bin datasets/cifar-10-batches-bin/test
mv -f datasets/cifar-10-batches-bin/data_batch*.bin datasets/cifar-10-batches-bin/batches.meta.txt datasets/cifar-10-batches-bin/train
tree ./datasets/cifar-10-batches-bin
./datasets/cifar-10-batches-bin
├── readme.html
├── test
│ └── test_batch.bin
└── train
├── batches.meta.txt
├── data_batch_1.bin
├── data_batch_2.bin
├── data_batch_3.bin
├── data_batch_4.bin
└── data_batch_5.bin
2 directories, 8 files
Randomly samples a specified amount of data from the index sequence.
The following example uses a random sampler to randomly sample five pieces of data from the CIFAR-10 dataset with and without replacement, and displays shapes and labels of the loaded data.
import mindspore.dataset as ds
ds.config.set_seed(0)
DATA_DIR = "./datasets/cifar-10-batches-bin/train/"
print("------ Without Replacement ------")
sampler = ds.RandomSampler(num_samples=5)
dataset1 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)
for data in dataset1.create_dict_iterator():
print("Image shape:", data['image'].shape, ", Label:", data['label'])
print("------ With Replacement ------")
sampler = ds.RandomSampler(replacement=True, num_samples=5)
dataset2 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)
for data in dataset2.create_dict_iterator():
print("Image shape:", data['image'].shape, ", Label:", data['label'])
The output is as follows:
------ Without Replacement ------
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 4
------ With Replacement ------
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 3
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 6
Specifies a sampling probability list with the length of N and randomly samples a specified amount of data from the first N samples based on the probability.
The following example uses a weighted random sampler to obtain 6 samples based on probability from the first 10 samples in the CIFAR-10 dataset, and displays shapes and labels of the loaded data.
import mindspore.dataset as ds
ds.config.set_seed(1)
DATA_DIR = "./datasets/cifar-10-batches-bin/train/"
weights = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
sampler = ds.WeightedRandomSampler(weights, num_samples=6)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)
for data in dataset.create_dict_iterator():
print("Image shape:", data['image'].shape, ", Label:", data['label'])
The output is as follows:
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 6
Randomly samples a specified amount of data from the specified index subset.
The following example uses a subset random sampler to obtain 3 samples from the specified subset in the CIFAR-10 dataset, and displays shapes and labels of the loaded data.
import mindspore.dataset as ds
ds.config.set_seed(2)
DATA_DIR = "./datasets/cifar-10-batches-bin/train/"
indices = [0, 1, 2, 3, 4, 5]
sampler = ds.SubsetRandomSampler(indices, num_samples=3)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)
for data in dataset.create_dict_iterator():
print("Image shape:", data['image'].shape, ", Label:", data['label'])
The output is as follows:
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Samples K pieces of data from the specified P categories.
The following example uses the PK sampler to obtain 2 samples from each category in the CIFAR-10 dataset, not more than 20 samples in total, and displays shapes and labels of the read data.
import mindspore.dataset as ds
ds.config.set_seed(3)
DATA_DIR = "./datasets/cifar-10-batches-bin/train/"
sampler = ds.PKSampler(num_val=2, class_column='label', num_samples=20)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)
for data in dataset.create_dict_iterator():
print("Image shape:", data['image'].shape, ", Label:", data['label'])
The output is as follows:
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 3
Image shape: (32, 32, 3) , Label: 3
Image shape: (32, 32, 3) , Label: 4
Image shape: (32, 32, 3) , Label: 4
Image shape: (32, 32, 3) , Label: 5
Image shape: (32, 32, 3) , Label: 5
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 7
Image shape: (32, 32, 3) , Label: 7
Image shape: (32, 32, 3) , Label: 8
Image shape: (32, 32, 3) , Label: 8
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 9
Samples dataset shards in distributed training.
The following example uses a distributed sampler to divide a generated dataset into three shards, obtains no more than three data samples in each shard, and displays the loaded data on shard number 0.
import numpy as np
import mindspore.dataset as ds
data_source = [0, 1, 2, 3, 4, 5, 6, 7, 8]
sampler = ds.DistributedSampler(num_shards=3, shard_id=0, shuffle=False, num_samples=3)
dataset = ds.NumpySlicesDataset(data_source, column_names=["data"], sampler=sampler)
for data in dataset.create_dict_iterator():
print(data)
The output is as follows:
{'data': Tensor(shape=[], dtype=Int64, value= 0)}
{'data': Tensor(shape=[], dtype=Int64, value= 3)}
{'data': Tensor(shape=[], dtype=Int64, value= 6)}
You can inherit the Sampler
base class and define the sampling method of the sampler by implementing the __iter__
method.
The following example defines a sampler with an interval of 2 samples from subscript 0 to subscript 9, applies the sampler to the CIFAR-10 dataset, and displays shapes and labels of the read data.
import mindspore.dataset as ds
class MySampler(ds.Sampler):
def __iter__(self):
for i in range(0, 10, 2):
yield i
DATA_DIR = "./datasets/cifar-10-batches-bin/train/"
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=MySampler())
for data in dataset.create_dict_iterator():
print("Image shape:", data['image'].shape, ", Label:", data['label'])
The output is as follows:
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 8
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。