2.6K Star 8.5K Fork 4.8K

GVPMindSpore/mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
dataset.py 3.28 KB
一键复制 编辑 原始数据 按行查看 历史
YueYu 提交于 2020-12-08 16:59 . Fix some non-minddata typos
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import numpy as np
import mindspore.dataset as ds
from mindspore.mindrecord import FileWriter
from .imdb import ImdbParser
def lstm_create_dataset(data_home, batch_size, repeat_num=1, training=True):
"""Data operations."""
ds.config.set_seed(1)
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0")
if not training:
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord0")
data_set = ds.MindDataset(data_dir, columns_list=["feature", "label"], num_parallel_workers=4)
# apply map operations on images
data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
data_set = data_set.repeat(count=repeat_num)
return data_set
def _convert_to_mindrecord(data_home, features, labels, weight_np=None, training=True):
"""
convert imdb dataset to mindrecord dataset
"""
if weight_np is not None:
np.savetxt(os.path.join(data_home, 'weight.txt'), weight_np)
# write mindrecord
schema_json = {"id": {"type": "int32"},
"label": {"type": "int32"},
"feature": {"type": "int32", "shape": [-1]}}
data_dir = os.path.join(data_home, "aclImdb_train.mindrecord")
if not training:
data_dir = os.path.join(data_home, "aclImdb_test.mindrecord")
def get_imdb_data(features, labels):
data_list = []
for i, (label, feature) in enumerate(zip(labels, features)):
data_json = {"id": i,
"label": int(label),
"feature": feature.reshape(-1)}
data_list.append(data_json)
return data_list
writer = FileWriter(data_dir, shard_num=4)
data = get_imdb_data(features, labels)
writer.add_schema(schema_json, "nlp_schema")
writer.add_index(["id", "label"])
writer.write_raw_data(data)
writer.commit()
def convert_to_mindrecord(embed_size, aclimdb_path, preprocess_path, glove_path):
"""
convert imdb dataset to mindrecord dataset
"""
parser = ImdbParser(aclimdb_path, glove_path, embed_size)
parser.parse()
if not os.path.exists(preprocess_path):
print(f"preprocess path {preprocess_path} is not exist")
os.makedirs(preprocess_path)
train_features, train_labels, train_weight_np = parser.get_datas('train')
_convert_to_mindrecord(preprocess_path, train_features, train_labels, train_weight_np)
test_features, test_labels, _ = parser.get_datas('test')
_convert_to_mindrecord(preprocess_path, test_features, test_labels, training=False)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r1.1

搜索帮助