代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
from __future__ import print_function, division
import argparse
import torch
from torchvision import datasets, transforms
import time
import os
version = torch.__version__
######################################################################
# Options
# --------
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--data_dir',default='/home/zzd/Market/pytorch',type=str, help='training dir path')
parser.add_argument('--train_all', action='store_true', help='use all training data' )
parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' )
parser.add_argument('--batchsize', default=128, type=int, help='batchsize')
opt = parser.parse_args()
data_dir = opt.data_dir
######################################################################
# Load Data
# ---------
#
transform_train_list = [
#transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
transforms.Resize((288,144), interpolation=3),
#transforms.RandomCrop((256,128)),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transform_val_list = [
transforms.Resize(size=(256,128),interpolation=3), #Image.BICUBIC
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
print(transform_train_list)
data_transforms = {
'train': transforms.Compose( transform_train_list ),
'val': transforms.Compose(transform_val_list),
}
train_all = ''
if opt.train_all:
train_all = '_all'
image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train' + train_all),
data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
data_transforms['val'])
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=True, num_workers=16)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
use_gpu = torch.cuda.is_available()
######################################################################
# prepare_dataset
# ------------------
#
# Now, let's write a general function to train a model. Here, we will
# illustrate:
#
# - Scheduling the learning rate
# - Saving the best model
#
# In the following, parameter ``scheduler`` is an LR scheduler object from
# ``torch.optim.lr_scheduler``.
def prepare_model():
since = time.time()
num_epochs = 1
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train']:
mean = torch.zeros(3)
std = torch.zeros(3)
# Iterate over data.
for data in dataloaders[phase]:
# get the inputs
inputs, labels = data
now_batch_size,c,h,w = inputs.shape
mean += torch.sum(torch.mean(torch.mean(inputs,dim=3),dim=2),dim=0)
std += torch.sum(torch.std(inputs.view(now_batch_size,c,h*w),dim=2),dim=0)
print(mean/dataset_sizes['train'])
print(std/dataset_sizes['train'])
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
return
prepare_model()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。