代码拉取完成,页面将自动刷新
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Preprocessing of I3D model datasets
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import src.data_factory as data_factory
from src.transforms.spatial_transforms import Compose, RandomHorizontalFlip, RandomCrop, CenterCrop
from src.transforms.target_transforms import ClassLabel
from src.transforms.temporal_transforms import TemporalRandomCrop
from src.utils import print_config, write_config, prepare_output_dirs
from config import parse_opts
def run():
config = parse_opts()
if config.dataset == 'ucf101':
config.finetune_num_classes = 101
if config.distributed:
config.save_dir = './output_distribute/'
config = prepare_output_dirs(config)
print_config(config)
write_config(config, os.path.join(config.save_dir, 'config.json'))
train_transforms = {'spatial': Compose([RandomCrop(config.spatial_size), RandomHorizontalFlip()]),
'temporal': TemporalRandomCrop(config.train_sample_duration),
'target': ClassLabel()}
validation_transforms = {'spatial': Compose([CenterCrop(config.spatial_size)]),
'temporal': TemporalRandomCrop(config.test_sample_duration),
'target': ClassLabel()}
dataset = data_factory.get_dataset(config, train_transforms, validation_transforms)
data_path = os.path.abspath(os.path.dirname(
__file__)) + os.path.join("/preprocess_Result/data", str(config.dataset), str(config.mode))
label_path = os.path.abspath(os.path.dirname(
__file__)) + os.path.join("/preprocess_Result/label", str(config.dataset), str(config.mode))
label_list = []
print(data_path)
file_path = os.path.join(label_path, "label_bs" + str(config.mode) + str(config.batch_size) + ".npy")
if os.path.exists(data_path):
print("=====================flag=================")
os.system('rm -rf ' + data_path)
os.makedirs(data_path)
if os.path.exists(label_path):
print("=====================flag=================")
os.system('rm -rf ' + label_path)
os.makedirs(label_path)
i = 0
for data in dataset['train'].create_dict_iterator(output_numpy=True):
i = i + 1
clip = data['clip']
file_name = str(config.dataset) + '_bs' + \
str(config.mode) + '_' + str(i) + '.bin'
file_data_path = os.path.join(data_path, file_name)
print('clip file ', i, 'writing')
clip.tofile(file_data_path)
print('clip preprocess down.')
label = data['target']
label_list.append(label)
np.save(file_path, label_list)
print('Finished training.')
if __name__ == '__main__':
run()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。