代码拉取完成,页面将自动刷新
import argparse
import random
import time
import warnings
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import models
from tensorboardX import SummaryWriter
from utils import *
from dataset.imbalance_cifar import SemiSupervisedImbalanceCIFAR10
from dataset.imbalance_svhn import SemiSupervisedImbalanceSVHN
from losses import LDAMLoss, FocalLoss
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'svhn'])
parser.add_argument('--data_path', type=str, default='./data')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet32', choices=model_names,
help='model architecture: ' + ' | '.join(model_names))
parser.add_argument('--loss_type', default="CE", type=str, choices=['CE', 'Focal', 'LDAM'])
parser.add_argument('--imb_type', default="exp", type=str, help='imbalance type')
parser.add_argument('--imb_factor', default=0.01, type=float, help='imbalance factor')
parser.add_argument('--imb_factor_unlabel', default=0.01, type=float, help='imbalance factor for unlabeled data')
parser.add_argument('--train_rule', default='None', type=str,
choices=['None', 'Resample', 'Reweight', 'DRW'])
parser.add_argument('--rand_number', default=0, type=int, help='fix random number for data sampling')
parser.add_argument('--exp_str', default='semi', type=str, help='(additional) name to indicate experiment')
parser.add_argument('--gpu', default=0, type=int, help='GPU id to use')
parser.add_argument('--pretrained_model', type=str, default='')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N')
parser.add_argument('--epochs', default=200, type=int, metavar='N')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N')
parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M')
parser.add_argument('--wd', '--weight-decay', default=2e-4, type=float, metavar='W', dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
parser.add_argument('--root_log', type=str, default='log')
parser.add_argument('--root_model', type=str, default='./checkpoint')
best_acc1 = 0
def main():
args = parser.parse_args()
args.store_name = '_'.join([args.dataset, args.arch, args.loss_type, args.train_rule, args.imb_type,
str(args.imb_factor), str(args.imb_factor_unlabel), args.exp_str])
prepare_folders(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, which can slow down training considerably! '
'You may see unexpected behavior when restarting from checkpoints.')
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.')
main_worker(args.gpu, args)
def main_worker(gpu, args):
global best_acc1
args.gpu = gpu
if args.gpu is not None:
print(f"Use GPU: {args.gpu} for training")
print(f"===> Creating model '{args.arch}'")
if args.dataset in {'cifar10', 'svhn'}:
num_classes = 10
else:
raise NotImplementedError
use_norm = True if args.loss_type == 'LDAM' else False
model = models.__dict__[args.arch](num_classes=num_classes, use_norm=use_norm)
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model = model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum, weight_decay=args.weight_decay)
mean = [0.4914, 0.4822, 0.4465] if args.dataset.startswith('cifar') else [.5, .5, .5]
std = [0.2023, 0.1994, 0.2010] if args.dataset.startswith('cifar') else [.5, .5, .5]
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std),
])
if args.dataset == 'cifar10':
train_dataset = SemiSupervisedImbalanceCIFAR10(
root=args.data_path,
imb_type=args.imb_type, imb_factor=args.imb_factor, unlabel_imb_factor=args.imb_factor_unlabel,
rand_number=args.rand_number, train=True, download=True, transform=transform_train
)
val_dataset = datasets.CIFAR10(root=args.data_path,
train=False, download=True, transform=transform_val)
train_sampler = None
if args.train_rule == 'Resample':
train_sampler = ImbalancedDatasetSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False,
num_workers=args.workers, pin_memory=True)
elif args.dataset == 'svhn':
train_dataset = SemiSupervisedImbalanceSVHN(
root=args.data_path,
imb_type=args.imb_type, imb_factor=args.imb_factor, unlabel_imb_factor=args.imb_factor_unlabel,
rand_number=args.rand_number, split='train', download=True, transform=transform_train
)
val_dataset = datasets.SVHN(root=args.data_path,
split='test', download=True, transform=transform_val)
train_sampler = None
if args.train_rule == 'Resample':
train_sampler = ImbalancedDatasetSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=False,
num_workers=args.workers, pin_memory=True)
else:
raise NotImplementedError(f"Dataset {args.dataset} is not supported!")
# evaluate only
if args.evaluate:
assert args.resume, 'Specify a trained model using [args.resume]'
checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
model.load_state_dict(checkpoint['state_dict'])
print(f"===> Checkpoint '{args.resume}' loaded, testing...")
validate(val_loader, model, nn.CrossEntropyLoss(), 0, args)
return
if args.resume:
if os.path.isfile(args.resume):
print(f"===> Loading checkpoint '{args.resume}'")
checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
if args.gpu is not None:
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print(f"===> Loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
else:
raise ValueError(f"No checkpoint found at '{args.resume}'")
if args.pretrained_model:
checkpoint = torch.load(args.pretrained_model, map_location=torch.device(f'cuda:{str(args.gpu)}'))
if 'moco_ckpt' not in args.pretrained_model:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if 'linear' not in k and 'fc' not in k:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
print(f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]')
else:
# rename moco pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
if use_norm:
assert set(msg.missing_keys) == {"fc.weight"}
else:
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
print(f'===> Pre-trained model loaded: {args.pretrained_model}')
cudnn.benchmark = True
if args.dataset.startswith(('cifar', 'svhn')):
cls_num_list = train_dataset.get_cls_num_list()
print('cls num list:')
print(cls_num_list)
args.cls_num_list = cls_num_list
# init log for training
log_training = open(os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
log_testing = open(os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
f.write(str(args))
tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch, args)
if args.train_rule == 'Reweight':
beta = 0.9999
effective_num = 1.0 - np.power(beta, cls_num_list)
per_cls_weights = (1.0 - beta) / np.array(effective_num)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
elif args.train_rule == 'DRW':
idx = epoch // 160
betas = [0, 0.9999]
effective_num = 1.0 - np.power(betas[idx], cls_num_list)
per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
else:
per_cls_weights = None
if args.loss_type == 'CE':
criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(args.gpu)
elif args.loss_type == 'LDAM':
criterion = LDAMLoss(cls_num_list=cls_num_list, max_m=0.5, s=30, weight=per_cls_weights).cuda(args.gpu)
elif args.loss_type == 'Focal':
criterion = FocalLoss(weight=per_cls_weights, gamma=1).cuda(args.gpu)
else:
warnings.warn('Loss type is not listed')
return
train(train_loader, model, criterion, optimizer, epoch, args, log_training, tf_writer)
acc1 = validate(val_loader, model, criterion, epoch, args, log_testing, tf_writer)
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
output_best = 'Best Prec@1: %.3f\n' % best_acc1
print(output_best)
log_testing.write(output_best + '\n')
log_testing.flush()
save_checkpoint(args, {
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer': optimizer.state_dict(),
}, is_best)
def train(train_loader, model, criterion, optimizer, epoch, args, log, tf_writer):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
model.train()
end = time.time()
for i, (inputs, target) in enumerate(train_loader):
data_time.update(time.time() - end)
inputs = inputs.cuda()
target = target.cuda()
output = model(inputs)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(acc1[0], inputs.size(0))
top5.update(acc5[0], inputs.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1))
print(output)
log.write(output + '\n')
log.flush()
tf_writer.add_scalar('loss/train', losses.avg, epoch)
tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
def validate(val_loader, model, criterion, epoch, args, log=None, tf_writer=None, flag='val'):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
# switch to evaluate mode
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
end = time.time()
for i, (inputs, target) in enumerate(val_loader):
inputs = inputs.cuda()
target = target.cuda()
output = model(inputs)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(acc1[0], inputs.size(0))
top5.update(acc5[0], inputs.size(0))
batch_time.update(time.time() - end)
end = time.time()
_, pred = torch.max(output, 1)
all_preds.extend(pred.cpu().numpy())
all_targets.extend(target.cpu().numpy())
if i % args.print_freq == 0:
output = ('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(output)
cf = confusion_matrix(all_targets, all_preds).astype(float)
cls_cnt = cf.sum(axis=1)
cls_hit = np.diag(cf)
cls_acc = cls_hit / cls_cnt
output = ('{flag} Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
.format(flag=flag, top1=top1, top5=top5, loss=losses))
out_cls_acc = '%s Class Accuracy: %s' % (
flag, (np.array2string(cls_acc, separator=',', formatter={'float_kind': lambda x: "%.3f" % x})))
print(output)
print(out_cls_acc)
if log is not None:
log.write(output + '\n')
log.write(out_cls_acc + '\n')
log.flush()
if tf_writer is not None:
tf_writer.add_scalar('loss/test_' + flag, losses.avg, epoch)
tf_writer.add_scalar('acc/test_' + flag + '_top1', top1.avg, epoch)
tf_writer.add_scalar('acc/test_' + flag + '_top5', top5.avg, epoch)
tf_writer.add_scalars('acc/test_' + flag + '_cls_acc', {str(i): x for i, x in enumerate(cls_acc)}, epoch)
return top1.avg
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。