From 07f110d2ab2c90c050e78f00dc0eddcb6bf9515e Mon Sep 17 00:00:00 2001 From: may Date: Fri, 21 Apr 2023 11:18:22 +0800 Subject: [PATCH 1/5] Add CBAM --- cv/classification/cbam/pytorch/LICENSE | 21 ++ cv/classification/cbam/pytorch/MODELS/bam.py | 49 +++ cv/classification/cbam/pytorch/MODELS/cbam.py | 95 +++++ .../cbam/pytorch/MODELS/model_resnet.py | 205 +++++++++++ cv/classification/cbam/pytorch/README.md | 35 ++ .../scripts/train_imagenet_resnet50_bam.sh | 9 + .../scripts/train_imagenet_resnet50_cbam.sh | 9 + .../cbam/pytorch/train_imagenet.py | 328 ++++++++++++++++++ 8 files changed, 751 insertions(+) create mode 100644 cv/classification/cbam/pytorch/LICENSE create mode 100644 cv/classification/cbam/pytorch/MODELS/bam.py create mode 100644 cv/classification/cbam/pytorch/MODELS/cbam.py create mode 100644 cv/classification/cbam/pytorch/MODELS/model_resnet.py create mode 100644 cv/classification/cbam/pytorch/README.md create mode 100755 cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh create mode 100755 cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh create mode 100644 cv/classification/cbam/pytorch/train_imagenet.py diff --git a/cv/classification/cbam/pytorch/LICENSE b/cv/classification/cbam/pytorch/LICENSE new file mode 100644 index 000000000..f91eab172 --- /dev/null +++ b/cv/classification/cbam/pytorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Jongchan Park + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/cv/classification/cbam/pytorch/MODELS/bam.py b/cv/classification/cbam/pytorch/MODELS/bam.py new file mode 100644 index 000000000..cbda3d060 --- /dev/null +++ b/cv/classification/cbam/pytorch/MODELS/bam.py @@ -0,0 +1,49 @@ +import torch +import math +import torch.nn as nn +import torch.nn.functional as F + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) +class ChannelGate(nn.Module): + def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): + super(ChannelGate, self).__init__() + self.gate_activation = gate_activation + self.gate_c = nn.Sequential() + self.gate_c.add_module( 'flatten', Flatten() ) + gate_channels = [gate_channel] + gate_channels += [gate_channel // reduction_ratio] * num_layers + gate_channels += [gate_channel] + for i in range( len(gate_channels) - 2 ): + self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) ) + self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) ) + self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() ) + self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) ) + def forward(self, in_tensor): + avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) ) + return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) + +class SpatialGate(nn.Module): + def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4): + super(SpatialGate, self).__init__() + self.gate_s = nn.Sequential() + self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1)) + self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) ) + self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() ) + for i in range( dilation_conv_num ): + self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \ + padding=dilation_val, dilation=dilation_val) ) + self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) ) + self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() ) + self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) ) + def forward(self, in_tensor): + return self.gate_s( in_tensor ).expand_as(in_tensor) +class BAM(nn.Module): + def __init__(self, gate_channel): + super(BAM, self).__init__() + self.channel_att = ChannelGate(gate_channel) + self.spatial_att = SpatialGate(gate_channel) + def forward(self,in_tensor): + att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) ) + return att * in_tensor diff --git a/cv/classification/cbam/pytorch/MODELS/cbam.py b/cv/classification/cbam/pytorch/MODELS/cbam.py new file mode 100644 index 000000000..3124c04b9 --- /dev/null +++ b/cv/classification/cbam/pytorch/MODELS/cbam.py @@ -0,0 +1,95 @@ +import torch +import math +import torch.nn as nn +import torch.nn.functional as F + +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type=='avg': + avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( avg_pool ) + elif pool_type=='max': + max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( max_pool ) + elif pool_type=='lp': + lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( lp_pool ) + elif pool_type=='lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp( lse_pool ) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale + +class CBAM(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out diff --git a/cv/classification/cbam/pytorch/MODELS/model_resnet.py b/cv/classification/cbam/pytorch/MODELS/model_resnet.py new file mode 100644 index 000000000..650664031 --- /dev/null +++ b/cv/classification/cbam/pytorch/MODELS/model_resnet.py @@ -0,0 +1,205 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from torch.nn import init +from .cbam import * +from .bam import * + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + if use_cbam: + self.cbam = CBAM( planes, 16 ) + else: + self.cbam = None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out = self.cbam(out) + + out += residual + out = self.relu(out) + + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + if use_cbam: + self.cbam = CBAM( planes * 4, 16 ) + else: + self.cbam = None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out = self.cbam(out) + + out += residual + out = self.relu(out) + + return out + +class ResNet(nn.Module): + def __init__(self, block, layers, network_type, num_classes, att_type=None): + self.inplanes = 64 + super(ResNet, self).__init__() + self.network_type = network_type + # different model config between ImageNet and CIFAR + if network_type == "ImageNet": + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.avgpool = nn.AvgPool2d(7) + else: + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + + if att_type=='BAM': + self.bam1 = BAM(64*block.expansion) + self.bam2 = BAM(128*block.expansion) + self.bam3 = BAM(256*block.expansion) + else: + self.bam1, self.bam2, self.bam3 = None, None, None + + self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type) + + self.fc = nn.Linear(512 * block.expansion, num_classes) + + init.kaiming_normal(self.fc.weight) + for key in self.state_dict(): + if key.split('.')[-1]=="weight": + if "conv" in key: + init.kaiming_normal(self.state_dict()[key], mode='fan_out') + if "bn" in key: + if "SpatialGate" in key: + self.state_dict()[key][...] = 0 + else: + self.state_dict()[key][...] = 1 + elif key.split(".")[-1]=='bias': + self.state_dict()[key][...] = 0 + + def _make_layer(self, block, planes, blocks, stride=1, att_type=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM')) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM')) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if self.network_type == "ImageNet": + x = self.maxpool(x) + + x = self.layer1(x) + if not self.bam1 is None: + x = self.bam1(x) + + x = self.layer2(x) + if not self.bam2 is None: + x = self.bam2(x) + + x = self.layer3(x) + if not self.bam3 is None: + x = self.bam3(x) + + x = self.layer4(x) + + if self.network_type == "ImageNet": + x = self.avgpool(x) + else: + x = F.avg_pool2d(x, 4) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + +def ResidualNet(network_type, depth, num_classes, att_type): + + assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100" + assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101' + + if depth == 18: + model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type) + + elif depth == 34: + model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type) + + elif depth == 50: + model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type) + + elif depth == 101: + model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type) + + return model diff --git a/cv/classification/cbam/pytorch/README.md b/cv/classification/cbam/pytorch/README.md new file mode 100644 index 000000000..c85c4e126 --- /dev/null +++ b/cv/classification/cbam/pytorch/README.md @@ -0,0 +1,35 @@ +# CBAM + +## Model description +Official PyTorch code for "[CBAM: Convolutional Block Attention Module (ECCV2018)](http://openaccess.thecvf.com/content_ECCV_2018/html/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.html)" + + +## Step 1: Installing + +```bash +pip3 install torch +pip3 install torchvision +``` + +## Step 2: Training + +ResNet50 based examples are included. Example scripts are included under ```./scripts/``` directory. +ImageNet data should be included under ```./data/ImageNet/``` with foler named ```train``` and ```val```. + +``` +# To train with CBAM (ResNet50 backbone) +# For 8 GPUs +python3 train_imagenet.py --ngpu 8 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 256 --lr 0.1 --att-type CBAM --prefix RESNET50_IMAGENET_CBAM ./data/ImageNet +# For 1 GPUs +python3 train_imagenet.py --ngpu 1 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 64 --lr 0.1 --att-type CBAM --prefix RESNET50_IMAGENET_CBAM ./data/ImageNet +``` + +## Result + +| GPU | FP32 | +| ----------- | ------------------------------------ | +| 8 cards | Prec@1 76.216 | + +## Reference + +- [MXNet implementation of CBAM with several modifications](https://github.com/bruinxiong/Modified-CBAMnet.mxnet) by [bruinxiong](https://github.com/bruinxiong) diff --git a/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh b/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh new file mode 100755 index 000000000..b5b5b6669 --- /dev/null +++ b/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh @@ -0,0 +1,9 @@ +python3 train_imagenet.py \ + --ngpu 8 \ + --workers 20 \ + --arch resnet --depth 50 \ + --epochs 100 \ + --batch-size 256 --lr 0.1 \ + --att-type BAM \ + --prefix RESNET50_IMAGENET_BAM \ + ./data/ImageNet/ diff --git a/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh b/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh new file mode 100755 index 000000000..40d2c2d1a --- /dev/null +++ b/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh @@ -0,0 +1,9 @@ +python3 train_imagenet.py \ + --ngpu 8 \ + --workers 20 \ + --arch resnet --depth 50 \ + --epochs 100 \ + --batch-size 256 --lr 0.1 \ + --att-type CBAM \ + --prefix RESNET50_IMAGENET_CBAM \ + ./data/ImageNet/ diff --git a/cv/classification/cbam/pytorch/train_imagenet.py b/cv/classification/cbam/pytorch/train_imagenet.py new file mode 100644 index 000000000..46fdcd056 --- /dev/null +++ b/cv/classification/cbam/pytorch/train_imagenet.py @@ -0,0 +1,328 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse +import os +import shutil +import time +import random + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +from MODELS.model_resnet import * +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True +model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet', + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('--depth', default=50, type=int, metavar='D', + help='model depth') +parser.add_argument('--ngpu', default=4, type=int, metavar='G', + help='number of gpus to use') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument("--seed", type=int, default=1234, metavar='BS', help='input batch size for training (default: 64)') +parser.add_argument("--prefix", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving') +parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluation only') +parser.add_argument('--att-type', type=str, choices=['BAM', 'CBAM'], default=None) +best_prec1 = 0 + +if not os.path.exists('./checkpoints'): + os.mkdir('./checkpoints') + +def main(): + global args, best_prec1 + global viz, train_lot, test_lot + args = parser.parse_args() + print ("args", args) + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + random.seed(args.seed) + + # create model + if args.arch == "resnet": + model = ResidualNet( 'ImageNet', args.depth, 1000, args.att_type ) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) + #model = torch.nn.DataParallel(model).cuda() + model = model.cuda() + print ("model") + print (model) + + # get the number of model parameters + print('Number of model parameters: {}'.format( + sum([p.data.nelement() for p in model.parameters()]))) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + best_prec1 = checkpoint['best_prec1'] + model.load_state_dict(checkpoint['state_dict']) + if 'optimizer' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + + cudnn.benchmark = True + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + # import pdb + # pdb.set_trace() + val_loader = torch.utils.data.DataLoader( + datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + if args.evaluate: + validate(val_loader, model, criterion, 0) + return + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(256), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler) + + for epoch in range(args.start_epoch, args.epochs): + adjust_learning_rate(optimizer, epoch) + + # train for one epoch + train(args, train_loader, model, criterion, optimizer, epoch) + + # evaluate on validation set + prec1 = validate(val_loader, model, criterion, epoch) + + # remember best prec@1 and save checkpoint + is_best = prec1 > best_prec1 + best_prec1 = max(prec1, best_prec1) + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'best_prec1': best_prec1, + 'optimizer' : optimizer.state_dict(), + }, is_best, args.prefix) + + +def train(args, train_loader, model, criterion, optimizer, epoch): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to train mode + model.train() + all_fps = [] + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + target = target.cuda() + input_var = torch.autograd.Variable(input) + target_var = torch.autograd.Variable(target) + + # compute output + output = model(input_var) + loss = criterion(output, target_var) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + fps = input.size(0) * args.ngpu / (time.time() - end) + all_fps.append(fps) + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + print(f"EPOCH {epoch} Avg img/s: {sum(all_fps) / len(all_fps)}") + +def validate(val_loader, model, criterion, epoch): + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + + end = time.time() + for i, (input, target) in enumerate(val_loader): + target = target.cuda() + with torch.no_grad(): + input_var = torch.autograd.Variable(input) + target_var = torch.autograd.Variable(target) + + # compute output + output = model(input_var) + loss = criterion(output, target_var) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1[0], input.size(0)) + top5.update(prec5[0], input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, is_best, prefix): + filename='./checkpoints/%s_checkpoint.pth.tar'%prefix + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, './checkpoints/%s_model_best.pth.tar'%prefix) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def adjust_learning_rate(optimizer, epoch): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() -- Gitee From e2fcc8b0a9348abcf9ab5224b7fd0c41ead6341d Mon Sep 17 00:00:00 2001 From: wwwmayyyyyyyy Date: Mon, 24 Apr 2023 10:37:50 +0800 Subject: [PATCH 2/5] Add cspdarknet53 --- cv/classification/cspdarknet53 | 1 + 1 file changed, 1 insertion(+) create mode 160000 cv/classification/cspdarknet53 diff --git a/cv/classification/cspdarknet53 b/cv/classification/cspdarknet53 new file mode 160000 index 000000000..7ec2b2d0b --- /dev/null +++ b/cv/classification/cspdarknet53 @@ -0,0 +1 @@ +Subproject commit 7ec2b2d0b089d5f6b2eb9087c230863e79474339 -- Gitee From f68c3b6d5cf4847e00555ba9c2387554727c5694 Mon Sep 17 00:00:00 2001 From: wwwmayyyyyyyy Date: Mon, 24 Apr 2023 10:55:57 +0800 Subject: [PATCH 3/5] Update --- cv/classification/cspdarknet53 | 1 - 1 file changed, 1 deletion(-) delete mode 160000 cv/classification/cspdarknet53 diff --git a/cv/classification/cspdarknet53 b/cv/classification/cspdarknet53 deleted file mode 160000 index 7ec2b2d0b..000000000 --- a/cv/classification/cspdarknet53 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7ec2b2d0b089d5f6b2eb9087c230863e79474339 -- Gitee From 96574c17777d01ecf40f20f152622018eeee4e6a Mon Sep 17 00:00:00 2001 From: wwwmayyyyyyyy Date: Mon, 24 Apr 2023 10:57:25 +0800 Subject: [PATCH 4/5] Add cspdarknet53 --- .../cspdarknet53/pytorch/README.md | 38 +++ .../cspdarknet53/pytorch/__init__.py | 20 ++ .../pytorch/common_utils/__init__.py | 38 +++ .../cspdarknet53/pytorch/common_utils/dist.py | 156 +++++++++ .../pytorch/common_utils/metric_logger.py | 106 ++++++ .../cspdarknet53/pytorch/common_utils/misc.py | 26 ++ .../pytorch/common_utils/smooth_value.py | 88 +++++ .../pytorch/dataloader/__init__.py | 14 + .../pytorch/dataloader/classification.py | 120 +++++++ .../pytorch/dataloader/dali_classification.py | 129 +++++++ .../pytorch/dataloader/utils/__init__.py | 14 + .../utils/presets_classification.py | 54 +++ .../cspdarknet53/pytorch/model/csdarknet53.py | 67 ++++ .../cspdarknet53/pytorch/model/cslayers.py | 176 ++++++++++ .../cspdarknet53/pytorch/train.py | 318 ++++++++++++++++++ .../cspdarknet53/pytorch/utils_.py | 169 ++++++++++ 16 files changed, 1533 insertions(+) create mode 100644 cv/classification/cspdarknet53/pytorch/README.md create mode 100644 cv/classification/cspdarknet53/pytorch/__init__.py create mode 100644 cv/classification/cspdarknet53/pytorch/common_utils/__init__.py create mode 100644 cv/classification/cspdarknet53/pytorch/common_utils/dist.py create mode 100644 cv/classification/cspdarknet53/pytorch/common_utils/metric_logger.py create mode 100644 cv/classification/cspdarknet53/pytorch/common_utils/misc.py create mode 100644 cv/classification/cspdarknet53/pytorch/common_utils/smooth_value.py create mode 100644 cv/classification/cspdarknet53/pytorch/dataloader/__init__.py create mode 100644 cv/classification/cspdarknet53/pytorch/dataloader/classification.py create mode 100644 cv/classification/cspdarknet53/pytorch/dataloader/dali_classification.py create mode 100644 cv/classification/cspdarknet53/pytorch/dataloader/utils/__init__.py create mode 100644 cv/classification/cspdarknet53/pytorch/dataloader/utils/presets_classification.py create mode 100644 cv/classification/cspdarknet53/pytorch/model/csdarknet53.py create mode 100644 cv/classification/cspdarknet53/pytorch/model/cslayers.py create mode 100644 cv/classification/cspdarknet53/pytorch/train.py create mode 100644 cv/classification/cspdarknet53/pytorch/utils_.py diff --git a/cv/classification/cspdarknet53/pytorch/README.md b/cv/classification/cspdarknet53/pytorch/README.md new file mode 100644 index 000000000..5fd274239 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/README.md @@ -0,0 +1,38 @@ +# CspDarknet53 + +## Model description + +This is an implementation of CSPDarknet53 in pytorch. + +## Step 1: Installing + +```bash +pip3 install torch +pip3 install torchvision +``` + +## Step 2: Training + +### One single GPU + +```bash +export CUDA_VISIBLE_DEVICES=0 +python3 train.py --batch-size 64 --epochs 120 --data-path /home/datasets/cv/imagenet +``` + +### 8 GPUs on one machine +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +python3 -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --batch-size 64 --epochs 120 --data-path /home/datasets/cv/imagenet +``` + +## Result + +| GPU | FP32 | +| ----------- | ------------------------------------ | +| 8 cards | Acc@1 76.644 fps 830 | +| 8 cards | fps 148 | + +## Reference + +https://github.com/WongKinYiu/CrossStagePartialNetworks diff --git a/cv/classification/cspdarknet53/pytorch/__init__.py b/cv/classification/cspdarknet53/pytorch/__init__.py new file mode 100644 index 000000000..011573976 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from .utils import * +from .common_utils import * +from .data_loader import * + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/cv/classification/cspdarknet53/pytorch/common_utils/__init__.py b/cv/classification/cspdarknet53/pytorch/common_utils/__init__.py new file mode 100644 index 000000000..7d2e011f7 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/common_utils/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import random + +import numpy as np + +from .dist import * +from .metric_logger import * +from .misc import * +from .smooth_value import * + +def manual_seed(seed, deterministic=False): + random.seed(seed) + np.random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True diff --git a/cv/classification/cspdarknet53/pytorch/common_utils/dist.py b/cv/classification/cspdarknet53/pytorch/common_utils/dist.py new file mode 100644 index 000000000..767b6ce00 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/common_utils/dist.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from collections import defaultdict, deque +import datetime +import errno +import os +import time + +import torch +import torch.distributed as dist + + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def get_dist_backend(args=None): + DIST_BACKEND_ENV = "PT_DIST_BACKEND" + if DIST_BACKEND_ENV in os.environ: + return os.environ[DIST_BACKEND_ENV] + + if args is None: + args = dict() + + backend_attr_name = "dist_backend" + + if hasattr(args, backend_attr_name): + return getattr(args, backend_attr_name) + + if backend_attr_name in args: + return args[backend_attr_name] + + return "nccl" + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + dist_backend = get_dist_backend(args) + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + data_list = [None] * world_size + dist.all_gather_object(data_list, data) + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/cv/classification/cspdarknet53/pytorch/common_utils/metric_logger.py b/cv/classification/cspdarknet53/pytorch/common_utils/metric_logger.py new file mode 100644 index 000000000..ab9c61b0b --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/common_utils/metric_logger.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from collections import defaultdict +import datetime +import time + +import torch +from .smooth_value import SmoothedValue + +""" +Examples: + +logger = MetricLogger(" ") + +>>> # For iter dataloader +>>> metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) +>>> header = 'Epoch: [{}]'.format(epoch) +>>> for image, target in metric_logger.log_every(data_loader, print_freq, header): +>>> ... +>>> logger.metric_logger.meters['img/s'].update(fps) + +""" + +class MetricLogger(object): + + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {}'.format(header, total_time_str)) diff --git a/cv/classification/cspdarknet53/pytorch/common_utils/misc.py b/cv/classification/cspdarknet53/pytorch/common_utils/misc.py new file mode 100644 index 000000000..8c5e49bcd --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/common_utils/misc.py @@ -0,0 +1,26 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys +import errno + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise diff --git a/cv/classification/cspdarknet53/pytorch/common_utils/smooth_value.py b/cv/classification/cspdarknet53/pytorch/common_utils/smooth_value.py new file mode 100644 index 000000000..8c1fadae2 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/common_utils/smooth_value.py @@ -0,0 +1,88 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from collections import defaultdict, deque +import datetime +import errno +import os +import time + +import torch +import torch.distributed as dist +from .dist import is_dist_avail_and_initialized + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float32, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) diff --git a/cv/classification/cspdarknet53/pytorch/dataloader/__init__.py b/cv/classification/cspdarknet53/pytorch/dataloader/__init__.py new file mode 100644 index 000000000..66b217b85 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/dataloader/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/cv/classification/cspdarknet53/pytorch/dataloader/classification.py b/cv/classification/cspdarknet53/pytorch/dataloader/classification.py new file mode 100644 index 000000000..317a9874e --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/dataloader/classification.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import os +import time + +import torch +import torchvision +from .utils import presets_classification as presets + +""" +Examples: + +>>> dataset_train, dataset_val = load_data(train_dir, val_dir, args) +""" + + +def get_datasets(traindir, + valdir, + resize_size=256, + crop_size=224, + auto_augment_policy=None, + random_erase_prob=0.): + # Data loading code + print("Loading data") + print("Loading training data") + dataset = torchvision.datasets.ImageFolder( + traindir, + presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy, + random_erase_prob=random_erase_prob)) + + print("Loading validation data") + dataset_test = torchvision.datasets.ImageFolder( + valdir, + presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)) + + return dataset, dataset_test + + +def get_input_size(model): + biger_input_size_models = ['inception'] + resize_size = 256 + crop_size = 224 + for bi_model in biger_input_size_models: + if bi_model in model: + resize_size = 342 + crop_size = 299 + + return resize_size, crop_size + + +def load_data(train_dir, val_dir, args): + auto_augment_policy = getattr(args, "auto_augment", None) + random_erase_prob = getattr(args, "random_erase", 0.0) + resize_size, crop_size = get_input_size(args.model) + dataset, dataset_test = get_datasets(train_dir, val_dir, + auto_augment_policy=auto_augment_policy, + random_erase_prob=random_erase_prob, + resize_size=resize_size, + crop_size=crop_size) + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + else: + train_sampler = torch.utils.data.RandomSampler(dataset) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + return dataset, dataset_test, train_sampler, test_sampler + + +def _create_torch_dataloader(train_dir, val_dir, args): + dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, + sampler=train_sampler, num_workers=args.workers, pin_memory=True) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=args.batch_size, + sampler=test_sampler, num_workers=args.workers, pin_memory=True) + + return data_loader, data_loader_test + + +def _create_dali_dataloader(train_dir, val_dir, args): + from .dali_classification import get_imagenet_iter_dali + device = torch.cuda.current_device() + _, crop_size = get_input_size(args.model) + data_loader = get_imagenet_iter_dali('train', train_dir, args.batch_size, + num_threads=args.workers, + device_id=device, + size=crop_size) + data_loader_test = get_imagenet_iter_dali('val', train_dir, args.batch_size, + num_threads=args.workers, + device_id=device, + size=crop_size) + + return data_loader, data_loader_test + + +def create_dataloader(train_dir, val_dir, args): + print("Creating data loaders") + if args.dali: + train_dir = os.path.dirname(train_dir) + val_dir = os.path.dirname(val_dir) + return _create_dali_dataloader(train_dir, val_dir, args) + return _create_torch_dataloader(train_dir, val_dir, args) diff --git a/cv/classification/cspdarknet53/pytorch/dataloader/dali_classification.py b/cv/classification/cspdarknet53/pytorch/dataloader/dali_classification.py new file mode 100644 index 000000000..2918bef66 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/dataloader/dali_classification.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + + +import os + +import nvidia.dali.ops as ops +import nvidia.dali.types as types +from nvidia.dali.pipeline import Pipeline +from nvidia.dali.plugin.pytorch import DALIClassificationIterator, DALIGenericIterator + +class HybridTrainPipe(Pipeline): + def __init__(self, batch_size, num_threads, device_id, data_dir, size): + super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id) + self.input = ops.FileReader(file_root=data_dir, random_shuffle=True) + self.decode = ops.ImageDecoder(device="cpu", output_type=types.RGB) + self.res = ops.RandomResizedCrop(device="gpu", size=size, random_area=[0.08, 1.25]) + self.cmnp = ops.CropMirrorNormalize(device="gpu", + output_dtype=types.FLOAT, + output_layout=types.NCHW, + image_type=types.RGB, + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + + def define_graph(self): + self.jpegs, self.labels = self.input(name="Reader") + + images = self.decode(self.jpegs) + images = self.res(images.gpu()) + output = self.cmnp(images) + return [output, self.labels] + + +class HybridValPipe(Pipeline): + def __init__(self, batch_size, num_threads, device_id, data_dir, size): + super(HybridValPipe, self).__init__(batch_size, num_threads, device_id) + self.input = ops.FileReader(file_root=data_dir, random_shuffle=False) + self.decode = ops.ImageDecoder(device="cpu", output_type=types.RGB) + self.res = ops.Resize(device="gpu", resize_x=size, resize_y=size) + self.cmnp = ops.CropMirrorNormalize(device="gpu", + output_dtype=types.FLOAT, + output_layout=types.NCHW, + crop=(size, size), + image_type=types.RGB, + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], + std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) + + def define_graph(self): + self.jpegs, self.labels = self.input(name="Reader") + + images = self.decode(self.jpegs) + images = self.res(images.gpu()) + output = self.cmnp(images) + return [output, self.labels] + + +def get_imagenet_iter_dali(type, image_dir, batch_size, num_threads, device_id, size): + if type == 'train': + pip_train = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, + data_dir = os.path.join(image_dir, "train"), + size=size) + pip_train.build() + dali_iter_train = DALIClassificationIterator(pip_train, size=pip_train.epoch_size("Reader")) + return dali_iter_train + elif type == 'val': + pip_val = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id, + data_dir = os.path.join(image_dir, "val"), + size=size) + pip_val.build() + dali_iter_val = DALIClassificationIterator(pip_val, size=pip_val.epoch_size("Reader")) + return dali_iter_val + + +def main(arguments): + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', help='directory to save data to', type=str, default='classification data') + args = parser.parse_args(arguments) + + train_loader = get_imagenet_iter_dali(type='train', image_dir=args.data_dir, + batch_size=256, + num_threads=4, size=224, device_id=3) + + val_loader = get_imagenet_iter_dali(type="val", image_dir=args.data_dir, + batch_size=256, + num_threads=4, size=224, device_id=3) + + print('start dali train dataloader.') + start = time.time() + for epoch in range(20): + for i, data in enumerate(train_loader): + images = data[0]["data"].cuda(non_blocking=True) + labels = data[0]["label"].squeeze().long().cuda(non_blocking=True) + + # WARN: Very important + train_loader.reset() + print("Epoch", epoch) + print('dali iterate time: %fs' % (time.time() - start)) + print('end dali train dataloader.') + + + print('start dali val dataloader.') + start = time.time() + for i, data in enumerate(val_loader): + images = data[0]["data"].cuda(non_blocking=True) + print(images.shape) + labels = data[0]["label"].squeeze().long().cuda(non_blocking=True) + print(labels.shape) + print('dali iterate time: %fs' % (time.time() - start)) + print('end dali val dataloader.') + + +if __name__ == '__main__': + import os, time, sys + import argparse + sys.exit(main(sys.argv[1:])) diff --git a/cv/classification/cspdarknet53/pytorch/dataloader/utils/__init__.py b/cv/classification/cspdarknet53/pytorch/dataloader/utils/__init__.py new file mode 100644 index 000000000..66b217b85 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/dataloader/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. \ No newline at end of file diff --git a/cv/classification/cspdarknet53/pytorch/dataloader/utils/presets_classification.py b/cv/classification/cspdarknet53/pytorch/dataloader/utils/presets_classification.py new file mode 100644 index 000000000..59688a959 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/dataloader/utils/presets_classification.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from torchvision.transforms import autoaugment, transforms + + +class ClassificationPresetTrain: + def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5, + auto_augment_policy=None, random_erase_prob=0.0): + trans = [transforms.RandomResizedCrop(crop_size)] + if hflip_prob > 0: + trans.append(transforms.RandomHorizontalFlip(hflip_prob)) + if auto_augment_policy is not None: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append(autoaugment.AutoAugment(policy=aa_policy)) + trans.extend([ + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + if random_erase_prob > 0: + trans.append(transforms.RandomErasing(p=random_erase_prob)) + + self.transforms = transforms.Compose(trans) + + def __call__(self, img): + return self.transforms(img) + + +class ClassificationPresetEval: + def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + + self.transforms = transforms.Compose([ + transforms.Resize(resize_size), + transforms.CenterCrop(crop_size), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + def __call__(self, img): + return self.transforms(img) diff --git a/cv/classification/cspdarknet53/pytorch/model/csdarknet53.py b/cv/classification/cspdarknet53/pytorch/model/csdarknet53.py new file mode 100644 index 000000000..74bd36cc1 --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/model/csdarknet53.py @@ -0,0 +1,67 @@ +import torch.nn as nn +from .cslayers import * +import torch.backends.cudnn as cudnn +from torchsummary import summary + +__all__ = ['CsDarkNet53'] + +class CsDarkNet53(nn.Module): + def __init__(self, num_classes): + super(CsDarkNet53, self).__init__() + + input_channels = 32 + + # Network + self.stage1 = Conv2dBatchLeaky(3, input_channels, 3, 1, activation='mish') + self.stage2 = Stage2(input_channels) + self.stage3 = Stage3(4*input_channels) + self.stage4 = Stage(4*input_channels, 8) + self.stage5 = Stage(8*input_channels, 8) + self.stage6 = Stage(16*input_channels, 4) + + self.conv = Conv2dBatchLeaky(32*input_channels, 32*input_channels, 1, 1, activation='mish') + self.avgpool = nn.AdaptiveAvgPool2d((1,1)) + self.fc = nn.Linear(1024, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + stage1 = self.stage1(x) + stage2 = self.stage2(stage1) + stage3 = self.stage3(stage2) + stage4 = self.stage4(stage3) + stage5 = self.stage5(stage4) + stage6 = self.stage6(stage5) + + conv = self.conv(stage6) + x = self.avgpool(conv) + x = x.view(-1, 1024) + x = self.fc(x) + + return x + +if __name__ == "__main__": + use_cuda = torch.cuda.is_available() + if use_cuda: + device = torch.device("cuda") + cudnn.benchmark = True + else: + device = torch.device("cpu") + + darknet = CsDarkNet53(num_classes=10) + darknet = darknet.cuda() + with torch.no_grad(): + darknet.eval() + data = torch.rand(1, 3, 256, 256) + data = data.cuda() + try: + #print(darknet) + summary(darknet,(3,256,256)) + print(darknet(data)) + except Exception as e: + print(e) diff --git a/cv/classification/cspdarknet53/pytorch/model/cslayers.py b/cv/classification/cspdarknet53/pytorch/model/cslayers.py new file mode 100644 index 000000000..0f5f7c78a --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/model/cslayers.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# mish(x) = x * tanh(log(1 + e^x)) +class Mish(nn.Module): + def __init__(self): + super(Mish, self).__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + +class Conv2dBatchLeaky(nn.Module): + """ + This convenience layer groups a 2D convolution, a batchnorm and a leaky ReLU. + """ + def __init__(self, in_channels, out_channels, kernel_size, stride, activation='leaky', leaky_slope=0.1): + super(Conv2dBatchLeaky, self).__init__() + + # Parameters + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + if isinstance(kernel_size, (list, tuple)): + self.padding = [int(k/2) for k in kernel_size] + else: + self.padding = int(kernel_size/2) + self.leaky_slope = leaky_slope + # self.mish = Mish() + + # Layer + if activation == "leaky": + self.layers = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, bias=False), + nn.BatchNorm2d(self.out_channels), + nn.LeakyReLU(self.leaky_slope, inplace=True) + ) + elif activation == "mish": + self.layers = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, bias=False), + nn.BatchNorm2d(self.out_channels), + Mish() + ) + elif activation == "linear": + self.layers = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, bias=False) + ) + + def __repr__(self): + s = '{name} ({in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, padding={padding}, negative_slope={leaky_slope})' + return s.format(name=self.__class__.__name__, **self.__dict__) + + def forward(self, x): + x = self.layers(x) + return x + +class SmallBlock(nn.Module): + + def __init__(self, nchannels): + super().__init__() + self.features = nn.Sequential( + Conv2dBatchLeaky(nchannels, nchannels, 1, 1, activation='mish'), + Conv2dBatchLeaky(nchannels, nchannels, 3, 1, activation='mish') + ) + # conv_shortcut + ''' + 参考 https://github.com/bubbliiiing/yolov4-pytorch + shortcut后不接任何conv + ''' + # self.active_linear = Conv2dBatchLeaky(nchannels, nchannels, 1, 1, activation='linear') + # self.conv_shortcut = Conv2dBatchLeaky(nchannels, nchannels, 1, 1, activation='mish') + + + def forward(self, data): + short_cut = data + self.features(data) + # active_linear = self.conv_shortcut(short_cut) + + return short_cut + +# Stage1 conv [256,256,3]->[256,256,32] + +class Stage2(nn.Module): + + def __init__(self, nchannels): + super().__init__() + # stage2 32 + self.conv1 = Conv2dBatchLeaky(nchannels, 2*nchannels, 3, 2, activation='mish') + self.split0 = Conv2dBatchLeaky(2*nchannels, 2*nchannels, 1, 1, activation='mish') + self.split1 = Conv2dBatchLeaky(2*nchannels, 2*nchannels, 1, 1, activation='mish') + + self.conv2 = Conv2dBatchLeaky(2*nchannels, nchannels, 1, 1, activation='mish') + self.conv3 = Conv2dBatchLeaky(nchannels, 2*nchannels, 3, 1, activation='mish') + + self.conv4 = Conv2dBatchLeaky(2*nchannels, 2*nchannels, 1, 1, activation='mish') + + + def forward(self, data): + conv1 = self.conv1(data) + split0 = self.split0(conv1) + split1 = self.split1(conv1) + conv2 = self.conv2(split1) + conv3 = self.conv3(conv2) + + shortcut = split1 + conv3 + conv4 = self.conv4(shortcut) + + route = torch.cat([split0, conv4], dim=1) + return route + +class Stage3(nn.Module): + def __init__(self, nchannels): + super().__init__() + # stage3 128 + self.conv1 = Conv2dBatchLeaky(nchannels, int(nchannels/2), 1, 1, activation='mish') + self.conv2 = Conv2dBatchLeaky(int(nchannels/2), nchannels, 3, 2, activation='mish') + + self.split0 = Conv2dBatchLeaky(nchannels, int(nchannels/2), 1, 1, activation='mish') + self.split1 = Conv2dBatchLeaky(nchannels, int(nchannels/2), 1, 1, activation='mish') + + self.block1 = SmallBlock(int(nchannels/2)) + self.block2 = SmallBlock(int(nchannels/2)) + + self.conv3 = Conv2dBatchLeaky(int(nchannels/2), int(nchannels/2), 1, 1, activation='mish') + + def forward(self, data): + conv1 = self.conv1(data) + conv2 = self.conv2(conv1) + + split0 = self.split0(conv2) + split1 = self.split1(conv2) + + block1 = self.block1(split1) + block2 = self.block2(block1) + + conv3 = self.conv3(block2) + + route = torch.cat([split0, conv3], dim=1) + + return route + +# Stage4 Stage5 Stage6 +class Stage(nn.Module): + def __init__(self, nchannels, nblocks): + super().__init__() + # stage4 : 128 + # stage5 : 256 + # stage6 : 512 + self.conv1 = Conv2dBatchLeaky(nchannels, nchannels, 1, 1, activation='mish') + self.conv2 = Conv2dBatchLeaky(nchannels, 2*nchannels, 3, 2, activation='mish') + self.split0 = Conv2dBatchLeaky(2*nchannels, nchannels, 1, 1, activation='mish') + self.split1 = Conv2dBatchLeaky(2*nchannels, nchannels, 1, 1, activation='mish') + blocks = [] + for i in range(nblocks): + blocks.append(SmallBlock(nchannels)) + self.blocks = nn.Sequential(*blocks) + self.conv4 = Conv2dBatchLeaky(nchannels, nchannels, 1, 1, activation='mish') + + def forward(self,data): + conv1 = self.conv1(data) + conv2 = self.conv2(conv1) + + split0 = self.split0(conv2) + split1 = self.split1(conv2) + blocks = self.blocks(split1) + conv4 = self.conv4(blocks) + route = torch.cat([split0, conv4], dim=1) + + return route + + + + + + + diff --git a/cv/classification/cspdarknet53/pytorch/train.py b/cv/classification/cspdarknet53/pytorch/train.py new file mode 100644 index 000000000..9afa8dc9d --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/train.py @@ -0,0 +1,318 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime +import os + +import time + +import torch +import torch.utils.data + +try: + from torch.cuda.amp import autocast, GradScaler + scaler = GradScaler() +except: + autocast = None + scaler = None + + +from torch import nn +import torch.distributed as dist +import torchvision + +from utils_ import (MetricLogger, SmoothedValue, accuracy, mkdir,\ + init_distributed_mode, manual_seed,\ + is_main_process, save_on_master, get_world_size) + +from dataloader.classification import get_datasets, create_dataloader + +from model.csdarknet53 import CsDarkNet53 + +def compute_loss(model, image, target, criterion): + output = model(image) + if not isinstance(output, (tuple, list)): + output = [output] + losses = [] + for out in output: + losses.append(criterion(out, target)) + loss = sum(losses) + return loss, output[0] + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, use_dali=False): + model.train() + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value}')) + metric_logger.add_meter('img/s', SmoothedValue(window_size=10, fmt='{value}')) + + header = 'Epoch: [{}]'.format(epoch) + all_fps = [] + for data in metric_logger.log_every(data_loader, print_freq, header): + if use_dali: + image, target = data[0]["data"], data[0]["label"][:, 0].long() + else: + image, target = data + start_time = time.time() + image, target = image.to(device), target.to(device) + if autocast is None or not amp: + loss, output = compute_loss(model, image, target, criterion) + else: + with autocast(): + loss, output = compute_loss(model, image, target, criterion) + + optimizer.zero_grad() + if scaler is not None and amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + torch.cuda.synchronize() + end_time = time.time() + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + batch_size = image.shape[0] + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + fps = batch_size / (end_time - start_time) * get_world_size() + metric_logger.meters['img/s'].update(fps) + all_fps.append(fps) + + print(header, 'Avg img/s:', sum(all_fps) / len(all_fps)) + + +def evaluate(model, criterion, data_loader, device, print_freq=100, use_dali=False): + model.eval() + metric_logger = MetricLogger(delimiter=" ") + header = 'Test:' + with torch.no_grad(): + for data in metric_logger.log_every(data_loader, print_freq, header): + if use_dali: + image, target = data[0]["data"], data[0]["label"][:, 0].long() + else: + image, target = data + image = image.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + output = model(image) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + # FIXME need to take into account that the datasets + # could have been padded in distributed setup + batch_size = image.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + + print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) + return metric_logger.acc1.global_avg + + +def _get_cache_path(filepath): + import hashlib + h = hashlib.sha1(filepath.encode()).hexdigest() + cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") + cache_path = os.path.expanduser(cache_path) + return cache_path + + +def main(args): + if args.output_dir: + mkdir(args.output_dir) + + init_distributed_mode(args) + print(args) + + device = torch.device(args.device) + + manual_seed(args.seed, deterministic=False) + # torch.backends.cudnn.benchmark = True + + # WARN: + if dist.is_initialized(): + num_gpu = dist.get_world_size() + else: + num_gpu = 1 + + global_batch_size = num_gpu * args.batch_size + + train_dir = os.path.join(args.data_path, 'train') + val_dir = os.path.join(args.data_path, 'val') + + num_classes = len(os.listdir(train_dir)) + if 0 < num_classes < 13: + if global_batch_size > 512: + if is_main_process(): + print("WARN: Updating global batch size to 512, avoid non-convergence when training small dataset.") + args.batch_size = 512 // num_gpu + + if args.pretrained: + num_classes = 1000 + + data_loader, data_loader_test = create_dataloader(train_dir, val_dir, args) + + print("Creating model") + model = CsDarkNet53(num_classes=num_classes) + print(model) + model.to(device) + if args.distributed and args.sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + + criterion = nn.CrossEntropyLoss() + + opt_name = args.opt.lower() + if opt_name == 'sgd': + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + elif opt_name == 'rmsprop': + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + else: + raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) + + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + + if args.test_only: + evaluate(model, criterion, data_loader_test, device=device) + return + + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs): + epoch_start_time = time.time() + if args.distributed and not args.dali: + data_loader.sampler.set_epoch(epoch) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, use_dali=args.dali) + lr_scheduler.step() + evaluate(model, criterion, data_loader_test, device=device, use_dali=args.dali) + if args.output_dir: + checkpoint = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args} + #save_on_master( + # checkpoint, + # os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + save_on_master( + checkpoint, + os.path.join(args.output_dir, 'checkpoint.pth')) + epoch_total_time = time.time() - epoch_start_time + epoch_total_time_str = str(datetime.timedelta(seconds=int(epoch_total_time))) + print('epoch time {}'.format(epoch_total_time_str)) + + if args.dali: + data_loader.reset() + data_loader_test.reset() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +def get_args_parser(add_help=True): + import argparse + parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help) + + parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') + parser.add_argument('--model', default='', help='model') + parser.add_argument('--device', default='cuda', help='device') + parser.add_argument('-b', '--batch-size', default=32, type=int) + parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') + parser.add_argument('--opt', default='sgd', type=str, help='optimizer') + parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') + parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') + parser.add_argument('--print-freq', default=10, type=int, help='print frequency') + parser.add_argument('--output-dir', default='.', help='path where to save') + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument( + "--cache-dataset", + dest="cache_dataset", + help="Cache the datasets for quicker initialization. It also serializes the transforms", + action="store_true", + ) + parser.add_argument( + "--sync-bn", + dest="sync_bn", + help="Use sync batch norm", + action="store_true", + ) + parser.add_argument( + "--test-only", + dest="test_only", + help="Only test the model", + action="store_true", + ) + parser.add_argument( + "--pretrained", + dest="pretrained", + help="Use pre-trained models from the modelzoo", + action="store_true", + ) + parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') + parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') + parser.add_argument( + "--dali", + help="Use dali as dataloader", + default=False, + action="store_true", + ) + + # distributed training parameters + parser.add_argument('--local_rank', default=-1, type=int, + help='Local rank') + parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument('--amp', action='store_true', help='Automatic Mixed Precision training') + parser.add_argument('--seed', default=42, type=int, help='Random seed') + return parser + + +if __name__ == "__main__": + args = get_args_parser().parse_args() + main(args) diff --git a/cv/classification/cspdarknet53/pytorch/utils_.py b/cv/classification/cspdarknet53/pytorch/utils_.py new file mode 100644 index 000000000..2a25edf2a --- /dev/null +++ b/cv/classification/cspdarknet53/pytorch/utils_.py @@ -0,0 +1,169 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from collections import defaultdict, deque, OrderedDict +import copy +import datetime +import hashlib +import time +import torch +import torch.distributed as dist + +import errno +import os + +from common_utils import * + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target[None]) + + res = [] + for k in topk: + correct_k = correct[:k].flatten().sum(dtype=torch.float32) + res.append(correct_k * (100.0 / batch_size)) + return res + + +def average_checkpoints(inputs): + """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: + https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 + + Args: + inputs (List[str]): An iterable of string paths of checkpoints to load from. + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = OrderedDict() + params_keys = None + new_state = None + num_models = len(inputs) + for fpath in inputs: + with open(fpath, "rb") as f: + state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + # Copies over the settings from the first checkpoint + if new_state is None: + new_state = state + model_params = state["model"] + model_params_keys = list(model_params.keys()) + if params_keys is None: + params_keys = model_params_keys + elif params_keys != model_params_keys: + raise KeyError( + "For checkpoint {}, expected list of params: {}, " + "but found: {}".format(f, params_keys, model_params_keys) + ) + for k in params_keys: + p = model_params[k] + if isinstance(p, torch.HalfTensor): + p = p.float() + if k not in params_dict: + params_dict[k] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + params_dict[k] += p + averaged_params = OrderedDict() + for k, v in params_dict.items(): + averaged_params[k] = v + if averaged_params[k].is_floating_point(): + averaged_params[k].div_(num_models) + else: + averaged_params[k] //= num_models + new_state["model"] = averaged_params + return new_state + + +def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True): + """ + This method can be used to prepare weights files for new models. It receives as + input a model architecture and a checkpoint from the training script and produces + a file with the weights ready for release. + + Examples: + from torchvision import models as M + + # Classification + model = M.mobilenet_v3_large(pretrained=False) + print(store_model_weights(model, './class.pth')) + + # Quantized Classification + model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False) + model.fuse_model() + model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack') + _ = torch.quantization.prepare_qat(model, inplace=True) + print(store_model_weights(model, './qat.pth')) + + # Object Detection + model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False) + print(store_model_weights(model, './obj.pth')) + + # Segmentation + model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True) + print(store_model_weights(model, './segm.pth', strict=False)) + + Args: + model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes. + checkpoint_path (str): The path of the checkpoint we will load. + checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored. + Default: "model". + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + output_path (str): The location where the weights are saved. + """ + # Store the new model next to the checkpoint_path + checkpoint_path = os.path.abspath(checkpoint_path) + output_dir = os.path.dirname(checkpoint_path) + + # Deep copy to avoid side-effects on the model object. + model = copy.deepcopy(model) + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Load the weights to the model to validate that everything works + # and remove unnecessary weights (such as auxiliaries, etc) + model.load_state_dict(checkpoint[checkpoint_key], strict=strict) + + tmp_path = os.path.join(output_dir, str(model.__hash__())) + torch.save(model.state_dict(), tmp_path) + + sha256_hash = hashlib.sha256() + with open(tmp_path, "rb") as f: + # Read and update hash string value in blocks of 4K + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + hh = sha256_hash.hexdigest() + + output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth") + os.replace(tmp_path, output_path) + + return output_path -- Gitee From 58627b08161c4864f02522933662fb0d06aeded8 Mon Sep 17 00:00:00 2001 From: wwwmayyyyyyyy Date: Mon, 24 Apr 2023 11:03:10 +0800 Subject: [PATCH 5/5] Update --- cv/classification/cbam/pytorch/LICENSE | 21 -- cv/classification/cbam/pytorch/MODELS/bam.py | 49 --- cv/classification/cbam/pytorch/MODELS/cbam.py | 95 ----- .../cbam/pytorch/MODELS/model_resnet.py | 205 ----------- cv/classification/cbam/pytorch/README.md | 35 -- .../scripts/train_imagenet_resnet50_bam.sh | 9 - .../scripts/train_imagenet_resnet50_cbam.sh | 9 - .../cbam/pytorch/train_imagenet.py | 328 ------------------ 8 files changed, 751 deletions(-) delete mode 100644 cv/classification/cbam/pytorch/LICENSE delete mode 100644 cv/classification/cbam/pytorch/MODELS/bam.py delete mode 100644 cv/classification/cbam/pytorch/MODELS/cbam.py delete mode 100644 cv/classification/cbam/pytorch/MODELS/model_resnet.py delete mode 100644 cv/classification/cbam/pytorch/README.md delete mode 100755 cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh delete mode 100755 cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh delete mode 100644 cv/classification/cbam/pytorch/train_imagenet.py diff --git a/cv/classification/cbam/pytorch/LICENSE b/cv/classification/cbam/pytorch/LICENSE deleted file mode 100644 index f91eab172..000000000 --- a/cv/classification/cbam/pytorch/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2019 Jongchan Park - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/cv/classification/cbam/pytorch/MODELS/bam.py b/cv/classification/cbam/pytorch/MODELS/bam.py deleted file mode 100644 index cbda3d060..000000000 --- a/cv/classification/cbam/pytorch/MODELS/bam.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import math -import torch.nn as nn -import torch.nn.functional as F - -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) -class ChannelGate(nn.Module): - def __init__(self, gate_channel, reduction_ratio=16, num_layers=1): - super(ChannelGate, self).__init__() - self.gate_activation = gate_activation - self.gate_c = nn.Sequential() - self.gate_c.add_module( 'flatten', Flatten() ) - gate_channels = [gate_channel] - gate_channels += [gate_channel // reduction_ratio] * num_layers - gate_channels += [gate_channel] - for i in range( len(gate_channels) - 2 ): - self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) ) - self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) ) - self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() ) - self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) ) - def forward(self, in_tensor): - avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) ) - return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor) - -class SpatialGate(nn.Module): - def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4): - super(SpatialGate, self).__init__() - self.gate_s = nn.Sequential() - self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1)) - self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) ) - self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() ) - for i in range( dilation_conv_num ): - self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \ - padding=dilation_val, dilation=dilation_val) ) - self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) ) - self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() ) - self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) ) - def forward(self, in_tensor): - return self.gate_s( in_tensor ).expand_as(in_tensor) -class BAM(nn.Module): - def __init__(self, gate_channel): - super(BAM, self).__init__() - self.channel_att = ChannelGate(gate_channel) - self.spatial_att = SpatialGate(gate_channel) - def forward(self,in_tensor): - att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) ) - return att * in_tensor diff --git a/cv/classification/cbam/pytorch/MODELS/cbam.py b/cv/classification/cbam/pytorch/MODELS/cbam.py deleted file mode 100644 index 3124c04b9..000000000 --- a/cv/classification/cbam/pytorch/MODELS/cbam.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch -import math -import torch.nn as nn -import torch.nn.functional as F - -class BasicConv(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): - super(BasicConv, self).__init__() - self.out_channels = out_planes - self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) - self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None - self.relu = nn.ReLU() if relu else None - - def forward(self, x): - x = self.conv(x) - if self.bn is not None: - x = self.bn(x) - if self.relu is not None: - x = self.relu(x) - return x - -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) - -class ChannelGate(nn.Module): - def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): - super(ChannelGate, self).__init__() - self.gate_channels = gate_channels - self.mlp = nn.Sequential( - Flatten(), - nn.Linear(gate_channels, gate_channels // reduction_ratio), - nn.ReLU(), - nn.Linear(gate_channels // reduction_ratio, gate_channels) - ) - self.pool_types = pool_types - def forward(self, x): - channel_att_sum = None - for pool_type in self.pool_types: - if pool_type=='avg': - avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp( avg_pool ) - elif pool_type=='max': - max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp( max_pool ) - elif pool_type=='lp': - lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) - channel_att_raw = self.mlp( lp_pool ) - elif pool_type=='lse': - # LSE pool only - lse_pool = logsumexp_2d(x) - channel_att_raw = self.mlp( lse_pool ) - - if channel_att_sum is None: - channel_att_sum = channel_att_raw - else: - channel_att_sum = channel_att_sum + channel_att_raw - - scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) - return x * scale - -def logsumexp_2d(tensor): - tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) - s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) - outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() - return outputs - -class ChannelPool(nn.Module): - def forward(self, x): - return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) - -class SpatialGate(nn.Module): - def __init__(self): - super(SpatialGate, self).__init__() - kernel_size = 7 - self.compress = ChannelPool() - self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) - def forward(self, x): - x_compress = self.compress(x) - x_out = self.spatial(x_compress) - scale = F.sigmoid(x_out) # broadcasting - return x * scale - -class CBAM(nn.Module): - def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): - super(CBAM, self).__init__() - self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) - self.no_spatial=no_spatial - if not no_spatial: - self.SpatialGate = SpatialGate() - def forward(self, x): - x_out = self.ChannelGate(x) - if not self.no_spatial: - x_out = self.SpatialGate(x_out) - return x_out diff --git a/cv/classification/cbam/pytorch/MODELS/model_resnet.py b/cv/classification/cbam/pytorch/MODELS/model_resnet.py deleted file mode 100644 index 650664031..000000000 --- a/cv/classification/cbam/pytorch/MODELS/model_resnet.py +++ /dev/null @@ -1,205 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch.nn import init -from .cbam import * -from .bam import * - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): - super(BasicBlock, self).__init__() - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) - self.downsample = downsample - self.stride = stride - - if use_cbam: - self.cbam = CBAM( planes, 16 ) - else: - self.cbam = None - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - if not self.cbam is None: - out = self.cbam(out) - - out += residual - out = self.relu(out) - - return out - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - if use_cbam: - self.cbam = CBAM( planes * 4, 16 ) - else: - self.cbam = None - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - if not self.cbam is None: - out = self.cbam(out) - - out += residual - out = self.relu(out) - - return out - -class ResNet(nn.Module): - def __init__(self, block, layers, network_type, num_classes, att_type=None): - self.inplanes = 64 - super(ResNet, self).__init__() - self.network_type = network_type - # different model config between ImageNet and CIFAR - if network_type == "ImageNet": - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.avgpool = nn.AvgPool2d(7) - else: - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) - - self.bn1 = nn.BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) - - if att_type=='BAM': - self.bam1 = BAM(64*block.expansion) - self.bam2 = BAM(128*block.expansion) - self.bam3 = BAM(256*block.expansion) - else: - self.bam1, self.bam2, self.bam3 = None, None, None - - self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type) - - self.fc = nn.Linear(512 * block.expansion, num_classes) - - init.kaiming_normal(self.fc.weight) - for key in self.state_dict(): - if key.split('.')[-1]=="weight": - if "conv" in key: - init.kaiming_normal(self.state_dict()[key], mode='fan_out') - if "bn" in key: - if "SpatialGate" in key: - self.state_dict()[key][...] = 0 - else: - self.state_dict()[key][...] = 1 - elif key.split(".")[-1]=='bias': - self.state_dict()[key][...] = 0 - - def _make_layer(self, block, planes, blocks, stride=1, att_type=None): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM')) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM')) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - if self.network_type == "ImageNet": - x = self.maxpool(x) - - x = self.layer1(x) - if not self.bam1 is None: - x = self.bam1(x) - - x = self.layer2(x) - if not self.bam2 is None: - x = self.bam2(x) - - x = self.layer3(x) - if not self.bam3 is None: - x = self.bam3(x) - - x = self.layer4(x) - - if self.network_type == "ImageNet": - x = self.avgpool(x) - else: - x = F.avg_pool2d(x, 4) - x = x.view(x.size(0), -1) - x = self.fc(x) - return x - -def ResidualNet(network_type, depth, num_classes, att_type): - - assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100" - assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101' - - if depth == 18: - model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type) - - elif depth == 34: - model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type) - - elif depth == 50: - model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type) - - elif depth == 101: - model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type) - - return model diff --git a/cv/classification/cbam/pytorch/README.md b/cv/classification/cbam/pytorch/README.md deleted file mode 100644 index c85c4e126..000000000 --- a/cv/classification/cbam/pytorch/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# CBAM - -## Model description -Official PyTorch code for "[CBAM: Convolutional Block Attention Module (ECCV2018)](http://openaccess.thecvf.com/content_ECCV_2018/html/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.html)" - - -## Step 1: Installing - -```bash -pip3 install torch -pip3 install torchvision -``` - -## Step 2: Training - -ResNet50 based examples are included. Example scripts are included under ```./scripts/``` directory. -ImageNet data should be included under ```./data/ImageNet/``` with foler named ```train``` and ```val```. - -``` -# To train with CBAM (ResNet50 backbone) -# For 8 GPUs -python3 train_imagenet.py --ngpu 8 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 256 --lr 0.1 --att-type CBAM --prefix RESNET50_IMAGENET_CBAM ./data/ImageNet -# For 1 GPUs -python3 train_imagenet.py --ngpu 1 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 64 --lr 0.1 --att-type CBAM --prefix RESNET50_IMAGENET_CBAM ./data/ImageNet -``` - -## Result - -| GPU | FP32 | -| ----------- | ------------------------------------ | -| 8 cards | Prec@1 76.216 | - -## Reference - -- [MXNet implementation of CBAM with several modifications](https://github.com/bruinxiong/Modified-CBAMnet.mxnet) by [bruinxiong](https://github.com/bruinxiong) diff --git a/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh b/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh deleted file mode 100755 index b5b5b6669..000000000 --- a/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_bam.sh +++ /dev/null @@ -1,9 +0,0 @@ -python3 train_imagenet.py \ - --ngpu 8 \ - --workers 20 \ - --arch resnet --depth 50 \ - --epochs 100 \ - --batch-size 256 --lr 0.1 \ - --att-type BAM \ - --prefix RESNET50_IMAGENET_BAM \ - ./data/ImageNet/ diff --git a/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh b/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh deleted file mode 100755 index 40d2c2d1a..000000000 --- a/cv/classification/cbam/pytorch/scripts/train_imagenet_resnet50_cbam.sh +++ /dev/null @@ -1,9 +0,0 @@ -python3 train_imagenet.py \ - --ngpu 8 \ - --workers 20 \ - --arch resnet --depth 50 \ - --epochs 100 \ - --batch-size 256 --lr 0.1 \ - --att-type CBAM \ - --prefix RESNET50_IMAGENET_CBAM \ - ./data/ImageNet/ diff --git a/cv/classification/cbam/pytorch/train_imagenet.py b/cv/classification/cbam/pytorch/train_imagenet.py deleted file mode 100644 index 46fdcd056..000000000 --- a/cv/classification/cbam/pytorch/train_imagenet.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import argparse -import os -import shutil -import time -import random - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.optim -import torch.utils.data -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import torchvision.models as models -from MODELS.model_resnet import * -from PIL import ImageFile -ImageFile.LOAD_TRUNCATED_IMAGES = True -model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -parser.add_argument('data', metavar='DIR', - help='path to dataset') -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet', - help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') -parser.add_argument('--depth', default=50, type=int, metavar='D', - help='model depth') -parser.add_argument('--ngpu', default=4, type=int, metavar='G', - help='number of gpus to use') -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('-b', '--batch-size', default=256, type=int, - metavar='N', help='mini-batch size (default: 256)') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -parser.add_argument('--print-freq', '-p', default=10, type=int, - metavar='N', help='print frequency (default: 10)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -parser.add_argument("--seed", type=int, default=1234, metavar='BS', help='input batch size for training (default: 64)') -parser.add_argument("--prefix", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving') -parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluation only') -parser.add_argument('--att-type', type=str, choices=['BAM', 'CBAM'], default=None) -best_prec1 = 0 - -if not os.path.exists('./checkpoints'): - os.mkdir('./checkpoints') - -def main(): - global args, best_prec1 - global viz, train_lot, test_lot - args = parser.parse_args() - print ("args", args) - - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - random.seed(args.seed) - - # create model - if args.arch == "resnet": - model = ResidualNet( 'ImageNet', args.depth, 1000, args.att_type ) - - # define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda() - - optimizer = torch.optim.SGD(model.parameters(), args.lr, - momentum=args.momentum, - weight_decay=args.weight_decay) - model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) - #model = torch.nn.DataParallel(model).cuda() - model = model.cuda() - print ("model") - print (model) - - # get the number of model parameters - print('Number of model parameters: {}'.format( - sum([p.data.nelement() for p in model.parameters()]))) - - # optionally resume from a checkpoint - if args.resume: - if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) - checkpoint = torch.load(args.resume) - args.start_epoch = checkpoint['epoch'] - best_prec1 = checkpoint['best_prec1'] - model.load_state_dict(checkpoint['state_dict']) - if 'optimizer' in checkpoint: - optimizer.load_state_dict(checkpoint['optimizer']) - print("=> loaded checkpoint '{}' (epoch {})" - .format(args.resume, checkpoint['epoch'])) - else: - print("=> no checkpoint found at '{}'".format(args.resume)) - - - cudnn.benchmark = True - - # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - # import pdb - # pdb.set_trace() - val_loader = torch.utils.data.DataLoader( - datasets.ImageFolder(valdir, transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])), - batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True) - if args.evaluate: - validate(val_loader, model, criterion, 0) - return - - train_dataset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(256), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - - train_sampler = None - - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler) - - for epoch in range(args.start_epoch, args.epochs): - adjust_learning_rate(optimizer, epoch) - - # train for one epoch - train(args, train_loader, model, criterion, optimizer, epoch) - - # evaluate on validation set - prec1 = validate(val_loader, model, criterion, epoch) - - # remember best prec@1 and save checkpoint - is_best = prec1 > best_prec1 - best_prec1 = max(prec1, best_prec1) - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args.arch, - 'state_dict': model.state_dict(), - 'best_prec1': best_prec1, - 'optimizer' : optimizer.state_dict(), - }, is_best, args.prefix) - - -def train(args, train_loader, model, criterion, optimizer, epoch): - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to train mode - model.train() - all_fps = [] - - end = time.time() - for i, (input, target) in enumerate(train_loader): - # measure data loading time - data_time.update(time.time() - end) - - target = target.cuda() - input_var = torch.autograd.Variable(input) - target_var = torch.autograd.Variable(target) - - # compute output - output = model(input_var) - loss = criterion(output, target_var) - - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - losses.update(loss.item(), input.size(0)) - top1.update(prec1[0], input.size(0)) - top5.update(prec5[0], input.size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # measure elapsed time - fps = input.size(0) * args.ngpu / (time.time() - end) - all_fps.append(fps) - batch_time.update(time.time() - end) - end = time.time() - - if i % args.print_freq == 0: - print('Epoch: [{0}][{1}/{2}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - epoch, i, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses, top1=top1, top5=top5)) - print(f"EPOCH {epoch} Avg img/s: {sum(all_fps) / len(all_fps)}") - -def validate(val_loader, model, criterion, epoch): - batch_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to evaluate mode - model.eval() - - end = time.time() - for i, (input, target) in enumerate(val_loader): - target = target.cuda() - with torch.no_grad(): - input_var = torch.autograd.Variable(input) - target_var = torch.autograd.Variable(target) - - # compute output - output = model(input_var) - loss = criterion(output, target_var) - - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - losses.update(loss.item(), input.size(0)) - top1.update(prec1[0], input.size(0)) - top5.update(prec5[0], input.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if i % args.print_freq == 0: - print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - i, len(val_loader), batch_time=batch_time, loss=losses, - top1=top1, top5=top5)) - - print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' - .format(top1=top1, top5=top5)) - - return top1.avg - - -def save_checkpoint(state, is_best, prefix): - filename='./checkpoints/%s_checkpoint.pth.tar'%prefix - torch.save(state, filename) - if is_best: - shutil.copyfile(filename, './checkpoints/%s_model_best.pth.tar'%prefix) - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def adjust_learning_rate(optimizer, epoch): - """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" - lr = args.lr * (0.1 ** (epoch // 30)) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -if __name__ == '__main__': - main() -- Gitee