2.8K Star 8.9K Fork 5.2K

GVPMindSpore/mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_tdt_data_transfer.py 6.11 KB
一键复制 编辑 原始数据 按行查看 历史
luoyang 提交于 2024-07-24 17:31 +08:00 . fix dataset st import
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import time
import numpy as np
from mindspore import context, nn, Tensor
from mindspore import log as logger
from mindspore.common.api import _cell_graph_executor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
import mindspore.dataset as de
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from tests.mark_utils import arg_mark
DATA_DIR = "/home/workspace/mindspore_dataset/cifar-10-verify-bin"
def dataset_cifar(dataset_path=None, batch_size=32, repeat_num=1, num_rows=9600, distribution_num=None, shard_id=None,
drop_remainder=True, usage=None, shuffle=False, num_workers=8, resize_size=32):
if dataset_path is None:
dataset_path = DATA_DIR
ds = de.Cifar10Dataset(dataset_path, num_samples=num_rows, num_shards=distribution_num, shard_id=shard_id,
shuffle=shuffle, usage=usage, num_parallel_workers=num_workers)
typecast_op = transforms.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=num_workers)
image_op_list = [vision.Resize(resize_size),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
vision.HWC2CHW()]
ds = ds.map(input_columns="image", operations=image_op_list, num_parallel_workers=num_workers)
ds = ds.batch(batch_size, drop_remainder=drop_remainder, num_parallel_workers=num_workers)
ds = ds.repeat(repeat_num)
return ds
def op_network_with_epoch(network, step_num):
iter_num = 0
network.set_train()
for _ in range(step_num):
op_return = network()
op_return = op_return.asnumpy()
logger.info("Op_return is : %s", op_return)
iter_num += 1
logger.info("Iter Num : %s", iter_num)
return iter_num
def convert_type(shapes, types):
ms_types = []
for np_shape, np_type in zip(shapes, types):
input_np = np.zeros(np_shape, np_type)
tensor = Tensor(input_np)
ms_types.append(tensor.dtype)
return ms_types
def get_dataset_base_value(dataset):
dataset_size = dataset.get_dataset_size()
batch_size = dataset.get_batch_size()
return dataset_size, batch_size
def dataset_send_tdt(dataset):
time.sleep(1)
dataset.send(1)
def get_dataset_shapes_and_types(dataset):
dataset_shapes = dataset.output_shapes()
np_types = dataset.output_types()
dataset_types = convert_type(dataset_shapes, np_types)
return dataset_shapes, dataset_types
class SingleOpNetwork(nn.Cell):
def __init__(self, shapes):
super(SingleOpNetwork, self).__init__()
self.shapes = tuple(shapes[0])
self.Op_Reshape_network = P.Reshape()
def construct(self, network_input):
return self.Op_Reshape_network(network_input, self.shapes)
class NetWithTDT(nn.Cell):
def __init__(self, network, dataset_types, dataset_shapes, shared_name=''):
super(NetWithTDT, self).__init__()
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_shapes), shared_name)
self.Op_network = network
def construct(self):
next_input, _ = self.get_next()
return self.Op_network(next_input)
def op_network_with_step_num(dataset, step_num):
dataset_shapes, dataset_types = get_dataset_shapes_and_types(dataset)
_, batch_size = get_dataset_base_value(dataset)
dataset = dataset.device_que()
queue_name = dataset.queue_name
net = SingleOpNetwork(dataset_shapes)
net_with_dataset = NetWithTDT(net, dataset_types, dataset_shapes, queue_name)
# when device type is Davinci, net should has get_next operation before call init_dataset
_cell_graph_executor.init_dataset(dataset.queue_name, 1, batch_size, dataset_types, dataset_shapes, (), "")
dataset_send_tdt(dataset)
return op_network_with_epoch(net_with_dataset, step_num)
@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='essential')
def test_tdt_produce_beyond_consume():
"""
Feature: Test dataset sink mode.
Description: Send 10 data into tdt and count number of the out iter.
Expectation: Number of out iter equals to number of source data.
"""
context.set_context(mode=context.GRAPH_MODE)
batch_size = 64
repeat_num = 1
num_rows = 6400
beyond_step_num = 10
ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows)
iter_num = op_network_with_step_num(ds, step_num=beyond_step_num)
logger.info("out_iter_num:%s", iter_num)
assert iter_num == 10
@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='essential')
def test_tdt_consume_beyond_produce():
"""
Feature: Test dataset sink mode.
Description: Number of source data is less than train loop.
Expectation: Returns fail and raises excrption.
"""
context.set_context(mode=context.GRAPH_MODE)
context.set_context(op_timeout=30)
batch_size = 64
repeat_num = 1
num_rows = 640
beyond_step_num = 1000
ds = dataset_cifar(batch_size=batch_size, repeat_num=repeat_num, num_rows=num_rows)
try:
iter_num = op_network_with_step_num(ds, step_num=beyond_step_num)
logger.info("out_iter_num:%s", iter_num)
assert False
except RuntimeError as e:
logger.info("when dataset batch num is less than train loop, error msg is %s", e)
assert True
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
v2.6.0-rc1

搜索帮助