108 Star 868 Fork 1.5K

MindSpore/models

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
preprocess.py 3.38 KB
一键复制 编辑 原始数据 按行查看 历史
党翌洲 提交于 3年前 . 2022.7.27 update
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Preprocessing of I3D model datasets
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import src.data_factory as data_factory
from src.transforms.spatial_transforms import Compose, RandomHorizontalFlip, RandomCrop, CenterCrop
from src.transforms.target_transforms import ClassLabel
from src.transforms.temporal_transforms import TemporalRandomCrop
from src.utils import print_config, write_config, prepare_output_dirs
from config import parse_opts
def run():
config = parse_opts()
if config.dataset == 'ucf101':
config.finetune_num_classes = 101
if config.distributed:
config.save_dir = './output_distribute/'
config = prepare_output_dirs(config)
print_config(config)
write_config(config, os.path.join(config.save_dir, 'config.json'))
train_transforms = {'spatial': Compose([RandomCrop(config.spatial_size), RandomHorizontalFlip()]),
'temporal': TemporalRandomCrop(config.train_sample_duration),
'target': ClassLabel()}
validation_transforms = {'spatial': Compose([CenterCrop(config.spatial_size)]),
'temporal': TemporalRandomCrop(config.test_sample_duration),
'target': ClassLabel()}
dataset = data_factory.get_dataset(config, train_transforms, validation_transforms)
data_path = os.path.abspath(os.path.dirname(
__file__)) + os.path.join("/preprocess_Result/data", str(config.dataset), str(config.mode))
label_path = os.path.abspath(os.path.dirname(
__file__)) + os.path.join("/preprocess_Result/label", str(config.dataset), str(config.mode))
label_list = []
print(data_path)
file_path = os.path.join(label_path, "label_bs" + str(config.mode) + str(config.batch_size) + ".npy")
if os.path.exists(data_path):
print("=====================flag=================")
os.system('rm -rf ' + data_path)
os.makedirs(data_path)
if os.path.exists(label_path):
print("=====================flag=================")
os.system('rm -rf ' + label_path)
os.makedirs(label_path)
i = 0
for data in dataset['train'].create_dict_iterator(output_numpy=True):
i = i + 1
clip = data['clip']
file_name = str(config.dataset) + '_bs' + \
str(config.mode) + '_' + str(i) + '.bin'
file_data_path = os.path.join(data_path, file_name)
print('clip file ', i, 'writing')
clip.tofile(file_data_path)
print('clip preprocess down.')
label = data['target']
label_list.append(label)
np.save(file_path, label_list)
print('Finished training.')
if __name__ == '__main__':
run()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
master

搜索帮助