48 Star 506 Fork 1.3K

Ascend/ModelZoo-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
CGAN_preprocess.py 2.60 KB
一键复制 编辑 原始数据 按行查看 历史
王姜奔 提交于 3年前 . init
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import numpy as np
import os
import argparse
def parse_args():
desc = "Pytorch implementation of CGAN collections"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--input_dim', type=int, default=62, help="The input_dim")
parser.add_argument('--output_dim', type=int, default=3, help="The output_dim")
parser.add_argument('--input_size', type=int, default=28, help="The image size of MNIST")
parser.add_argument('--class_num', type=int, default=10, help="The num of classes of MNIST")
parser.add_argument('--pth_path', type=str, default='CGAN_G.pth', help='pth model path')
parser.add_argument('--onnx_path', type=str, default="CGAN.onnx", help='onnx model path')
parser.add_argument('--save_path', type=str, default="data", help='processed data path')
return parser.parse_args()
# fixed noise & condition
def prep_preocess(args):
sample_num = args.class_num**2
z_dim = args.input_dim
sample_z_ = torch.zeros((sample_num, z_dim))
for i in range(args.class_num):
sample_z_[i * args.class_num] = torch.rand(1,z_dim)
for j in range(1, args.class_num):
sample_z_[i * args.class_num + j] = sample_z_[i * args.class_num]
if not os.path.exists(os.path.join(args.save_path)):
os.makedirs(os.path.join(args.save_path))
temp = torch.zeros((args.class_num, 1))
for i in range(args.class_num):
temp[i, 0] = i
temp_y = torch.zeros((sample_num, 1))
for i in range(args.class_num):
temp_y[i * args.class_num: (i + 1) * args.class_num] = temp
sample_y_ = torch.zeros((sample_num, args.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
input = torch.cat([sample_z_, sample_y_], 1)
input = np.array(input).astype(np.float32)
input.tofile(os.path.join(args.save_path, 'input' + '.bin'))
if __name__ == "__main__":
args = parse_args()
prep_preocess(args)
print("data preprocessed stored in",args.save_path)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ascend/ModelZoo-PyTorch.git
git@gitee.com:ascend/ModelZoo-PyTorch.git
ascend
ModelZoo-PyTorch
ModelZoo-PyTorch
master

搜索帮助